diff --git a/dataset/Test Generation/ComplexCodeEval-Java/3/EI/EI.json b/dataset/Test Generation/ComplexCodeEval-Java/3/EI/EI.json new file mode 100644 index 0000000000000000000000000000000000000000..2da802473f6bac6037a4dc23066321f922d7070a --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/3/EI/EI.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTask", + "reference": " @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(TEST_S3_BUCKET);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 100L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockKafkaConsumer.consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(true);\n Mockito.when(mockKafkaConsumer.close()).thenReturn(true);\n Mockito.when(mockChunkManager.stopAsync()).thenReturn(true);\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenReturn(true);\n Mockito.when(AstraKafkaConsumer.makeKafkaConfig(Mockito.any(), Mockito.anyInt())).thenReturn(new Properties());\n Mockito.when(RecoveryChunkManager.fromConfig(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(mockChunkManager);\n Mockito.when(adminClient.listConsumerGroupOffsets(Mockito.anyString())).thenReturn(new ConsumerGroupOffsets());\n Mockito.when(AstraConfig.getRecoveryConfig()).thenReturn(new RecoveryConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig()).thenReturn(new KafkaConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic()).thenReturn(\"testTopic\");\n Mockito.when(AstraConfig.getIndexerConfig()).thenReturn(new IndexerConfig());\n Mockito.when(AstraConfig.getS3Config()).thenReturn(new S3Config());\n Mockito.when(meterRegistry.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class));\n Mockito.when(Timer.start(Mockito.any())).thenReturn(Mockito.mock(Timer.Sample.class));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n Assert.assertTrue(result);\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).prepConsumerForConsumption(Mockito.anyLong());\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong());\n Mockito.verify(mockChunkManager, Mockito.times(1)).waitForRollOvers();\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).close();\n Mockito.verify(mockChunkManager, Mockito.times(1)).stopAsync();\n Mockito.verify(mockChunkManager, Mockito.times(1)).awaitTerminated(Mockito.anyLong());\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 1, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskFailure", + "reference": " @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n String fakeS3Bucket = \"fakeBucket\";\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(fakeS3Bucket);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isFalse();\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 10L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(false);\n Mockito.when(mockKafkaConsumer.close()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.stopAsync()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertFalse(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 2, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public int migration(String srcTable, List uuids, String targetTable) {\n if (Objects.isNull(srcTable) || Objects.isNull(targetTable) || CollectionUtils.isEmpty(uuids)) {\n throw new SwValidationException(\n ValidSubject.EVALUATION, \"srcTable, targetTable and uuids must not be null\");\n }\n var filter = TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build();\n return store.migration(DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(filter)\n .build()\n );\n }\n```\n###test function signature:\n```java\n @Test\n public void testMigration()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMigration", + "reference": " @Test\n public void testMigration() {\n var srcTable = \"p/1/space/1/eval/summary\";\n var targetTable = \"p/1/eval/summary\";\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(null, List.of(\"uuid1\"), targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, null, targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, List.of(\"uuid1\"), null));\n\n evaluationRepo.migration(srcTable, List.of(\"uuid1\", \"uuid2\"), targetTable);\n\n verify(dataStore, times(1)).migration(\n DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(\n TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(List.of(\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid1\")))\n .build(),\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid2\")))\n .build()))\n .build())\n .build()\n );\n }", + "completions": [ + "```java\n @Test\n public void testMigration() {\n // Given\n String srcTable = \"srcTable\";\n List uuids = Arrays.asList(\"uuid1\", \"uuid2\", \"uuid3\");\n String targetTable = \"targetTable\";\n DataStoreMigrationRequest expectedRequest = DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build()\n )\n .build();\n when(store.migration(any(DataStoreMigrationRequest.class))).thenReturn(1);\n\n // When\n int result = migrationService.migration(srcTable, uuids, targetTable);\n\n // Then\n assertEquals(1, result);\n verify(store).migration(expectedRequest);\n }\n```\n" + ], + "line": 37, + "token": 221, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDatasMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDatasMapTest", + "reference": " @Test\n public void sendDatasMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"doris_test_db\";\n List> dataList = new ArrayList<>();\n for (int i = 0; i < 100; i++) {\n Map data = new HashMap<>();\n data.put(\"id\", i);\n data.put(\"name\", \"张三\");\n data.put(\"message\", \"股东风波的法规的规定翻跟斗广泛的给yryrtytr\");\n System.out.println(gson.toJson(data));\n dataList.add(data);\n\n }\n dorisStreamLoad.sendData(database, tableName, columnList, dataList);\n }", + "completions": [ + "```java\n @Test\n public void sendDataTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 4, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_serializable_key_and_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_serializable_key_and_value", + "reference": " @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map original = new HashMap<>();\n original.put(\"key\", new SerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map originalMap = new HashMap<>();\n originalMap.put(\"one\", 1);\n originalMap.put(\"two\", 2);\n originalMap.put(\"three\", 3);\n\n Map clonedMap = clone(originalMap);\n\n assertNotSame(originalMap, clonedMap);\n assertEquals(originalMap, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 5, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void encode(String schemaName, long nextTransactionOffset) {\n if (schemaName == null) {\n schemaName = UNKNOWN;\n }\n this.schemaName = schemaName;\n this.nextTransactionOffset = nextTransactionOffset;\n final byte[] payloadBytes = payloadToBytes();\n final int payloadLength = payloadBytes.length;\n\n // set logEventHeader\n int eventSize = eventHeaderLengthVersionGt1 + payloadLength;\n setLogEventHeader(new LogEventHeader(drc_filter_log_event.getType(), 0, eventSize, nextTransactionOffset)\n );\n\n // set payload\n final ByteBuf payloadByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(payloadLength);\n payloadByteBuf.writeBytes(payloadBytes);\n payloadByteBuf.skipBytes(payloadLength);\n setPayloadBuf(payloadByteBuf);\n }\n```\n###test function signature:\n```java\n @Test\n public void encode()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "encode", + "reference": " @Test\n public void encode() {\n FilterLogEvent filterLogEvent = new FilterLogEvent();\n filterLogEvent.encode(\"drc1\", 101);\n Assert.assertEquals(\"drc1\", filterLogEvent.getSchemaName());\n Assert.assertEquals(101, filterLogEvent.getNextTransactionOffset());\n\n ByteBuf headerBuf = filterLogEvent.getLogEventHeader().getHeaderBuf();\n headerBuf.readerIndex(0);\n ByteBuf payloadBuf = filterLogEvent.getPayloadBuf();\n payloadBuf.readerIndex(0);\n CompositeByteBuf compositeByteBuf = PooledByteBufAllocator.DEFAULT.compositeDirectBuffer();\n compositeByteBuf.addComponents(true, headerBuf, payloadBuf);\n\n FilterLogEvent newFilterLogEvent = new FilterLogEvent();\n newFilterLogEvent.read(compositeByteBuf);\n compositeByteBuf.release(compositeByteBuf.refCnt());\n Assert.assertEquals(\"drc1\", newFilterLogEvent.getSchemaName());\n Assert.assertEquals(101, newFilterLogEvent.getNextTransactionOffset());\n }", + "completions": [ + "```java\n @Test\n public void encode() {\n // Given\n String schemaName = \"testSchema\";\n long nextTransactionOffset = 1234567890L;\n drc_filter_log_event event = new drc_filter_log_event();\n\n // When\n event.encode(schemaName, nextTransactionOffset);\n\n // Then\n assertEquals(schemaName, event.getSchemaName());\n assertEquals(nextTransactionOffset, event.getNextTransactionOffset());\n assertEquals(drc_filter_log_event.getType(), event.getLogEventHeader().getType());\n assertEquals(0, event.getLogEventHeader().getServerId());\n assertEquals(eventHeaderLengthVersionGt1 + event.getPayloadBuf().readableBytes(), event.getLogEventHeader().getEventLength());\n assertEquals(nextTransactionOffset, event.getLogEventHeader().getNextTransactionOffset());\n assertEquals(event.getPayloadBuf().readableBytes(), event.getPayloadBuf().readableBytes());\n }\n```\n" + ], + "line": 29, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void put(String path, long position, ByteBuf data) {\n writeLock.lock();\n try {\n int dataLength = data.readableBytes();\n NavigableMap cache = path2cache.computeIfAbsent(path, k -> new TreeMap<>());\n Map.Entry pos2value = cache.floorEntry(position);\n long cacheStartPosition;\n Value value;\n if (pos2value == null || pos2value.getKey() + pos2value.getValue().dataLength < position) {\n cacheStartPosition = position;\n value = Value.EMPTY;\n } else {\n cacheStartPosition = pos2value.getKey();\n value = pos2value.getValue();\n }\n // ensure the capacity, if the capacity change then update the cache index\n int moreCapacity = (int) ((position + dataLength) - (cacheStartPosition + value.blocks.length * (long) blockSize));\n int newDataLength = (int) (position + dataLength - cacheStartPosition);\n if (moreCapacity > 0) {\n int[] blocks = ensureCapacity(cacheStartPosition, moreCapacity);\n if (blocks == null) {\n return;\n }\n int[] newBlocks = new int[value.blocks.length + blocks.length];\n System.arraycopy(value.blocks, 0, newBlocks, 0, value.blocks.length);\n System.arraycopy(blocks, 0, newBlocks, value.blocks.length, blocks.length);\n value = new Value(newBlocks, newDataLength);\n } else {\n value = new Value(value.blocks, newDataLength);\n }\n cache.put(cacheStartPosition, value);\n lru.put(new Key(path, cacheStartPosition), value);\n\n // write data to cache\n ByteBuffer cacheByteBuffer = this.cacheByteBuffer.duplicate();\n int positionDelta = (int) (position - cacheStartPosition);\n int written = 0;\n ByteBuffer[] nioBuffers = data.nioBuffers();\n int[] blocks = value.blocks;\n for (ByteBuffer nioBuffer : nioBuffers) {\n ByteBuf buf = Unpooled.wrappedBuffer(nioBuffer);\n while (buf.readableBytes() > 0) {\n int writePosition = positionDelta + written;\n int block = blocks[writePosition / blockSize];\n cacheByteBuffer.position(block * blockSize + writePosition % blockSize);\n int length = Math.min(buf.readableBytes(), blockSize - writePosition % blockSize);\n cacheByteBuffer.put(buf.slice(buf.readerIndex(), length).nioBuffer());\n buf.skipBytes(length);\n written += length;\n }\n }\n } finally {\n writeLock.unlock();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testMergePut() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMergePut", + "reference": " @Test\n public void testMergePut() throws IOException {\n FileCache fileCache = new FileCache(\"/tmp/file_cache_test\", 10 * 1024, 1024);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 1, 500));\n buf.addComponent(true, genBuf((byte) 2, 500));\n buf.addComponent(true, genBuf((byte) 3, 500));\n fileCache.put(\"test\", 3333L, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(1500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 4, 500));\n buf.addComponent(true, genBuf((byte) 5, 500));\n buf.addComponent(true, genBuf((byte) 6, 500));\n fileCache.put(\"test\", 3333L + 1000, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(2500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 7, 500));\n fileCache.put(\"test\", 3333L + 1000 + 1500, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(3000, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n\n assertTrue(verify(fileCache.get(\"test\", 3333L, 500).get(), (byte) 1));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 500, 500).get(), (byte) 2));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1000, 500).get(), (byte) 4));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1500, 500).get(), (byte) 5));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2000, 500).get(), (byte) 6));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2500, 500).get(), (byte) 7));\n }", + "completions": [ + "```java\n @Test\n public void testMergePut() throws IOException {\n // Given\n String path = \"/test/path\";\n long position = 100;\n ByteBuf data = Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5});\n\n // When\n cache.put(path, position, data);\n\n // Then\n NavigableMap cacheMap = cache.path2cache.get(path);\n assertNotNull(cacheMap);\n Map.Entry entry = cacheMap.floorEntry(position);\n assertNotNull(entry);\n assertEquals(position, entry.getKey().longValue());\n assertEquals(5, entry.getValue().dataLength);\n assertEquals(1, entry.getValue().blocks[0]);\n assertEquals(2, entry.getValue().blocks[1]);\n assertEquals(3, entry.getValue().blocks[2]);\n assertEquals(4, entry.getValue().blocks[3]);\n assertEquals(5, entry.getValue().blocks[4]);\n }\n```\n" + ], + "line": 65, + "token": 508, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 7, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchById()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchById", + "reference": " @Test\n public void testSearchById() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult index =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(index.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchById() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 8, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = new ArrayList<>();\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 9, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_object", + "reference": " @Test\n public void should_clone_serializable_object() {\n Object original = new SerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_object() {\n // Given\n Serializable object = new Serializable() {\n private static final long serialVersionUID = 1L;\n // Implement necessary methods\n };\n\n // When\n Serializable clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertEquals(object, clonedObject);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 10, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() {\n\n\t\tvar text = \"Hello, how are you?\";\n\n\t\tEmbedding response = vertexAiPaLm2Api.embedText(text);\n\n\t\tassertThat(response).isNotNull();\n\t\tassertThat(response.value()).hasSize(768);\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(new EmbeddingResponse(expectedEmbedding));\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Optional maybeFetchErrorIntervalMs(short errorCode, int intervalMs) {\n if (errorCode == Errors.NONE.code())\n return Optional.empty();\n\n int pushIntervalMs;\n String reason;\n\n Errors error = Errors.forCode(errorCode);\n switch (error) {\n case INVALID_REQUEST:\n case INVALID_RECORD:\n case UNSUPPORTED_VERSION:\n pushIntervalMs = Integer.MAX_VALUE;\n reason = \"The broker response indicates the client sent an request that cannot be resolved\"\n + \" by re-trying, hence disable telemetry\";\n break;\n case UNKNOWN_SUBSCRIPTION_ID:\n case UNSUPPORTED_COMPRESSION_TYPE:\n pushIntervalMs = 0;\n reason = error.message();\n break;\n case TELEMETRY_TOO_LARGE:\n case THROTTLING_QUOTA_EXCEEDED:\n reason = error.message();\n pushIntervalMs = (intervalMs != -1) ? intervalMs : ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS;\n break;\n default:\n reason = \"Unwrapped error code\";\n log.error(\"Error code: {}. Unmapped error for telemetry, disable telemetry.\", errorCode);\n pushIntervalMs = Integer.MAX_VALUE;\n }\n\n log.debug(\"Error code: {}, reason: {}. Push interval update to {} ms.\", errorCode, reason, pushIntervalMs);\n return Optional.of(pushIntervalMs);\n }\n```\n###test function signature:\n```java\n @Test\n public void testMaybeFetchErrorIntervalMs()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMaybeFetchErrorIntervalMs", + "reference": " @Test\n public void testMaybeFetchErrorIntervalMs() {\n assertEquals(Optional.empty(), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.NONE.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_REQUEST.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_RECORD.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SUBSCRIPTION_ID.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNSUPPORTED_COMPRESSION_TYPE.code(), -1));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), 20000));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), 20000));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SERVER_ERROR.code(), -1));\n }", + "completions": [ + "```java\n @Test\n public void testMaybeFetchErrorIntervalMs() {\n // Test case 1: errorCode is Errors.NONE.code()\n short errorCode = Errors.NONE.code();\n int intervalMs = 1000;\n Optional result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertFalse(result.isPresent());\n\n // Test case 2: errorCode is Errors.INVALID_REQUEST.code()\n errorCode = Errors.INVALID_REQUEST.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 3: errorCode is Errors.UNKNOWN_SUBSCRIPTION_ID.code()\n errorCode = Errors.UNKNOWN_SUBSCRIPTION_ID.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 4: errorCode is Errors.TELEMETRY_TOO_LARGE.code()\n errorCode = Errors.TELEMETRY_TOO_LARGE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 5: errorCode is Errors.THROTTLING_QUOTA_EXCEEDED.code()\n errorCode = Errors.THROTTLING_QUOTA_EXCEEDED.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 6: errorCode is Errors.UNSUPPORTED_COMPRESSION_TYPE.code()\n errorCode = Errors.UNSUPPORTED_COMPRESSION_TYPE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 7: errorCode is Errors.INVALID_RECORD.code()\n errorCode = Errors.INVALID_RECORD.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 8: errorCode is Errors.UNSUPPORTED_VERSION.code()\n errorCode = Errors.UNSUPPORTED_VERSION.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 9: errorCode is an unmapped error\n errorCode = (short) 9999;\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n }\n```\n" + ], + "line": 43, + "token": 290, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 12, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_non_serializable_object", + "reference": " @Test\n public void should_clone_non_serializable_object() {\n Object original = new NonSerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_non_serializable_object() {\n // Given\n Map nonSerializableMap = new HashMap<>();\n nonSerializableMap.put(\"key1\", new Object());\n nonSerializableMap.put(\"key2\", new Object());\n\n // When\n Map clonedMap = objectCloner.clone(nonSerializableMap);\n\n // Then\n assertNotSame(nonSerializableMap, clonedMap);\n assertNotSame(nonSerializableMap.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(nonSerializableMap.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 13, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueOperationCompletesOperationWithCachedActionResult", + "reference": " @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ActionKey actionKey = DigestUtil.asActionKey(Digest.newBuilder().setHash(\"test\").build());\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"operation-with-cached-action-result\")\n .setActionDigest(actionKey.getDigest())\n .build();\n\n ActionResult actionResult = ActionResult.getDefaultInstance();\n\n when(mockBackplane.getActionResult(eq(actionKey))).thenReturn(actionResult);\n\n Poller poller = mock(Poller.class);\n\n instance.queue(executeEntry, poller, DEFAULT_TIMEOUT).get();\n\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(CACHE_CHECK));\n verify(mockBackplane, never()).putOperation(any(Operation.class), eq(QUEUED));\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(COMPLETED));\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n ActionResult actionResult = ActionResult.newBuilder().setExitCode(0).build();\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(\"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n when(cache.get(any(ActionKey.class))).thenReturn(actionResult);\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n future.get();\n\n verify(poller).pause();\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 14, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueActionFailsQueueEligibility", + "reference": " @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n Directory inputRoot = Directory.newBuilder().build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(false);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_INVALID)\n .setSubject(INVALID_PLATFORM)\n .setDescription(\n \"properties are not valid for queue eligibility: []. If you think your\"\n + \" queue should still accept these poperties without them being\"\n + \" specified in queue configuration, consider configuring the queue\"\n + \" with `allow_unmatched: True`\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(ACTION_DIGEST));\n when(executeEntry.getStdoutStreamName()).thenReturn(STDOUT_STREAM_NAME);\n when(executeEntry.getStderrStreamName()).thenReturn(STDERR_STREAM_NAME);\n when(executeEntry.getOperationName()).thenReturn(OPERATION_NAME);\n when(executeEntry.getRequestMetadata()).thenReturn(REQUEST_METADATA);\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n when(poller.pause()).thenThrow(new RuntimeException(\"Queue eligibility failed\"));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n\n assertThrows(RuntimeException.class, future::get);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new SerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new SerializableObject(\"name2\"),\n new SerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key\", new SerializableObject());\n SerializableObject object = new SerializableObject(map);\n\n // When\n SerializableObject clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertNotSame(object.getMap(), clonedObject.getMap());\n assertNotSame(object.getMap().get(\"key\"), clonedObject.getMap().get(\"key\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 16, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForSumAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForSumAgg", + "reference": " @Test\n public void testFullIndexSearchForSumAgg() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new SumAggBuilder(\"test\", TEST_SOURCE_LONG_PROPERTY, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalSum internalSum =\n (InternalSum) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n // 1, 3, 4, 5\n assertThat(internalSum.getValue()).isEqualTo(13);\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForSumAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new SumAggBuilder(\"testField\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertEquals(SumAggBuilder.class, result.getAggregation().getClass());\n assertEquals(\"testField\", ((SumAggBuilder) result.getAggregation()).getField());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 17, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void process(String base64AuthenticatorData, String signature, String clientDataJson, Fido2RegistrationData registration,\n Fido2AuthenticationData authenticationEntity) {\n log.debug(\"Registration: {}\", registration);\n\n AuthData authData = authenticatorDataParser.parseAssertionData(base64AuthenticatorData);\n commonVerifiers.verifyRpIdHash(authData, registration.getDomain());\n\n log.debug(\"User verification option: {}\", authenticationEntity.getUserVerificationOption());\n userVerificationVerifier.verifyUserVerificationOption(authenticationEntity.getUserVerificationOption(), authData);\n\n byte[] clientDataHash = DigestUtils.getSha256Digest().digest(base64Service.urlDecode(clientDataJson));\n\n try {\n int counter = authenticatorDataParser.parseCounter(authData.getCounters());\n commonVerifiers.verifyCounter(registration.getCounter(), counter);\n registration.setCounter(counter);\n\n JsonNode uncompressedECPointNode = dataMapperService.cborReadTree(base64Service.urlDecode(registration.getUncompressedECPoint()));\n PublicKey publicKey = coseService.createUncompressedPointFromCOSEPublicKey(uncompressedECPointNode);\n\n log.debug(\"Uncompressed ECpoint node: {}\", uncompressedECPointNode);\n log.debug(\"EC Public key hex: {}\", Hex.encodeHexString(publicKey.getEncoded()));\n log.debug(\"Registration algorithm: {}, default use: -7\", registration.getSignatureAlgorithm());\n authenticatorDataVerifier.verifyAssertionSignature(authData, clientDataHash, signature, publicKey, -7);\n\n } catch (Fido2CompromisedDevice ex) {\n log.error(\"Error compromised device: {}\", ex.getMessage());\n throw ex;\n } catch (Exception ex) {\n log.error(\"Error to check none assertion: {}\", ex.getMessage());\n throw new Fido2RuntimeException(\"Failed to check none assertion: {}\", ex.getMessage(), ex);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "process_ifCborReadTreeThrowException_fido2RuntimeException", + "reference": " @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n String base64AuthenticatorData = \"base64AuthenticatorData_test\";\n String signature = \"signature_test\";\n String clientDataJson = \"clientDataJson_test\";\n Fido2RegistrationData registration = mock(Fido2RegistrationData.class);\n Fido2AuthenticationData authenticationEntity = mock(Fido2AuthenticationData.class);\n\n when(authenticationEntity.getUserVerificationOption()).thenReturn(UserVerification.preferred);\n when(registration.getDomain()).thenReturn(\"domain_test\");\n when(registration.getCounter()).thenReturn(100);\n when(registration.getUncompressedECPoint()).thenReturn(\"uncompressedECPoint_test\");\n\n when(authenticatorDataParser.parseAssertionData(any())).thenReturn(mock(AuthData.class));\n when(base64Service.urlDecode(any(String.class))).thenReturn(\"decode_test\".getBytes());\n when(dataMapperService.cborReadTree(any(byte[].class))).thenThrow(new IOException(\"IOException_test\"));\n\n Fido2RuntimeException ex = assertThrows(Fido2RuntimeException.class, () -> noneAssertionFormatProcessor.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity));\n assertNotNull(ex);\n assertEquals(ex.getMessage(), \"IOException_test\");\n\n verify(log).debug(eq(\"Registration: {}\"), any(Fido2RegistrationData.class));\n verify(log).debug(eq(\"User verification option: {}\"), any(UserVerification.class));\n verify(commonVerifiers).verifyRpIdHash(any(AuthData.class), any(String.class));\n verify(authenticatorDataParser).parseCounter(any());\n verify(log).error(eq(\"Error to check none assertion: {}\"), any(String.class));\n\n verify(log, never()).error(eq(\"Error compromised device: {}\"), any(String.class));\n verifyNoInteractions(coseService, authenticatorDataVerifier);\n verifyNoMoreInteractions(log);\n }", + "completions": [ + "```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n // Arrange\n String base64AuthenticatorData = \"base64AuthenticatorData\";\n String signature = \"signature\";\n String clientDataJson = \"clientDataJson\";\n Fido2RegistrationData registration = new Fido2RegistrationData();\n Fido2AuthenticationData authenticationEntity = new Fido2AuthenticationData();\n\n when(authenticatorDataParser.parseAssertionData(base64AuthenticatorData)).thenReturn(new AuthData());\n when(authenticatorDataParser.parseCounter(any())).thenReturn(1);\n when(dataMapperService.cborReadTree(any())).thenThrow(new IOException());\n\n // Act & Assert\n Assertions.assertThrows(Fido2RuntimeException.class, () -> {\n fido2AuthenticationService.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity);\n });\n }\n```\n" + ], + "line": 38, + "token": 365, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 18, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void setUserData_invalidKey_throws()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "setUserData_invalidKey_throws", + "reference": "\t@Test\n\tpublic void setUserData_invalidKey_throws() {\n\t\t// setup\n\t\tTestParameters parameters = new TestParameters();\n\n\t\t// test\n\t\tfor (String key : new String[] { null, \"\" }) {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, \"test\");\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid key; key must be non-empty, non-null\"),\n\t\t\t\t\tex.getMessage());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@Test(expected = IllegalArgumentException.class)\n\tpublic void setUserData_invalidKey_throws() {\n\t\tTestClass testClass = new TestClass();\n\t\ttestClass.setUserData(\"\", \"value\");\n\t}\n```\n\nIn this test, we are creating an instance of the TestClass and calling the setUserData method with an empty key and a valid value. We expect this to throw an IllegalArgumentException, which is what we specify in the @Test annotation.\n" + ], + "line": 23, + "token": 170, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueDirectoryMissingErrorsOperation", + "reference": " @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ByteString foo = ByteString.copyFromUtf8(\"foo\");\n Digest subdirDigest = DIGEST_UTIL.compute(foo);\n Directory inputRoot =\n Directory.newBuilder()\n .addDirectories(\n DirectoryNode.newBuilder().setName(\"missing-subdir\").setDigest(subdirDigest))\n .build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(true);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_MISSING)\n .setSubject(\"blobs/\" + DigestUtil.toString(subdirDigest))\n .setDescription(\"The directory `/missing-subdir` was not found in the CAS.\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(10);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(DigestUtil.DEFAULT_INSTANCE_NAME + \"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n ActionKey actionKey = DigestUtil.asActionKey(executeEntry.getActionDigest());\n when(checkCacheFuture(actionKey, any(), any())).thenReturn(immediateFuture(true));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n assertThat(future.get()).isEqualTo(null);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 20, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime = Instant.now();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Start recovery with an offset range that is partially unavailable\n long startOffset = 50;\n long endOffset = 150;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry)).isEqualTo(50);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(51);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 100L,\n 200L,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 21, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public List selectHeaderAccept(String[] accepts) {\n if (accepts.length == 0) {\n return null;\n }\n for (String accept : accepts) {\n MediaType mediaType = MediaType.parseMediaType(accept);\n if (isJsonMime(mediaType) && !isProblemJsonMime(accept)) {\n return Collections.singletonList(mediaType);\n }\n }\n return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts));\n }\n```\n###test function signature:\n```java\n @Test\n public void testSelectHeaderAccept()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSelectHeaderAccept", + "reference": " @Test\n public void testSelectHeaderAccept() {\n String[] accepts = {\"application/json\", \"application/xml\"};\n assertEquals(\"application/json\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"APPLICATION/XML\", \"APPLICATION/JSON\"};\n assertEquals(\"APPLICATION/JSON\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"application/xml\", \"application/json; charset=UTF8\"};\n assertEquals(\"application/json; charset=UTF8\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"text/plain\", \"application/xml\"};\n assertEquals(\"text/plain,application/xml\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{};\n assertNull(apiClient.selectHeaderAccept(accepts));\n }", + "completions": [ + "```java\n @Test\n public void testSelectHeaderAccept() {\n // Given\n String[] accepts = {\"application/json\", \"application/problem+json\"};\n // When\n List result = selectHeaderAccept(accepts);\n // Then\n assertEquals(1, result.size());\n assertEquals(\"application/json\", result.get(0).toString());\n }\n```\n" + ], + "line": 23, + "token": 154, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollEventSentOnConsumerPoll", + "reference": " @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n SubscriptionState subscriptions = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);\n consumer = newConsumer(\n mock(FetchBuffer.class),\n new ConsumerInterceptors<>(Collections.emptyList()),\n mock(ConsumerRebalanceListenerInvoker.class),\n subscriptions,\n singletonList(new RoundRobinAssignor()),\n \"group-id\",\n \"client-id\");\n final TopicPartition tp = new TopicPartition(\"topic\", 0);\n final List> records = singletonList(\n new ConsumerRecord<>(\"topic\", 0, 2, \"key1\", \"value1\"));\n doAnswer(invocation -> Fetch.forPartition(tp, records, true))\n .when(fetchCollector)\n .collectFetch(Mockito.any(FetchBuffer.class));\n\n consumer.subscribe(singletonList(\"topic1\"));\n consumer.poll(Duration.ofMillis(100));\n verify(applicationEventHandler).add(any(PollEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n // Arrange\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n consumer.poll(Duration.ofMillis(100));\n\n // Act\n ConsumerRecords records = consumer.poll(Duration.ofMillis(100));\n\n // Assert\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 23, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetMultiplePartitionRecoveriesBehind", + "reference": " @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n final RecoveryTaskMetadata recoveryTask2 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"2\",\n \"2\",\n recoveryStartOffset * 3 + 1,\n recoveryStartOffset * 4,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask2);\n final RecoveryTaskMetadata recoveryTask21 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"21\", \"2\", recoveryStartOffset * 4 + 1, 50000, createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask21);\n await()\n .until(\n () ->\n recoveryTaskStore\n .listSync()\n .containsAll(\n List.of(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(5);\n assertThat(recoveryTasks)\n .contains(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n // Given\n String partitionId = \"partition1\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 24, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture commitStreamObject(apache.rocketmq.controller.v1.S3StreamObject streamObject,\n List compactedObjects) {\n LOGGER.info(\"commitStreamObject with streamObject: {}, compactedObjects: {}\", TextFormat.shortDebugString(streamObject),\n compactedObjects);\n\n CompletableFuture future = new CompletableFuture<>();\n try (SqlSession session = sessionFactory.openSession()) {\n if (streamObject.getObjectId() == S3Constants.NOOP_OBJECT_ID) {\n LOGGER.error(\"S3StreamObject[object-id={}] is null or objectId is unavailable\", streamObject.getObjectId());\n String msg = String.format(\"S3StreamObject[object-id=%d] is null or objectId is unavailable\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.NOT_FOUND_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n\n long committedTs = System.currentTimeMillis();\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n\n // commit object\n if (!commitObject(streamObject.getObjectId(), streamObject.getStreamId(), streamObject.getObjectSize(), session)) {\n String msg = String.format(\"S3StreamObject[object-id=%d] is not ready for commit\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.ILLEGAL_STATE_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n long dataTs = committedTs;\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n dataTs = compactedObjects.stream()\n .map(id -> {\n // mark destroy compacted object\n S3Object object = s3ObjectMapper.getById(id);\n object.setState(S3ObjectState.BOS_WILL_DELETE);\n object.setMarkedForDeletionTimestamp(new Date());\n s3ObjectMapper.markToDelete(object.getId(), new Date());\n\n // update dataTs to the min compacted object's dataTs\n com.automq.rocketmq.metadata.dao.S3StreamObject s3StreamObject =\n s3StreamObjectMapper.getByObjectId(id);\n return s3StreamObject.getBaseDataTimestamp().getTime();\n })\n .min(Long::compareTo).get();\n }\n\n List toCache = new ArrayList<>();\n\n // create a new S3StreamObject to replace committed ones\n if (streamObject.getObjectId() != S3Constants.NOOP_OBJECT_ID) {\n com.automq.rocketmq.metadata.dao.S3StreamObject newS3StreamObj =\n new com.automq.rocketmq.metadata.dao.S3StreamObject();\n newS3StreamObj.setStreamId(streamObject.getStreamId());\n newS3StreamObj.setObjectId(streamObject.getObjectId());\n newS3StreamObj.setObjectSize(streamObject.getObjectSize());\n newS3StreamObj.setStartOffset(streamObject.getStartOffset());\n newS3StreamObj.setEndOffset(streamObject.getEndOffset());\n newS3StreamObj.setBaseDataTimestamp(new Date(dataTs));\n newS3StreamObj.setCommittedTimestamp(new Date(committedTs));\n s3StreamObjectMapper.create(newS3StreamObj);\n toCache.add(newS3StreamObj);\n }\n\n // delete the compactedObjects of S3Stream\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n compactedObjects.forEach(id -> s3StreamObjectMapper.delete(null, null, id));\n }\n session.commit();\n\n // Update Cache\n s3StreamObjectCache.cache(streamObject.getStreamId(), toCache);\n s3StreamObjectCache.onCompact(streamObject.getStreamId(), compactedObjects);\n\n LOGGER.info(\"S3StreamObject[object-id={}] commit success, compacted objects: {}\",\n streamObject.getObjectId(), compactedObjects);\n future.complete(null);\n } catch (Exception e) {\n LOGGER.error(\"CommitStream failed\", e);\n ControllerException ex = new ControllerException(Code.INTERNAL_VALUE, \"CommitStream failed\" + e.getMessage());\n future.completeExceptionally(ex);\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCommitStreamObject() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCommitStreamObject", + "reference": " @Test\n public void testCommitStreamObject() throws IOException {\n long objectId, streamId = 1;\n\n try (S3MetadataService metadataService = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n objectId = metadataService.prepareS3Objects(3, 5).get();\n } catch (Exception e) {\n throw new RuntimeException(e);\n }\n\n S3StreamObject news3StreamObject = S3StreamObject.newBuilder()\n .setObjectId(objectId + 2)\n .setStreamId(streamId)\n .setObjectSize(111L)\n .build();\n\n try (SqlSession session = this.getSessionFactory().openSession()) {\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n buildS3StreamObjs(objectId, 2, 3, 100L).forEach(s3StreamObjectMapper::create);\n session.commit();\n }\n\n try (S3MetadataService service = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n List compactedObjects = new ArrayList<>();\n compactedObjects.add(objectId);\n compactedObjects.add(objectId + 1);\n service.commitStreamObject(news3StreamObject, compactedObjects);\n }\n\n try (SqlSession session = getSessionFactory().openSession()) {\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3Object s3Object = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object.getState());\n\n S3Object s3Object1 = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object1.getState());\n\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n for (long index = objectId; index < objectId + 2; index++) {\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(index);\n Assertions.assertNull(object);\n }\n\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(objectId + 2);\n Assertions.assertEquals(111L, object.getObjectSize());\n Assertions.assertEquals(streamId, object.getStreamId());\n Assertions.assertTrue(object.getBaseDataTimestamp().getTime() > 0);\n Assertions.assertTrue(object.getCommittedTimestamp().getTime() > 0);\n }\n }", + "completions": [ + "```java\n @Test\n public void testCommitStreamObject() throws IOException {\n // Given\n apache.rocketmq.controller.v1.S3StreamObject streamObject = new apache.rocketmq.controller.v1.S3StreamObject();\n streamObject.setObjectId(1L);\n streamObject.setStreamId(2L);\n streamObject.setObjectSize(3L);\n streamObject.setStartOffset(4L);\n streamObject.setEndOffset(5L);\n\n List compactedObjects = Arrays.asList(6L, 7L, 8L);\n\n // When\n CompletableFuture future = controller.commitStreamObject(streamObject, compactedObjects);\n\n // Then\n assertDoesNotThrow(() -> future.get());\n }\n```\n" + ], + "line": 85, + "token": 664, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 25, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyTooManyResults", + "reference": " @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(totalHits.value).thenReturn(2L);\n Mockito.when(hits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(response.getHits()).thenReturn(hits);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processService.getProcessByKey(processDefinitionKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Expected exception message to match\",\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n e.getMessage());\n }\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String doGet(String url) {\n try (CloseableHttpClient client = HttpClients.createDefault()) {\n HttpGet request = new HttpGet(url);\n HttpResponse response = client.execute(request);\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n return EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n }\n } catch (IOException e) {\n log.error(\"DoGet exception.\", e);\n\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n void testDoGet()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDoGet", + "reference": " @Test\n void testDoGet() {\n String result = HttpUtil.doGet(\"http://baidu.com\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testDoGet() {\n String url = \"http://example.com\";\n String expectedResponse = \"Hello, world!\";\n\n // Mock the HttpClient and HttpResponse\n CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);\n HttpResponse mockResponse = Mockito.mock(HttpResponse.class);\n StatusLine mockStatusLine = Mockito.mock(StatusLine.class);\n HttpEntity mockEntity = Mockito.mock(HttpEntity.class);\n\n // Set up the mock objects to return the expected values\n Mockito.when(mockClient.execute(Mockito.any(HttpGet.class))).thenReturn(mockResponse);\n Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);\n Mockito.when(mockStatusLine.getStatusCode()).thenReturn(HttpStatus.SC_OK);\n Mockito.when(mockResponse.getEntity()).thenReturn(mockEntity);\n Mockito.when(EntityUtils.toString(mockEntity, StandardCharsets.UTF_8)).thenReturn(expectedResponse);\n\n // Call the method under test\n String actualResponse = YourClassName.doGet(url);\n\n // Verify the result\n Assertions.assertEquals(expectedResponse, actualResponse);\n }\n```\nPlease replace `YourClassName` with the actual name of the class where the `doGet` method is located.\n" + ], + "line": 23, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n dpopService.validateDpopThumprint(null, \"any\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n // Given\n String existingThumprint = \"\";\n String requestThumprint = \"requestThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n try {\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), e.getResponse().getStatus());\n assertEquals(\"Thumprint does not match\", e.getMessage());\n }\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDataMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDataMapTest", + "reference": " @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"hera_log_doris_table_2_11\";\n for (int i = 0; i < 1; i++) {\n String jsonString = \"{\\\"linenumber\\\":1256,\\\"tailId\\\":90219,\\\"mqtag\\\":\\\"tags_4_13_90219\\\",\\\"filename\\\":\\\"/home/work/log/log-agent/server.log\\\",\\\"tail\\\":\\\"demo-client-agent\\\",\\\"mqtopic\\\":\\\"90219_hera-demo-client\\\",\\\"message\\\":\\\"2024-01-10 19:34:40,360|INFO ||NettyClientPublicExecutor_1|c.x.mone.log.agent.rpc.task.PingTask|83|ping res: log-agent-server:2022-12-05:0.0.2->2024-01-10 19:34:40 358->10.53.129.250:9899\\\",\\\"logstore\\\":\\\"测试doris日志\\\",\\\"logip\\\":\\\"10.53.129.176\\\",\\\"timestamp\\\":1704886481241}\";\n\n Map map = gson.fromJson(jsonString, Map.class);\n dorisStreamLoad.sendData(database, tableName, columnListNew, map);\n }\n }", + "completions": [ + "```java\n @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 29, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyNoResults", + "reference": " @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final NotFoundException exception =\n Assertions.assertThrows(\n NotFoundException.class, () -> processService.getProcessByKey(processDefinitionKey));\n\n // Then\n Assertions.assertEquals(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 30, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String getDiagramByKey(Long processDefinitionKey) {\n final IdsQueryBuilder q = idsQuery().addIds(processDefinitionKey.toString());\n\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n\n if (response.getHits().getTotalHits().value == 1) {\n final Map result = response.getHits().getHits()[0].getSourceAsMap();\n return (String) result.get(BPMN_XML);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with id '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with id '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\n \"Exception occurred, while obtaining the process diagram: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetDiagramByKeyTooManyResults", + "reference": " @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getDiagramByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 1L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHits.Relation.EQUAL_TO));\n\n // When\n try {\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n processDiagramService.getDiagramByKey(processDefinitionKey);\n } catch (NotFoundException e) {\n // Then\n Assert.assertEquals(\n \"Could not find unique process with id '\" + processDefinitionKey + \"'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 37, + "token": 309, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 31, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n void shutdown() {\n LOG.info(\"Running shutdown hook.\");\n try {\n serviceManager.stopAsync().awaitStopped(30, TimeUnit.SECONDS);\n } catch (Exception e) {\n // stopping timed out\n LOG.error(\"ServiceManager shutdown timed out\", e);\n }\n try {\n curatorFramework.unwrap().close();\n } catch (Exception e) {\n LOG.error(\"Error while closing curatorFramework \", e);\n }\n LOG.info(\"Shutting down LogManager\");\n LogManager.shutdown();\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexerShutdownTwice() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexerShutdownTwice", + "reference": " @Test\n public void testIndexerShutdownTwice() throws Exception {\n startKafkaServer();\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // Create a live partition for this partiton\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTimeMs = 1;\n final long endTimeMs = 100;\n final long maxOffset = 30;\n SnapshotMetadata livePartition0 =\n new SnapshotMetadata(\n name + \"live0\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"0\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition0);\n\n SnapshotMetadata livePartition1 =\n new SnapshotMetadata(\n name + \"live1\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"1\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition1);\n\n final SnapshotMetadata partition0 =\n new SnapshotMetadata(name, path, startTimeMs, endTimeMs, maxOffset, \"0\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition0);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .containsOnly(livePartition1, livePartition0, partition0);\n\n // Empty consumer offset since there is no prior consumer.\n astraIndexer =\n new AstraIndexer(\n chunkManagerUtil.chunkManager,\n curatorFramework,\n makeIndexerConfig(50),\n getKafkaConfig(),\n metricsRegistry);\n astraIndexer.startAsync();\n astraIndexer.awaitRunning(DEFAULT_START_STOP_DURATION);\n await().until(() -> kafkaServer.getConnectedConsumerGroups() == 1);\n\n // Produce more messages since the recovery task is created for head.\n produceMessagesToKafka(kafkaServer.getBroker(), startTime);\n\n consumeMessagesAndSearchMessagesTest(100, 1);\n\n // Live snapshot is deleted, recovery task is created.\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(livePartition1, partition0);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isEqualTo(4);\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).size()).isEqualTo(1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(searchMetadataStore).size()).isEqualTo(1);\n RecoveryTaskMetadata recoveryTask1 =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).get(0);\n assertThat(recoveryTask1.startOffset).isEqualTo(31);\n assertThat(recoveryTask1.endOffset).isEqualTo(99);\n assertThat(recoveryTask1.partitionId).isEqualTo(\"0\");\n\n // Shutting down is idempotent. So, doing it twice shouldn't throw an error.\n astraIndexer.shutDown();\n astraIndexer.shutDown();\n astraIndexer = null;\n }", + "completions": [ + "```java\n @Test\n public void testIndexerShutdownTwice() throws Exception {\n // Given\n Indexer indexer = new Indexer();\n indexer.serviceManager = Mockito.mock(ServiceManager.class);\n indexer.curatorFramework = Mockito.mock(CuratorFramework.class);\n Mockito.when(indexer.serviceManager.stopAsync()).thenReturn(Mockito.mock(ServiceManager.class));\n Mockito.when(indexer.serviceManager.awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class))).thenThrow(new TimeoutException());\n Mockito.when(indexer.curatorFramework.unwrap()).thenReturn(Mockito.mock(Closeable.class));\n Mockito.doThrow(new IOException()).when(indexer.curatorFramework.unwrap()).close();\n\n // When\n indexer.shutdown();\n\n // Then\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).stopAsync();\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class));\n Mockito.verify(indexer.curatorFramework.unwrap(), Mockito.times(1)).close();\n }\n```\n" + ], + "line": 27, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(addressList,\"test\",\"test\",\"test\", SystemEnum.AGENT, MsgTypeEnum.ALARM);\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = null;\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 33, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyWithException", + "reference": " @Test\n public void testGetProcessByKeyWithException() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(OperateRuntimeException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(totalHits.value).thenReturn(2L);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final Executable executable = () -> processService.getProcessByKey(processDefinitionKey);\n\n // Then\n final NotFoundException exception =\n Assertions.assertThrows(NotFoundException.class, executable);\n Assertions.assertEquals(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 34, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testConcurrentSearches() throws InterruptedException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testConcurrentSearches", + "reference": " @Test\n public void testConcurrentSearches() throws InterruptedException {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n AtomicInteger searchFailures = new AtomicInteger(0);\n AtomicInteger statsFailures = new AtomicInteger(0);\n AtomicInteger searchExceptions = new AtomicInteger(0);\n AtomicInteger successfulRuns = new AtomicInteger(0);\n\n Runnable searchRun =\n () -> {\n for (int i = 0; i < 100; i++) {\n try {\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n 0L,\n MAX_TIME,\n 100,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n if (babies.hits.size() != 2) {\n searchFailures.addAndGet(1);\n } else {\n successfulRuns.addAndGet(1);\n }\n } catch (Exception e) {\n searchExceptions.addAndGet(1);\n }\n }\n };\n\n Thread t1 = new Thread(searchRun);\n Thread t2 = new Thread(searchRun);\n t1.start();\n t2.start();\n t1.join();\n t2.join();\n assertThat(searchExceptions.get()).isEqualTo(0);\n assertThat(statsFailures.get()).isEqualTo(0);\n assertThat(searchFailures.get()).isEqualTo(0);\n assertThat(successfulRuns.get()).isEqualTo(200);\n }", + "completions": [ + "```java\n @Test\n public void testConcurrentSearches() throws InterruptedException {\n // Create a fixed thread pool with 10 threads\n ExecutorService executorService = Executors.newFixedThreadPool(10);\n\n // Create a list to hold the futures\n List>> futures = new ArrayList<>();\n\n // Submit 100 search tasks to the executor service\n for (int i = 0; i < 100; i++) {\n // Create a search task\n Callable> searchTask = () -> {\n // Generate a random dataset\n String dataset = UUID.randomUUID().toString();\n\n // Generate a random query string\n String queryStr = UUID.randomUUID().toString();\n\n // Generate a random start time\n long startTimeMsEpoch = System.currentTimeMillis() - ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random end time\n long endTimeMsEpoch = startTimeMsEpoch + ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random number of hits requested\n int howMany = ThreadLocalRandom.current().nextInt(100);\n\n // Generate a random aggregation builder\n AggBuilder aggBuilder = new AggBuilder();\n\n // Perform the search\n return logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n };\n\n // Submit the search task to the executor service and add the future to the list\n futures.add(executorService.submit(searchTask));\n }\n\n // Wait for all search tasks to complete\n for (Future> future : futures) {\n future.get();\n }\n\n // Shutdown the executor service\n executorService.shutdown();\n }\n```\n" + ], + "line": 85, + "token": 608, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 35, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String sendGetData(String path, Map headMap) {\n log.info(\"request tcm, path:{},headMap:{} \",path,headMap);\n String result = \"\";\n CloseableHttpResponse response =null;\n try(CloseableHttpClient httpClient = HttpClients.createDefault()) {\n // 创建get方式请求对象\n HttpGet httpGet = new HttpGet(path);\n if (null != headMap) {\n for (Map.Entry entry : headMap.entrySet()) {\n httpGet.addHeader(entry.getKey(), entry.getValue());\n }\n }\n // 通过请求对象获取响应对象\n response = httpClient.execute(httpGet);\n\n // 获取结果实体\n // 判断网络连接状态码是否正常(0--200都数正常)\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n result = EntityUtils.toString(response.getEntity(), UTF_8);\n } else {\n log.error(\"get请求传输 异常.url:{}, headMap:{}\", path, headMap);\n }\n } catch (ClientProtocolException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n } catch (IOException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n }\n finally {\n try {\n if (null!=response){\n response.close();\n }\n } catch (IOException e) {\n log.error(\"关闭response 异常 \", e);\n }\n }\n log.debug(result);\n return result;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendGetData()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendGetData", + "reference": " @Test\n void testSendGetData() {\n final Map headMap = new HashMap<>();\n final String result = HttpUtils.sendGetData(\"path\", headMap);\n assertThat(result).isEqualTo(\"\");\n }", + "completions": [ + "```java\n @Test\n void testSendGetData() {\n String path = \"http://localhost:8080/api/test\";\n Map headMap = new HashMap<>();\n headMap.put(\"Content-Type\", \"application/json\");\n headMap.put(\"Authorization\", \"Bearer your_token\");\n\n String result = sendGetData(path, headMap);\n\n // Add assertions to verify the result\n // For example, if the result is expected to be a JSON string, you can use a JSON library to parse the result and assert the values\n // Here is an example using the JSON library Jackson\n ObjectMapper objectMapper = new ObjectMapper();\n JsonNode jsonNode = objectMapper.readTree(result);\n assertEquals(\"expectedValue\", jsonNode.get(\"key\").asText());\n }\n```\nPlease replace \"your_token\" with the actual token you want to use for the test. Also, replace \"expectedValue\" and \"key\" with the actual expected value and key from the JSON response.\n" + ], + "line": 51, + "token": 367, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 36, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ByteString retainKeyPrefix(ByteString tenantNS, List topicFilterLevels) {\n ByteString prefix = ByteString.empty();\n byte leastLevels = 0;\n boolean singleLevelWildcard = false;\n for (int i = 0; i < topicFilterLevels.size(); i++) {\n String tfl = topicFilterLevels.get(i);\n if (\"+\".equals(tfl)) {\n leastLevels++;\n singleLevelWildcard = true;\n continue;\n }\n if (\"#\".equals(tfl)) {\n break;\n }\n leastLevels++;\n if (!singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(tfl));\n }\n if (i + 1 < topicFilterLevels.size()) {\n if (!topicFilterLevels.get(i + 1).equals(\"#\") && !singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(NUL));\n }\n }\n }\n return tenantNS.concat(unsafeWrap(new byte[] {leastLevels})).concat(prefix);\n }\n```\n###test function signature:\n```java\n @Test\n public void testRetainKeyPrefix()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRetainKeyPrefix", + "reference": " @Test\n public void testRetainKeyPrefix() {\n String tenantId = \"tenantA\";\n ByteString tenantNS = tenantNS(tenantId);\n assertEquals(retainKeyPrefix(tenantNS, parse(\"#\", false)), tenantNS.concat(levelByte(0)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/#\", false)),\n tenantNS.concat(levelByte(1).concat(copyFromUtf8(\"a\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/+\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(\"a\").concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/b\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)).concat(copyFromUtf8(\"b\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/#\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+\", false)),\n tenantNS.concat(\n levelByte(3).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+/+\", false)),\n tenantNS.concat(\n levelByte(4).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/\", false)),\n tenantNS.concat(levelByte(4).concat(copyFromUtf8(NUL))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/+/\", false)),\n tenantNS.concat(levelByte(5).concat(copyFromUtf8(NUL))));\n }", + "completions": [ + "```java\n @Test\n public void testRetainKeyPrefix() {\n ByteString tenantNS = ByteString.copyFromUtf8(\"tenant1\");\n List topicFilterLevels = Arrays.asList(\"level1\", \"level2\", \"+\", \"#\");\n ByteString expectedPrefix = ByteString.copyFromUtf8(\"tenant1\\u0001level1\\u0000level2\");\n ByteString actualPrefix = retainKeyPrefix(tenantNS, topicFilterLevels);\n assertEquals(expectedPrefix, actualPrefix);\n }\n```\n" + ], + "line": 37, + "token": 255, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 37, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n @Transactional(rollbackFor = Exception.class)\n public void updateFileConfigMaster(Long id) {\n // 校验存在\n validateFileConfigExists(id);\n // 更新其它为非 master\n fileConfigMapper.updateBatch(new FileConfigDO().setMaster(false));\n // 更新\n fileConfigMapper.updateById(new FileConfigDO().setId(id).setMaster(true));\n // 发送刷新配置的消息\n TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {\n\n @Override\n public void afterCommit() {\n fileConfigProducer.sendFileConfigRefreshMessage();\n }\n\n });\n }\n```\n###test function signature:\n```java\n @Test\n public void testUpdateFileConfigMaster_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testUpdateFileConfigMaster_success", + "reference": " @Test\n public void testUpdateFileConfigMaster_success() {\n // mock 数据\n FileConfigDO dbFileConfig = randomFileConfigDO().setMaster(false);\n fileConfigMapper.insert(dbFileConfig);// @Sql: 先插入出一条存在的数据\n FileConfigDO masterFileConfig = randomFileConfigDO().setMaster(true);\n fileConfigMapper.insert(masterFileConfig);// @Sql: 先插入出一条存在的数据\n\n // 调用\n fileConfigService.updateFileConfigMaster(dbFileConfig.getId());\n // 断言数据\n assertTrue(fileConfigMapper.selectById(dbFileConfig.getId()).getMaster());\n assertFalse(fileConfigMapper.selectById(masterFileConfig.getId()).getMaster());\n // verify 调用\n verify(fileConfigProducer).sendFileConfigRefreshMessage();\n }", + "completions": [ + "```java\n @Test\n public void testUpdateFileConfigMaster_success() {\n // 准备数据\n FileConfigDO fileConfigDO = new FileConfigDO();\n fileConfigDO.setId(1L);\n fileConfigDO.setMaster(false);\n fileConfigMapper.insert(fileConfigDO);\n\n // 执行方法\n fileConfigService.updateFileConfigMaster(1L);\n\n // 验证结果\n FileConfigDO updatedFileConfigDO = fileConfigMapper.selectById(1L);\n Assert.assertTrue(updatedFileConfigDO.getMaster());\n\n // 清理数据\n fileConfigMapper.deleteById(1L);\n }\n```\n" + ], + "line": 28, + "token": 164, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\t@Override\n\tpublic void add(List documents) {\n\n\t\tAssert.notNull(documents, \"The document list should not be null.\");\n\t\tif (CollectionUtils.isEmpty(documents)) {\n\t\t\treturn; // nothing to do;\n\t\t}\n\n\t\tfinal var searchDocuments = documents.stream().map(document -> {\n\t\t\tfinal var embeddings = this.embeddingClient.embed(document);\n\t\t\tSearchDocument searchDocument = new SearchDocument();\n\t\t\tsearchDocument.put(ID_FIELD_NAME, document.getId());\n\t\t\tsearchDocument.put(EMBEDDING_FIELD_NAME, embeddings);\n\t\t\tsearchDocument.put(CONTENT_FIELD_NAME, document.getContent());\n\t\t\tsearchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());\n\n\t\t\t// Add the filterable metadata fields as top level fields, allowing filler\n\t\t\t// expressions on them.\n\t\t\tfor (MetadataField mf : this.filterMetadataFields) {\n\t\t\t\tif (document.getMetadata().containsKey(mf.name())) {\n\t\t\t\t\tsearchDocument.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name()));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn searchDocument;\n\t\t}).toList();\n\n\t\tIndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);\n\n\t\tfor (IndexingResult indexingResult : result.getResults()) {\n\t\t\tAssert.isTrue(indexingResult.isSucceeded(),\n\t\t\t\t\tString.format(\"Document with key %s upload is not successfully\", indexingResult.getKey()));\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void addAndSearchTest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "addAndSearchTest", + "reference": "\t@Test\n\tpublic void addAndSearchTest() {\n\n\t\tcontextRunner.run(context -> {\n\n\t\t\tVectorStore vectorStore = context.getBean(VectorStore.class);\n\n\t\t\tvectorStore.add(documents);\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\t\t\t}, hasSize(1));\n\n\t\t\tList results = vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\n\t\t\tassertThat(results).hasSize(1);\n\t\t\tDocument resultDoc = results.get(0);\n\t\t\tassertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());\n\t\t\tassertThat(resultDoc.getContent()).contains(\"The Great Depression (1929–1939) was an economic shock\");\n\t\t\tassertThat(resultDoc.getMetadata()).hasSize(2);\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"meta2\");\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"distance\");\n\n\t\t\t// Remove all documents from the store\n\t\t\tvectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Hello\").withTopK(1));\n\t\t\t}, hasSize(0));\n\t\t});\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void addAndSearchTest() {\n\t\t// Given\n\t\tList documents = new ArrayList<>();\n\t\tDocument document1 = new Document();\n\t\tdocument1.setId(\"1\");\n\t\tdocument1.setContent(\"This is a test document\");\n\t\tdocument1.setMetadata(new HashMap<>());\n\t\tdocument1.getMetadata().put(\"author\", \"John Doe\");\n\t\tdocument1.getMetadata().put(\"date\", \"2022-01-01\");\n\t\tdocuments.add(document1);\n\n\t\tDocument document2 = new Document();\n\t\tdocument2.setId(\"2\");\n\t\tdocument2.setContent(\"This is another test document\");\n\t\tdocument2.setMetadata(new HashMap<>());\n\t\tdocument2.getMetadata().put(\"author\", \"Jane Doe\");\n\t\tdocument2.getMetadata().put(\"date\", \"2022-01-02\");\n\t\tdocuments.add(document2);\n\n\t\t// When\n\t\tsearchService.add(documents);\n\n\t\t// Then\n\t\tSearchRequest searchRequest = new SearchRequest();\n\t\tsearchRequest.setQuery(\"test\");\n\t\tsearchRequest.setFilter(new HashMap<>());\n\t\tsearchRequest.getFilter().put(\"author\", \"John Doe\");\n\t\tSearchResponse searchResponse = searchService.search(searchRequest);\n\n\t\tAssert.assertEquals(1, searchResponse.getDocuments().size());\n\t\tAssert.assertEquals(\"1\", searchResponse.getDocuments().get(0).getId());\n\t}\n```\n" + ], + "line": 39, + "token": 312, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 39, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {\n long startTime = System.currentTimeMillis();\n\n // 校验参数\n try {\n validateParams(metrics);\n } catch (Exception e) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(e.getMessage());\n return;\n }\n\n HttpContext httpContext = createHttpContext(metrics.getNginx());\n HttpUriRequest request = createHttpRequest(metrics.getNginx());\n try (CloseableHttpResponse response = CommonHttpClient.getHttpClient().execute(request, httpContext)){\n // 发起http请求,获取响应数据\n int statusCode = response.getStatusLine().getStatusCode();\n if (statusCode != SUCCESS_CODE) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(\"StatusCode \" + statusCode);\n return;\n }\n String resp = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n\n Long responseTime = System.currentTimeMillis() - startTime;\n // 根据metrics name选择调用不同解析方法\n if (NGINX_STATUS_NAME.equals(metrics.getName()) || AVAILABLE.equals(metrics.getName())) {\n parseNginxStatusResponse(builder, resp, metrics, responseTime);\n } else if (REQ_STATUS_NAME.equals(metrics.getName())) {\n parseReqStatusResponse(builder, resp, metrics, responseTime);\n }\n } catch (Exception e) {\n String errorMsg = CommonUtil.getMessageFromThrowable(e);\n log.info(errorMsg);\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(errorMsg);\n } finally {\n if (request != null) {\n request.abort();\n }\n }\n\n }\n```\n###test function signature:\n```java\n @Test\n public void testNginxCollect()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNginxCollect", + "reference": " @Test\n public void testNginxCollect() {\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n long monitorId = 999;\n String app = \"testNginx\";\n\n Metrics metrics = new Metrics();\n metrics.setNginx(nginxProtocol);\n nginxCollect.collect(builder, monitorId, app, metrics);\n }", + "completions": [ + "```java\n @Test\n public void testNginxCollect() {\n // 创建一个Metrics对象\n Metrics metrics = new Metrics();\n metrics.setName(\"nginx_status\");\n metrics.setNginx(new Nginx(\"http://localhost:8080/nginx_status\"));\n\n // 创建一个CollectRep.MetricsData.Builder对象\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n\n // 创建一个NginxCollect对象\n NginxCollect nginxCollect = new NginxCollect();\n\n // 调用collect方法\n nginxCollect.collect(builder, 1, \"test\", metrics);\n\n // 验证结果\n assertEquals(CollectRep.Code.SUCCESS, builder.getCode());\n assertNotNull(builder.getData());\n assertTrue(builder.getData().getMetrics().containsKey(\"active_connections\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"accepts\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"handled\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"requests\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"reading\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"writing\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"waiting\"));\n }\n```\n" + ], + "line": 51, + "token": 335, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 40, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Transactional(rollbackOn = Exception.class)\n public void submit(@NonNull PartitionPlanConfig partitionPlanConfig)\n throws SchedulerException, ClassNotFoundException {\n Long databaseId = partitionPlanConfig.getDatabaseId();\n Validate.notNull(databaseId, \"DatabaseId can not be null\");\n // disable all related partition plan task\n Database database = this.databaseService.detail(databaseId);\n disablePartitionPlan(database.getId());\n PartitionPlanEntity partitionPlanEntity = modelToEntity(partitionPlanConfig);\n partitionPlanEntity = this.partitionPlanRepository.save(partitionPlanEntity);\n if (!partitionPlanConfig.isEnabled()\n || CollectionUtils.isEmpty(partitionPlanConfig.getPartitionTableConfigs())) {\n log.info(\"Partition plan is disabled or table config is empty, do nothing and return\");\n return;\n }\n Validate.isTrue(partitionPlanConfig.getCreationTrigger() != null, \"Creation trigger can not be null\");\n if (partitionPlanConfig.getDroppingTrigger() == null) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(partitionPlanConfig.getPartitionTableConfigs(),\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n return;\n }\n Map> strategy2TblCfgs =\n partitionPlanConfig.getPartitionTableConfigs().stream().flatMap(tableConfig -> {\n Map> strategy2Cfgs =\n tableConfig.getPartitionKeyConfigs().stream()\n .collect(Collectors.groupingBy(PartitionPlanKeyConfig::getStrategy));\n return strategy2Cfgs.values().stream().map(cfgs -> {\n PartitionPlanTableConfig cfg = new PartitionPlanTableConfig();\n cfg.setPartitionKeyConfigs(cfgs);\n cfg.setTableName(tableConfig.getTableName());\n cfg.setEnabled(tableConfig.isEnabled());\n cfg.setPartitionNameInvoker(tableConfig.getPartitionNameInvoker());\n cfg.setPartitionNameInvokerParameters(tableConfig.getPartitionNameInvokerParameters());\n return cfg;\n });\n }).collect(Collectors.groupingBy(cfg -> cfg.getPartitionKeyConfigs().get(0).getStrategy()));\n List createConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.CREATE);\n List dropConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.DROP);\n if (CollectionUtils.isNotEmpty(createConfigs)) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(createConfigs,\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n if (CollectionUtils.isNotEmpty(dropConfigs)) {\n ScheduleEntity dropScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getDroppingTrigger());\n createPartitionPlanTables(dropConfigs,\n partitionPlanEntity.getId(), dropScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "submit_bothCreateAndDropTrigger_submitSucceed", + "reference": " @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n PartitionPlanTableConfig tableConfig = new PartitionPlanTableConfig();\n tableConfig.setTableName(MYSQL_REAL_RANGE_TABLE_NAME);\n tableConfig.setPartitionNameInvoker(\"CUSTOM_PARTITION_NAME_GENERATOR\");\n SqlExprBasedGeneratorConfig config = new SqlExprBasedGeneratorConfig();\n config.setGenerateExpr(\"concat('p', date_format(from_unixtime(unix_timestamp(\"\n + \"STR_TO_DATE(20240125, '%Y%m%d')) + \"\n + PartitionPlanVariableKey.INTERVAL.getVariable() + \"), '%Y%m%d'))\");\n config.setIntervalGenerateExpr(\"86400\");\n tableConfig.setPartitionNameInvokerParameters(getSqlExprBasedNameGeneratorParameters(config));\n PartitionPlanKeyConfig c3Create = getMysqlc3CreateConfig();\n PartitionPlanKeyConfig datekeyCreate = getMysqldatekeyCreateConfig();\n PartitionPlanKeyConfig dropConfig = getDropConfig();\n tableConfig.setPartitionKeyConfigs(Arrays.asList(c3Create, datekeyCreate, dropConfig));\n\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setPartitionTableConfigs(Collections.singletonList(tableConfig));\n partitionPlanConfig.setFlowInstanceId(1L);\n partitionPlanConfig.setTimeoutMillis(180000L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setDatabaseId(1L);\n\n long t1 = System.currentTimeMillis();\n long t2 = System.currentTimeMillis() + 1000;\n TriggerConfig createTrigger = TestRandom.nextObject(TriggerConfig.class);\n createTrigger.setStartAt(new Date(t1));\n partitionPlanConfig.setCreationTrigger(createTrigger);\n\n TriggerConfig dropTrigger = TestRandom.nextObject(TriggerConfig.class);\n dropTrigger.setStartAt(new Date(t2));\n partitionPlanConfig.setDroppingTrigger(dropTrigger);\n Database database = TestRandom.nextObject(Database.class);\n database.setId(1L);\n Mockito.when(this.databaseService.detail(1L)).thenReturn(database);\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t1;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t2;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.doNothing().when(this.scheduleService).enable(Mockito.isA(ScheduleEntity.class));\n this.partitionPlanScheduleService.submit(partitionPlanConfig);\n List actuals = this.partitionPlanTableRepository.findAll();\n Set scheduleIds = actuals.stream()\n .map(PartitionPlanTableEntity::getScheduleId).collect(Collectors.toSet());\n Assert.assertEquals(scheduleIds.size(), 2);\n }", + "completions": [ + "```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n // Given\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setDatabaseId(1L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanConfig.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanConfig.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanConfig.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanConfig.setTaskId(\"taskId\");\n partitionPlanConfig.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanConfig.setTimeoutMillis(1000L);\n\n Database database = new Database();\n database.setId(1L);\n database.setName(\"database\");\n database.setType(DatabaseType.MYSQL);\n database.setHost(\"localhost\");\n database.setPort(3306);\n database.setUsername(\"root\");\n database.setPassword(\"password\");\n database.setDatabaseName(\"test\");\n database.setProperties(\"{}\");\n database.setStatus(DatabaseStatus.ONLINE);\n database.setCreateTime(new Date());\n database.setUpdateTime(new Date());\n when(databaseService.detail(1L)).thenReturn(database);\n\n PartitionPlanEntity partitionPlanEntity = new PartitionPlanEntity();\n partitionPlanEntity.setId(1L);\n partitionPlanEntity.setDatabaseId(1L);\n partitionPlanEntity.setEnabled(true);\n partitionPlanEntity.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanEntity.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanEntity.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanEntity.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanEntity.setTaskId(\"taskId\");\n partitionPlanEntity.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanEntity.setTimeoutMillis(1000L);\n when(partitionPlanRepository.save(any(PartitionPlanEntity.class))).thenReturn(partitionPlanEntity);\n\n ScheduleEntity createScheduleEntity = new ScheduleEntity();\n createScheduleEntity.setId(1L);\n createScheduleEntity.setDatabaseId(1L);\n createScheduleEntity.setTriggerConfig(new TriggerConfig(TriggerType" + ], + "line": 70, + "token": 530, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 41, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAllQueryWithFullTextSearchDisabled", + "reference": " @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n Instant time = Instant.now();\n Trace.KeyValue customField =\n Trace.KeyValue.newBuilder()\n .setVStr(\"value\")\n .setKey(\"customField\")\n .setFieldType(Schema.SchemaFieldType.KEYWORD)\n .build();\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time, List.of(customField)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n SearchResult termQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"customField:value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(termQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermStrQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermStrQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermNumericQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermNumericQuery.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"*\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNull(result.getInternalAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 42, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollExecutedCommitAsyncCallbacks", + "reference": " @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks() {\n consumer = newConsumer();\n MockCommitCallback callback = new MockCommitCallback();\n completeCommitAsyncApplicationEventSuccessfully();\n doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n completeFetchedCommittedOffsetApplicationEventSuccessfully(mkMap());\n\n consumer.assign(Collections.singleton(new TopicPartition(\"foo\", 0)));\n assertDoesNotThrow(() -> consumer.commitAsync(new HashMap<>(), callback));\n assertMockCommitCallbackInvoked(() -> consumer.poll(Duration.ZERO),\n callback,\n null);\n }", + "completions": [ + "```java\n @Test\n public void testPoll() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertNotNull(records);\n assertEquals(0, records.count());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 43, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static URI resolve(URI base, URI ref) {\n String path = ref.getRawPath();\n if (path != null && !path.isEmpty()) {\n path = base.getRawPath();\n if (path == null || !path.endsWith(\"/\")) {\n try {\n base = new URI(base.getScheme(), base.getAuthority(), base.getPath() + '/', null, null);\n } catch (URISyntaxException e) {\n throw new IllegalStateException(e);\n }\n }\n }\n return URIUtils.resolve(base, ref);\n }\n```\n###test function signature:\n```java\n @Test\n void testResolve_BasePathTrailingSlash()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testResolve_BasePathTrailingSlash", + "reference": " @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://host/sub/dir/\");\n assertEquals(\"http://host/sub/dir/file.jar\", resolve(base, \"file.jar\"));\n assertEquals(\"http://host/sub/dir/dir/file.jar\", resolve(base, \"dir/file.jar\"));\n assertEquals(\"http://host/sub/dir/?arg=val\", resolve(base, \"?arg=val\"));\n assertEquals(\"http://host/sub/dir/file?arg=val\", resolve(base, \"file?arg=val\"));\n assertEquals(\"http://host/sub/dir/dir/file?arg=val\", resolve(base, \"dir/file?arg=val\"));\n }", + "completions": [ + "```java\n @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://example.com/path/\");\n URI ref = URI.create(\"http://example.com/path/file.txt\");\n URI expected = URI.create(\"http://example.com/path/file.txt\");\n URI actual = resolve(base, ref);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 25, + "token": 187, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_collection_of_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_collection_of_non_serializable_object", + "reference": " @Test\n public void should_clone_collection_of_non_serializable_object() {\n List original = new ArrayList<>();\n original.add(new NonSerializableObject(\"value\"));\n List cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_collection_of_non_serializable_object() {\n // Given\n List nonSerializableList = new ArrayList<>();\n nonSerializableList.add(new NonSerializableObject(\"test\"));\n nonSerializableList.add(new NonSerializableObject(\"test2\"));\n\n // When\n List clonedList = objectCloner.clone(nonSerializableList);\n\n // Then\n assertNotSame(nonSerializableList, clonedList);\n assertEquals(nonSerializableList.size(), clonedList.size());\n for (int i = 0; i < nonSerializableList.size(); i++) {\n assertNotSame(nonSerializableList.get(i), clonedList.get(i));\n assertEquals(nonSerializableList.get(i), clonedList.get(i));\n }\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 45, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssign()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssign", + "reference": " @Test\n public void testAssign() {\n consumer = newConsumer();\n final TopicPartition tp = new TopicPartition(\"foo\", 3);\n consumer.assign(singleton(tp));\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().contains(tp));\n verify(applicationEventHandler).add(any(AssignmentChangeEvent.class));\n verify(applicationEventHandler).add(any(NewTopicsMetadataUpdateRequestEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testAssign() {\n // Create a mock Consumer instance\n Consumer mockConsumer = mock(Consumer.class);\n\n // Create a collection of TopicPartition instances\n Collection partitions = new ArrayList<>();\n partitions.add(new TopicPartition(\"testTopic\", 0));\n partitions.add(new TopicPartition(\"testTopic\", 1));\n\n // Call the assign method\n mockConsumer.assign(partitions);\n\n // Verify that the assign method was called with the correct parameters\n verify(mockConsumer).assign(partitions);\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 46, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String relativePath(Path absolutePath, Path basePath) {\n Preconditions.checkArgument(absolutePath.isAbsolute(), \"absolutePath must be an absolute path\");\n Preconditions.checkArgument(basePath.isAbsolute(), \"basePath must be an absolute path\");\n\n List absolutePathComponents = Lists.newArrayList(toPathComponents(absolutePath)); // make a copy\n List basePathComponents = toPathComponents(basePath);\n boolean hasCommonPrefix = basePathComponents.isEmpty(); // when basePath is \"/\", we always have a common prefix\n\n for (String base : basePathComponents) {\n if (absolutePathComponents.get(0).equals(base)) {\n absolutePathComponents.remove(0);\n hasCommonPrefix = true;\n } else {\n break;\n }\n }\n if (hasCommonPrefix) {\n return PATH_JOINER.join(absolutePathComponents);\n } else {\n return absolutePath.toString();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testRelativePath() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRelativePath", + "reference": " @Test\n public void testRelativePath() throws Exception {\n assertEquals(\"a\", PathUtils.relativePath(Path.of(\"/a\"), Path.of(\"/\")));\n assertEquals(\"b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/a\")));\n assertEquals(\"b/c.json\", PathUtils.relativePath(Path.of(\"/a/b/c.json\"), Path.of(\"/a\")));\n assertEquals(\"c/d/e\", PathUtils.relativePath(Path.of(\"/a/b/c/d/e\"), Path.of(\"/a/b\")));\n assertEquals(\"/a/b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/c/d\"))); // no common prefix\n }", + "completions": [ + "```java\n @Test\n public void testRelativePath() throws Exception {\n Path absolutePath = Paths.get(\"/a/b/c/d\");\n Path basePath = Paths.get(\"/a/b\");\n String expected = \"c/d\";\n String actual = relativePath(absolutePath, basePath);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 31, + "token": 234, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Boolean isJson(String value) {\n JsonElement jsonElement;\n try {\n jsonElement = JsonParser.parseString(value);\n } catch (Exception e) {\n return false;\n }\n if (jsonElement == null || jsonElement.isJsonNull()) {\n return false;\n }\n if (!jsonElement.isJsonObject() && !jsonElement.isJsonArray()) {\n return false;\n }\n return true;\n }\n```\n###test function signature:\n```java\n @Test\n void testIsJson()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJson", + "reference": " @Test\n void testIsJson() {\n assertThat(JsonUtil.isJson(\"value\")).isFalse();\n }", + "completions": [ + "```java\n @Test\n void testIsJson() {\n // Test case 1: Valid JSON object\n String jsonObject = \"{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}\";\n assertTrue(isJson(jsonObject));\n\n // Test case 2: Valid JSON array\n String jsonArray = \"[{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}, {\\\"name\\\":\\\"Jane\\\", \\\"age\\\":25, \\\"city\\\":\\\"Chicago\\\"}]\";\n assertTrue(isJson(jsonArray));\n\n // Test case 3: Invalid JSON (not an object or array)\n String invalidJson = \"\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"\";\n assertFalse(isJson(invalidJson));\n\n // Test case 4: Empty string\n String emptyString = \"\";\n assertFalse(isJson(emptyString));\n\n // Test case 5: Null value\n String nullValue = null;\n assertFalse(isJson(nullValue));\n }\n```\n" + ], + "line": 26, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchOnlyHistogram()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchOnlyHistogram", + "reference": " @Test\n public void testSearchOnlyHistogram() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 0,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(0);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(2);\n\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(1).getDocCount()).isEqualTo(1);\n\n assertThat(\n Long.parseLong(histogram.getBuckets().get(0).getKeyAsString()) >= time.toEpochMilli())\n .isTrue();\n assertThat(\n Long.parseLong(histogram.getBuckets().get(1).getKeyAsString())\n <= time.plusSeconds(10).toEpochMilli())\n .isTrue();\n }", + "completions": [ + "```java\n @Test\n public void testSearchOnlyHistogram() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 0;\n AggBuilder aggBuilder = new AggBuilder();\n aggBuilder.addHistogramAgg(\"testField\", \"testInterval\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation().getHistogramAggs().size() > 0);\n assertEquals(\"testField\", result.getAggregation().getHistogramAggs().get(0).getField());\n assertEquals(\"testInterval\", result.getAggregation().getHistogramAggs().get(0).getInterval());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 49, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent", + "reference": " @Test\n public void testReadComponent() {\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n EntityMetadata entityMetadata = mock(EntityMetadata.class);\n VariableSource variableSource = mock(VariableSource.class);\n when(contentPermissionChecker.isPermitted(anyString(), any(), eq(BreadActions.BROWSE), any())).thenReturn(true);\n when(component.getEntityMetadata()).thenReturn(entityMetadata);\n when(entityMetadata.getId()).thenReturn(new DetachedEntityId(\"someid\"));\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(Arrays.asList(asset));\n when(assetVariableResolver.fromAsset(asset)).thenReturn(variableSource);\n ComponentXO componentXO = underTest.readComponent(\"someid\", \"testRepositoryName\");\n\n assertThat(componentXO, is(notNullValue()));\n assertThat(componentXO.getId(), is(\"someid\"));\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n Supplier txSupplier = () -> storageTx;\n\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(txSupplier);\n when(storageTx.findComponent(componentId)).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(ImmutableList.of(asset));\n\n // When\n ComponentXO result = readComponent(repository, componentId);\n\n // Then\n verify(repository).facet(StorageFacet.class);\n verify(storageFacet).txSupplier();\n verify(storageTx).begin();\n verify(storageTx).findComponent(componentId);\n verify(storageTx).browseAssets(component);\n verify(storageTx).close();\n // Add assertions for the result\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 50, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ListResult listEntities(String tableName, String searchIndexName, List matchFilters,\n List queryFilters, List multiMatchFilter, String nextToken, List sorters, Class clazz) {\n SearchQuery searchQuery = createSearchQuery(matchFilters, queryFilters, multiMatchFilter);\n SearchRequest searchRequest = new SearchRequest(tableName, searchIndexName, searchQuery);\n if (!StringUtils.isEmpty(nextToken)) {\n byte[] tokenBytes = EncryptionUtil.decode(nextToken);\n searchRequest.getSearchQuery().setToken(tokenBytes);\n } else {\n if (sorters != null &&!sorters.isEmpty()) {\n searchQuery.setSort(new Sort(sorters));\n }\n }\n SearchRequest.ColumnsToGet columnsToGet = new SearchRequest.ColumnsToGet();\n columnsToGet.setColumns(ReflectionUtil.getPropertyNames(clazz));\n searchRequest.setColumnsToGet(columnsToGet);\n log.info(\"searchRequest:{}\", JsonUtil.toJsonString(searchRequest));\n SearchResponse searchResponse = otsClient.search(searchRequest);\n if (searchResponse == null || searchResponse.getRows() == null) {\n return ListResult.genSuccessListResult(null, 0);\n }\n byte[] nextTokenBytes = searchResponse.getNextToken();\n nextToken = nextTokenBytes == null || nextTokenBytes.length == 0 ? null : EncryptionUtil.encode(nextTokenBytes);\n List result = searchResponse.getRows().stream()\n .map(row -> OtsUtil.convertRowToDTO(row, clazz))\n .collect(Collectors.toList());\n return ListResult.genSuccessListResult(result, searchResponse.getTotalCount(), nextToken);\n }\n```\n###test function signature:\n```java\n @Test\n void testListEntities()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListEntities", + "reference": " @Test\n void testListEntities() {\n final List matchFilters = createMatchFilters();\n final List queryFilters = createQueryFilters();\n final ListResult expectedResult = new ListResult<>();\n FieldSort fieldSort = new FieldSort(OrderOtsConstant.GMT_CREATE_LONG);\n fieldSort.setOrder(SortOrder.DESC);\n expectedResult.setData(Arrays.asList());\n expectedResult.setCount(0L);\n String nextToken = \"CAESFQoTChEKDWdtdENyZWF0ZUxvbmcQARgBIlQKCQBI8UqGigEAAApHA0IAAAAxUzM1MzQzMTM0NjQzMjYzMzAzMzYyMzE2MTMzMzkzOTM1MzEzNjM2MzM2NDM2MzAzMDMwNjYzNTM1MzA2NjY0MzM=\";\n expectedResult.setNextToken(nextToken);\n final SearchResponse searchResponse = new SearchResponse(new Response(\"requestId\"));\n when(mockOtsClient.search(any(SearchRequest.class))).thenReturn(searchResponse);\n ListResult result = baseOtsHelper.listEntities(\"order\", \"order_index\", matchFilters, queryFilters, null, nextToken, Arrays.asList(fieldSort), OrderDTO.class);\n assertThat(result.getCount()).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n void testListEntities() {\n // Given\n String tableName = \"testTable\";\n String searchIndexName = \"testIndex\";\n List matchFilters = Arrays.asList(new OtsFilter(\"field1\", OtsFilter.CompareOperator.EQUAL, \"value1\"),\n new OtsFilter(\"field2\", OtsFilter.CompareOperator.EQUAL, \"value2\"));\n List queryFilters = Arrays.asList(new OtsFilter(\"field3\", OtsFilter.CompareOperator.EQUAL, \"value3\"),\n new OtsFilter(\"field4\", OtsFilter.CompareOperator.EQUAL, \"value4\"));\n List multiMatchFilter = Arrays.asList(new OtsFilter(\"field5\", OtsFilter.CompareOperator.EQUAL, \"value5\"),\n new OtsFilter(\"field6\", OtsFilter.CompareOperator.EQUAL, \"value6\"));\n String nextToken = \"testNextToken\";\n List sorters = Arrays.asList(new Sort.Sorter(\"field7\", Sort.SortOrder.ASC),\n new Sort.Sorter(\"field8\", Sort.SortOrder.DESC));\n Class clazz = TestDTO.class;\n\n // When\n ListResult result = listEntities(tableName, searchIndexName, matchFilters, queryFilters, multiMatchFilter, nextToken, sorters, clazz);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalCount());\n assertNull(result.getNextToken());\n assertTrue(result.getData().isEmpty());\n }\n```\n" + ], + "line": 38, + "token": 346, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 51, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String[] listFiles(URI fileUri, boolean recursive) throws IOException {\n try {\n ImmutableList.Builder builder = ImmutableList.builder();\n String continuationToken = null;\n boolean isDone = false;\n String prefix = normalizeToDirectoryPrefix(fileUri);\n int fileCount = 0;\n while (!isDone) {\n ListObjectsV2Request.Builder listObjectsV2RequestBuilder =\n ListObjectsV2Request.builder().bucket(fileUri.getHost());\n if (!prefix.equals(DELIMITER)) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.prefix(prefix);\n }\n if (!recursive) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.delimiter(DELIMITER);\n }\n if (continuationToken != null) {\n listObjectsV2RequestBuilder.continuationToken(continuationToken);\n }\n ListObjectsV2Request listObjectsV2Request = listObjectsV2RequestBuilder.build();\n LOG.debug(\"Trying to send ListObjectsV2Request {}\", listObjectsV2Request);\n ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);\n LOG.debug(\"Getting ListObjectsV2Response: {}\", listObjectsV2Response);\n List filesReturned = listObjectsV2Response.contents();\n fileCount += filesReturned.size();\n filesReturned.stream()\n .forEach(\n object -> {\n // Only add files and not directories\n if (!object.key().equals(fileUri.getPath())\n && !object.key().endsWith(DELIMITER)) {\n String fileKey = object.key();\n if (fileKey.startsWith(DELIMITER)) {\n fileKey = fileKey.substring(1);\n }\n builder.add(S3_SCHEME + fileUri.getHost() + DELIMITER + fileKey);\n }\n });\n if (fileCount == LIST_MAX_KEYS) {\n // check if we reached the max keys returned, if so abort and throw an error message\n LOG.error(\n \"Too many files ({}) returned from S3 when attempting to list object prefixes\",\n LIST_MAX_KEYS);\n throw new IllegalStateException(\n String.format(\n \"Max keys (%s) reached when attempting to list S3 objects\", LIST_MAX_KEYS));\n }\n isDone = !listObjectsV2Response.isTruncated();\n continuationToken = listObjectsV2Response.nextContinuationToken();\n }\n String[] listedFiles = builder.build().toArray(new String[0]);\n LOG.debug(\n \"Listed {} files from URI: {}, is recursive: {}\", listedFiles.length, fileUri, recursive);\n return listedFiles;\n } catch (Throwable t) {\n throw new IOException(t);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListFilesInFolderRecursive", + "reference": " @Test\n public void testListFilesInFolderRecursive() throws Exception {\n String folder = \"list-files-rec\";\n String[] nestedFolders = new String[] {\"list-files-child-1\", \"list-files-child-2\"};\n String[] originalFiles = new String[] {\"a-list-3.txt\", \"b-list-3.txt\", \"c-list-3.txt\"};\n\n List expectedResultList = new ArrayList<>();\n for (String childFolder : nestedFolders) {\n String folderName = folder + DELIMITER + childFolder;\n for (String fileName : originalFiles) {\n createEmptyFile(folderName, fileName);\n expectedResultList.add(\n String.format(FILE_FORMAT, SCHEME, bucket, folderName + DELIMITER + fileName));\n }\n }\n String[] actualFiles =\n s3BlobFs.listFiles(URI.create(String.format(FILE_FORMAT, SCHEME, bucket, folder)), true);\n\n actualFiles =\n Arrays.stream(actualFiles).filter(x -> x.contains(\"list-3\")).toArray(String[]::new);\n assertEquals(actualFiles.length, expectedResultList.size());\n assertTrue(Arrays.equals(expectedResultList.toArray(), actualFiles));\n }", + "completions": [ + "```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception {\n // Given\n URI fileUri = new URI(\"s3://bucket/folder/\");\n boolean recursive = true;\n String[] expectedFiles = {\n \"s3://bucket/folder/file1.txt\",\n \"s3://bucket/folder/file2.txt\",\n \"s3://bucket/folder/subfolder/file3.txt\",\n \"s3://bucket/folder/subfolder/file4.txt\"\n };\n\n // When\n String[] actualFiles = s3FileSystem.listFiles(fileUri, recursive);\n\n // Then\n assertArrayEquals(expectedFiles, actualFiles);\n }\n```\n" + ], + "line": 70, + "token": 490, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 52, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_nameDuplicate", + "reference": " @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // mock 数据\n RoleDO roleDO = randomPojo(RoleDO.class, o -> o.setName(\"role_name\"));\n roleMapper.insert(roleDO);\n // 准备参数\n String name = \"role_name\";\n\n // 调用,并断言异常\n assertServiceException(() -> roleService.validateRoleDuplicate(name, randomString(), null),\n ROLE_NAME_DUPLICATE, name);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // Given\n String name = \"duplicateName\";\n String code = \"uniqueCode\";\n Long id = 1L;\n RoleDO existingRole = new RoleDO();\n existingRole.setId(2L);\n existingRole.setName(name);\n existingRole.setCode(code);\n when(roleMapper.selectByName(name)).thenReturn(existingRole);\n\n // When\n try {\n roleService.validateRoleDuplicate(name, code, id);\n fail(\"Expected exception not thrown\");\n } catch (ServiceException e) {\n // Then\n assertEquals(ROLE_NAME_DUPLICATE, e.getCode());\n assertEquals(name, e.getArgs()[0]);\n }\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public String sign(Map map) {\n try {\n Date date = new Date(System.currentTimeMillis() + jwtProperties.getExpireTimeMills());\n Algorithm algorithm = Algorithm.HMAC256(jwtProperties.getTokenSecret());\n Map header = new HashMap<>(2);\n header.put(\"typ\", \"jwt\");\n JWTCreator.Builder builder = JWT.create()\n .withHeader(header)\n .withIssuedAt(new Date())\n .withExpiresAt(date);\n map.entrySet().forEach(entry -> {\n if (entry.getValue() instanceof Integer) {\n builder.withClaim(entry.getKey(), (Integer) entry.getValue());\n } else if (entry.getValue() instanceof Long) {\n builder.withClaim(entry.getKey(), (Long) entry.getValue());\n } else if (entry.getValue() instanceof Boolean) {\n builder.withClaim(entry.getKey(), (Boolean) entry.getValue());\n } else if (entry.getValue() instanceof String) {\n builder.withClaim(entry.getKey(), String.valueOf(entry.getValue()));\n } else if (entry.getValue() instanceof Double) {\n builder.withClaim(entry.getKey(), (Double) entry.getValue());\n } else if (entry.getValue() instanceof Date) {\n builder.withClaim(entry.getKey(), (Date) entry.getValue());\n }\n });\n return builder.sign(algorithm);\n } catch (Exception e) {\n return null;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sign_generateJwtToken_succeed()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sign_generateJwtToken_succeed", + "reference": " @Test\n public void sign_generateJwtToken_succeed() {\n Map map = new HashMap<>();\n map.put(\"userId\", 123);\n map.put(\"username\", \"testUser\");\n String token = jwtService.sign(map);\n Assert.assertNotNull(token);\n }", + "completions": [ + "```java\n @Test\n public void sign_generateJwtToken_succeed() {\n // Given\n Map map = new HashMap<>();\n map.put(\"userId\", 1);\n map.put(\"username\", \"testUser\");\n map.put(\"isAdmin\", true);\n map.put(\"expireTime\", System.currentTimeMillis() + 3600000);\n\n // When\n String jwtToken = jwtService.sign(map);\n\n // Then\n assertNotNull(jwtToken);\n assertTrue(jwtToken.length() > 0);\n }\n```\n" + ], + "line": 41, + "token": 349, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 54, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_non_serializable_value_with_null_value", + "reference": " @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map original = new LinkedHashMap<>();\n original.put(\"null\", null);\n original.put(\"key\", new NonSerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map map = new HashMap<>();\n map.put(\"key\", null);\n\n Map clonedMap = (Map) clone(map);\n\n assertNotSame(map, clonedMap);\n assertEquals(map, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void init(final DiscoveryConfig config) {\n if (this.namingService != null) {\n LOGGER.info(\"Nacos naming service already registered\");\n return;\n }\n Properties properties = config.getProps();\n Properties nacosProperties = new Properties();\n this.groupName = properties.getProperty(\"groupName\", \"SHENYU_GROUP\");\n String serverAddr = config.getServerList();\n nacosProperties.put(PropertyKeyConst.SERVER_ADDR, serverAddr);\n nacosProperties.put(PropertyKeyConst.NAMESPACE, properties.getProperty(NAMESPACE, \"\"));\n nacosProperties.put(PropertyKeyConst.USERNAME, properties.getProperty(PropertyKeyConst.USERNAME, \"\"));\n nacosProperties.put(PropertyKeyConst.PASSWORD, properties.getProperty(PropertyKeyConst.PASSWORD, \"\"));\n nacosProperties.put(PropertyKeyConst.ACCESS_KEY, properties.getProperty(PropertyKeyConst.ACCESS_KEY, \"\"));\n nacosProperties.put(PropertyKeyConst.SECRET_KEY, properties.getProperty(PropertyKeyConst.SECRET_KEY, \"\"));\n try {\n this.namingService = NamingFactory.createNamingService(nacosProperties);\n LOGGER.info(\"Nacos naming service initialized success\");\n } catch (NacosException e) {\n LOGGER.error(\"Error initializing Nacos naming service\", e);\n throw new ShenyuException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testInit", + "reference": " @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n // Set the discovery config\n setField(nacosDiscoveryServiceUnderTest.getClass(), \"namingService\", null);\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n config.setServerList(\"127.0.0.1:8848\");\n properties.setProperty(\"groupName\", \"SHENYU_GROUP\");\n config.setProps(properties);\n\n try (MockedStatic mockedNamingFactory = mockStatic(NamingFactory.class)) {\n // Mock the successful creation of NamingService\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenReturn(namingService);\n nacosDiscoveryServiceUnderTest.init(config);\n mockedNamingFactory.verify(() -> NamingFactory.createNamingService(any(Properties.class)));\n assertEquals(namingService, getField(nacosDiscoveryServiceUnderTest, \"namingService\"));\n // Mock the situation where NamingService fails to be created and throws an exception\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenThrow(new NacosException());\n assertDoesNotThrow(() -> nacosDiscoveryServiceUnderTest.init(config));\n }\n }", + "completions": [ + "```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n properties.put(\"groupName\", \"SHENYU_GROUP\");\n properties.put(\"serverAddr\", \"localhost:8848\");\n properties.put(\"namespace\", \"\");\n properties.put(\"username\", \"\");\n properties.put(\"password\", \"\");\n properties.put(\"accessKey\", \"\");\n properties.put(\"secretKey\", \"\");\n config.setProps(properties);\n config.setServerList(\"localhost:8848\");\n\n NacosDiscoveryService nacosDiscoveryService = new NacosDiscoveryService();\n nacosDiscoveryService.init(config);\n\n Field namingServiceField = NacosDiscoveryService.class.getDeclaredField(\"namingService\");\n namingServiceField.setAccessible(true);\n NamingService namingService = (NamingService) namingServiceField.get(nacosDiscoveryService);\n\n assertNotNull(namingService);\n }\n```\n" + ], + "line": 35, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 56, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyRecoveryNotBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(850, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 2) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(750, 0, indexerConfig));\n\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n assertThat(recoveryTaskCreator.determineStartingOffset(1201, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTaskCreator.determineStartingOffset(1200, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(1150, 0, indexerConfig));\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore))\n .contains(recoveryTask1, recoveryTask11);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 57, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public boolean isJsonMime(String mediaType) {\n // \"* / *\" is default to JSON\n if (\"*/*\".equals(mediaType)) {\n return true;\n }\n\n try {\n return isJsonMime(MediaType.parseMediaType(mediaType));\n } catch (InvalidMediaTypeException e) {\n }\n return false;\n }\n```\n###test function signature:\n```java\n @Test\n public void testIsJsonMime()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJsonMime", + "reference": " @Test\n public void testIsJsonMime() {\n assertFalse(apiClient.isJsonMime(null));\n assertFalse(apiClient.isJsonMime(\"\"));\n assertFalse(apiClient.isJsonMime(\"text/plain\"));\n assertFalse(apiClient.isJsonMime(\"application/xml\"));\n assertFalse(apiClient.isJsonMime(\"application/jsonp\"));\n assertFalse(apiClient.isJsonMime(\"example/json\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+jsonx\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+xjson\"));\n\n assertTrue(apiClient.isJsonMime(\"application/json\"));\n assertTrue(apiClient.isJsonMime(\"application/json; charset=UTF8\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/JSON\"));\n\n assertTrue(apiClient.isJsonMime(\"application/problem+json\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/PROBLEM+JSON\"));\n assertTrue(apiClient.isJsonMime(\"application/json\\t\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+bar+json\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json;x;y\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json\\t;\"));\n assertTrue(apiClient.isJsonMime(\"Example/fOO+JSON\"));\n }", + "completions": [ + "```java\n @Test\n public void testIsJsonMime() {\n // Test case 1: \"* / *\"\n assertTrue(isJsonMime(\"*/*\"));\n\n // Test case 2: \"application/json\"\n assertTrue(isJsonMime(\"application/json\"));\n\n // Test case 3: \"application/json; charset=UTF-8\"\n assertTrue(isJsonMime(\"application/json; charset=UTF-8\"));\n\n // Test case 4: \"application/xml\"\n assertFalse(isJsonMime(\"application/xml\"));\n\n // Test case 5: \"text/plain\"\n assertFalse(isJsonMime(\"text/plain\"));\n\n // Test case 6: \"invalid/media/type\"\n assertFalse(isJsonMime(\"invalid/media/type\"));\n\n // Test case 7: null\n assertFalse(isJsonMime(null));\n }\n```\n" + ], + "line": 22, + "token": 139, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n dpopService.validateDpopThumprint(\"test\", \"test\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n // Given\n String existingThumprint = \"existingThumprint\";\n String requestThumprint = \"existingThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n // Call the method to test\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String getExpandUrl(String url, Object paramObj) {\n UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url);\n if (paramObj instanceof Map) {\n for (Map.Entry param : ((Map) paramObj).entrySet()) {\n if (null != param.getValue()) {\n builder.queryParam(param.getKey(), param.getValue());\n }\n }\n } else {\n for (Field param : paramObj.getClass().getDeclaredFields()) {\n param.setAccessible(true);\n\n try {\n Object value = param.get(paramObj);\n if (null != value) {\n if (value instanceof List) {\n builder.queryParam(param.getName(), (List) value);\n } else {\n builder.queryParam(param.getName(), value);\n }\n }\n } catch (IllegalAccessException e) {\n log.warn(\"Field: {} value retrieval failed, exception: {}\", param.getName(), e);\n }\n }\n }\n\n return builder.build().encode().toUri().toString();\n }\n```\n###test function signature:\n```java\n @Test\n void testGetExpandUrl()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetExpandUrl", + "reference": " @Test\n void testGetExpandUrl() {\n String result = HttpUtil.getExpandUrl(\"http://baidu.com\", \"paramObj\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testGetExpandUrl() {\n // Test with a Map\n Map map = new HashMap<>();\n map.put(\"param1\", \"value1\");\n map.put(\"param2\", \"value2\");\n String url = \"http://example.com\";\n String expected = \"http://example.com?param1=value1¶m2=value2\";\n String actual = getExpandUrl(url, map);\n assertEquals(expected, actual);\n\n // Test with a custom object\n TestObject obj = new TestObject(\"value1\", \"value2\");\n expected = \"http://example.com?param1=value1¶m2=value2\";\n actual = getExpandUrl(url, obj);\n assertEquals(expected, actual);\n }\n\n public static class TestObject {\n private String param1;\n private String param2;\n\n public TestObject(String param1, String param2) {\n this.param1 = param1;\n this.param2 = param2;\n }\n\n public String getParam1() {\n return param1;\n }\n\n public String getParam2() {\n return param2;\n }\n }\n```\n" + ], + "line": 38, + "token": 277, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 60, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"asdf\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n String invalidExpirationTime = \"invalid\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(invalidExpirationTime);\n assertEquals(CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 61, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Get\n @Path(\"/:indexName/_mapping\")\n public HttpResponse mapping(\n @Param(\"indexName\") Optional indexName,\n @Param(\"startTimeEpochMs\") Optional startTimeEpochMs,\n @Param(\"endTimeEpochMs\") Optional endTimeEpochMs)\n throws IOException {\n // Use a tree map so the results are naturally sorted\n Map> propertiesMap = new TreeMap<>();\n\n // we default the schema search to the last hour if params are not provided\n AstraSearch.SchemaResult schemaResult =\n searcher.getSchema(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(\n startTimeEpochMs.orElse(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build());\n\n Map schema = SearchResultUtils.fromSchemaResultProto(schemaResult);\n schema.forEach((key, value) -> propertiesMap.put(key, Map.of(\"type\", value.getName())));\n\n // todo - remove this after we add support for a \"date\" type\n // override the timestamp as a date field for proper autocomplete\n propertiesMap.put(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, Map.of(\"type\", \"date\"));\n\n return HttpResponse.of(\n HttpStatus.OK,\n MediaType.JSON,\n JsonUtil.writeAsString(\n ImmutableMap.of(\n indexName.orElseThrow(),\n ImmutableMap.of(\"mappings\", ImmutableMap.of(\"properties\", propertiesMap)))));\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexMapping() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexMapping", + "reference": " @Test\n public void testIndexMapping() throws IOException {\n AstraQueryServiceBase searcher = mock(AstraQueryServiceBase.class);\n ElasticsearchApiService serviceUnderTest = new ElasticsearchApiService(searcher);\n\n Instant start = Instant.now();\n Instant end = start.minusSeconds(60);\n\n when(searcher.getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build())))\n .thenReturn(AstraSearch.SchemaResult.newBuilder().build());\n\n HttpResponse response =\n serviceUnderTest.mapping(\n Optional.of(\"foo\"), Optional.of(start.toEpochMilli()), Optional.of(end.toEpochMilli()));\n verify(searcher)\n .getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build()));\n\n // handle response\n AggregatedHttpResponse aggregatedRes = response.aggregate().join();\n String body = aggregatedRes.content(StandardCharsets.UTF_8);\n JsonNode jsonNode = new ObjectMapper().readTree(body);\n\n assertThat(aggregatedRes.status().code()).isEqualTo(200);\n\n assertThat(jsonNode.findValue(\"foo\")).isNotNull();\n assertThat(\n jsonNode.findValue(\"foo\").findValue(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName))\n .isNotNull();\n\n when(searcher.getSchema(any()))\n .thenAnswer(\n invocationOnMock -> {\n AstraSearch.SchemaRequest request =\n ((AstraSearch.SchemaRequest) invocationOnMock.getArguments()[0]);\n assertThat(request.getDataset()).isEqualTo(\"bar\");\n assertThat(request.getStartTimeEpochMs())\n .isCloseTo(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli(),\n Offset.offset(1000L));\n assertThat(request.getEndTimeEpochMs())\n .isCloseTo(Instant.now().toEpochMilli(), Offset.offset(1000L));\n return AstraSearch.SchemaResult.newBuilder().build();\n });\n serviceUnderTest.mapping(Optional.of(\"bar\"), Optional.empty(), Optional.empty());\n }", + "completions": [ + "```java\n @Test\n public void testIndexMapping() throws IOException {\n Optional indexName = Optional.of(\"testIndex\");\n Optional startTimeEpochMs = Optional.of(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli());\n Optional endTimeEpochMs = Optional.of(Instant.now().toEpochMilli());\n\n AstraSearch.SchemaResult schemaResult =\n AstraSearch.SchemaResult.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(startTimeEpochMs.orElse(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build();\n\n when(searcher.getSchema(any(AstraSearch.SchemaRequest.class))).thenReturn(schemaResult);\n\n HttpResponse response = mapping(indexName, startTimeEpochMs, endTimeEpochMs);\n\n assertEquals(HttpStatus.OK, response.getStatus());\n assertEquals(MediaType.JSON, response.getContentType());\n assertNotNull(response.getContent());\n }\n```\n" + ], + "line": 43, + "token": 354, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "@SuppressWarnings", + "reference": " @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n consumer = newConsumer();\n Exception e = assertThrows(UnsupportedOperationException.class, () -> consumer.poll(0L));\n assertEquals(\"Consumer.poll(long) is not supported when \\\"group.protocol\\\" is \\\"consumer\\\". \" +\n \"This method is deprecated and will be removed in the next major release.\", e.getMessage());\n }", + "completions": [ + "```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 46, + "token": 373, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 63, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testNullSearchString()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNullSearchString", + "reference": " @Test\n public void testNullSearchString() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n assertThatExceptionOfType(IllegalArgumentException.class)\n .isThrownBy(\n () ->\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME + \"miss\",\n null,\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\")));\n }", + "completions": [ + "```java\n @Test\n public void testNullSearchString() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = null;\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n try {\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n fail(\"Expected an IllegalArgumentException to be thrown\");\n } catch (IllegalArgumentException e) {\n // Expected\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 64, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture sendMessageBack(ProxyContext ctx, ReceiptHandle handle, String messageId,\n ConsumerSendMsgBackRequestHeader requestHeader, long timeoutMillis) {\n // Build the response.\n final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null, null);\n\n Integer delayLevel = requestHeader.getDelayLevel();\n if (Objects.isNull(delayLevel)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument delay level is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n Long offset = requestHeader.getOffset();\n if (Objects.isNull(offset)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument offset is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n VirtualQueue virtualQueue = new VirtualQueue(requestHeader.getBname());\n\n CompletableFuture topicFuture = topicOf(requestHeader.getOriginTopic());\n CompletableFuture groupFuture = consumerGroupOf(requestHeader.getGroup());\n return topicFuture.thenCombine(groupFuture, (topic, group) -> {\n if (topic.getTopicId() != virtualQueue.topicId()) {\n LOGGER.error(\"Topic id in request header {} does not match topic id in message queue {}, maybe the topic is recreated.\",\n topic.getTopicId(), virtualQueue.topicId());\n throw new ProxyException(apache.rocketmq.v2.Code.TOPIC_NOT_FOUND, \"Topic resource does not exist.\");\n }\n return Pair.of(topic, group);\n }).thenCompose(pair -> {\n Topic topic = pair.getLeft();\n ConsumerGroup group = pair.getRight();\n\n return store.pull(StoreContext.EMPTY, group.getGroupId(), topic.getTopicId(), virtualQueue.physicalQueueId(),\n Filter.DEFAULT_FILTER, requestHeader.getOffset(), 1, false)\n .thenApply(pullResult -> {\n if (pullResult.status() == com.automq.rocketmq.store.model.message.PullResult.Status.FOUND) {\n return pullResult.messageList().get(0);\n }\n throw new ProxyException(apache.rocketmq.v2.Code.MESSAGE_NOT_FOUND, \"Message not found from server.\");\n }).thenCompose(messageExt -> {\n if (messageExt.deliveryAttempts() > group.getMaxDeliveryAttempt()) {\n return deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n }\n\n // Message consume retry strategy\n // <0: no retry,put into DLQ directly\n // =0: broker control retry frequency\n // >0: client control retry frequency\n return switch (Integer.compare(delayLevel, 0)) {\n case -1 ->\n deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n case 0 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n // Keep the same logic as apache RocketMQ.\n int serverDelayLevel = messageExt.deliveryAttempts() + 1;\n messageExt.setDeliveryAttempts(serverDelayLevel);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(serverDelayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put messageExt to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n case 1 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n messageExt.setDeliveryAttempts(messageExt.deliveryAttempts() + 1);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(delayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put message to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n default -> throw new IllegalStateException(\"Never reach here\");\n };\n });\n }).whenComplete((nil, throwable) -> {\n if (throwable != null) {\n response.setCode(ResponseCode.SYSTEM_ERROR);\n response.setRemark(throwable.getMessage());\n }\n }).thenApply(nil -> response);\n }\n```\n###test function signature:\n```java\n @Test\n void sendMessageBack()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendMessageBack", + "reference": " @Test\n void sendMessageBack() {\n String topicName = \"topic\";\n VirtualQueue virtualQueue = new VirtualQueue(2, 0);\n ConsumerSendMsgBackRequestHeader header = new ConsumerSendMsgBackRequestHeader();\n header.setBname(virtualQueue.brokerName());\n header.setGroup(\"group\");\n header.setOriginTopic(topicName);\n header.setOffset(0L);\n header.setDelayLevel(0);\n\n CompletionException exception = assertThrowsExactly(CompletionException.class, () -> messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join());\n assertInstanceOf(ProxyException.class, ExceptionUtils.getRealException(exception));\n ProxyException realException = (ProxyException) ExceptionUtils.getRealException(exception);\n assertEquals(realException.getErrorCode(), Code.MESSAGE_NOT_FOUND);\n assertEquals(realException.getMessage(), \"Message not found from server.\");\n\n long topicId = metadataService.topicOf(topicName).join().getTopicId();\n messageStore.put(StoreContext.EMPTY, FlatMessageUtil.convertTo(ProxyContextExt.create(), topicId, 0, \"\", new Message(topicName, \"\", new byte[] {})));\n\n // Broker controlled delay level.\n header.setDelayLevel(0);\n RemotingCommand response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Client controlled delay level.\n header.setDelayLevel(16);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Forward message into dead letter topic.\n header.setDelayLevel(-1);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.times(1)).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n }", + "completions": [ + "```java\n @Test\n void sendMessageBack() {\n // Given\n ProxyContext ctx = mock(ProxyContext.class);\n ReceiptHandle handle = mock(ReceiptHandle.class);\n String messageId = \"testMessageId\";\n ConsumerSendMsgBackRequestHeader requestHeader = mock(ConsumerSendMsgBackRequestHeader.class);\n long timeoutMillis = 1000L;\n\n when(requestHeader.getDelayLevel()).thenReturn(1);\n when(requestHeader.getOffset()).thenReturn(1L);\n when(requestHeader.getOriginTopic()).thenReturn(\"testOriginTopic\");\n when(requestHeader.getGroup()).thenReturn(\"testGroup\");\n when(requestHeader.getBname()).thenReturn(\"testBname\");\n\n Topic topic = mock(Topic.class);\n when(topic.getTopicId()).thenReturn(1);\n\n ConsumerGroup group = mock(ConsumerGroup.class);\n when(group.getGroupId()).thenReturn(\"testGroupId\");\n when(group.getMaxDeliveryAttempt()).thenReturn(3);\n\n FlatMessage message = mock(FlatMessage.class);\n when(message.deliveryAttempts()).thenReturn(1);\n when(message.originalOffset()).thenReturn(1L);\n\n MessageExt messageExt = mock(MessageExt.class);\n when(messageExt.message()).thenReturn(message);\n when(messageExt.deliveryAttempts()).thenReturn(1);\n\n PullResult pullResult = mock(PullResult.class);\n when(pullResult.status()).thenReturn(PullResult.Status.FOUND);\n when(pullResult.messageList()).thenReturn(List.of(messageExt));\n\n Store store = mock(Store.class);\n when(store.pull(any(), any(), any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(pullResult));\n when(store.put(any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n DeadLetterService deadLetterService = mock(DeadLetterService.class);\n when(deadLetterService.send(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n ProxyContextExt ctxExt = mock(ProxyContextExt.class);\n\n TopicService topicService = mock(TopicService.class);\n when(topicService.topicOf(any())).thenReturn(CompletableFuture.completedFuture(topic));\n when(topicService.consumerGroupOf(any())).thenReturn(CompletableFuture.completedFuture(group));\n\n ConsumerService consumerService = new ConsumerService(store, deadLetterService, topicService);\n\n // When\n CompletableFuture future = consumerService.sendMessageBack(ctx, handle, messageId, requestHeader, timeoutMillis);\n\n // Then\n RemotingCommand response = future.join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n }\n```\n" + ], + "line": 99, + "token": 793, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 65, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInRequestThreads()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInRequestThreads", + "reference": " @Test\n public void testGetTracerInRequestThreads() {\n ApolloAuditTracer mockTracer = new ApolloAuditTracer(Mockito.mock(ApolloAuditScopeManager.class), supplier);\n RequestAttributes mockRequestAttributes = Mockito.mock(RequestAttributes.class);\n RequestContextHolder.setRequestAttributes(mockRequestAttributes);\n Mockito.when(mockRequestAttributes.getAttribute(Mockito.eq(ApolloAuditConstants.TRACER), Mockito.eq(RequestAttributes.SCOPE_REQUEST)))\n .thenReturn(mockTracer);\n ApolloAuditTracer get = traceContext.tracer();\n assertNotNull(get);\n Mockito.verify(traceContext, Mockito.times(0))\n .setTracer(Mockito.any(ApolloAuditTracer.class));\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInRequestThreads() {\n // Given\n RequestAttributes requestAttributes = Mockito.mock(RequestAttributes.class);\n ApolloAuditTracer tracer = Mockito.mock(ApolloAuditTracer.class);\n Mockito.when(requestAttributes.getAttribute(ApolloAuditConstants.TRACER, RequestAttributes.SCOPE_REQUEST)).thenReturn(tracer);\n Mockito.when(RequestContextHolder.getRequestAttributes()).thenReturn(requestAttributes);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n Assert.assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Map generate(String modelName) {\n Map kv = new HashMap<>();\n if (modelName != null) {\n final Schema schema = this.globalContext.getSchemaMap().get(modelName);\n if (schema != null) {\n String example = Json.pretty(this.resolveModelToExample(modelName, schema));\n\n if (example != null) {\n kv.put(EXAMPLE, example);\n return Map.copyOf(kv);\n }\n }\n }\n return Collections.emptyMap();\n }\n```\n###test function signature:\n```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "shouldGenerateOneOfWhenOneOfInRoot", + "reference": " @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n OpenAPIModelGenerator generator = setupPayloadGenerator();\n\n Map example = generator.generate(\"PetType\");\n String exampleJson = example.get(\"example\");\n\n Assertions.assertThat(exampleJson)\n .contains(\"PetTypeONE_OF#/components/schemas/Husky\")\n .contains(\"PetTypeONE_OF#/components/schemas/Labrador\");\n }", + "completions": [ + "```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n // Given\n String modelName = \"ModelName\";\n Schema schema = new Schema();\n schema.setOneOf(Arrays.asList(new Schema(), new Schema()));\n this.globalContext.getSchemaMap().put(modelName, schema);\n\n // When\n Map result = this.generate(modelName);\n\n // Then\n assertNotNull(result);\n assertFalse(result.isEmpty());\n assertTrue(result.containsKey(EXAMPLE));\n assertNotNull(result.get(EXAMPLE));\n }\n```\n" + ], + "line": 25, + "token": 181, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInAnotherThreadButSameRequest", + "reference": " @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n ApolloAuditTracer mockTracer = Mockito.mock(ApolloAuditTracer.class);\n {\n Mockito.when(traceContext.tracer()).thenReturn(mockTracer);\n }\n CountDownLatch latch = new CountDownLatch(1);\n Executors.newSingleThreadExecutor().submit(() -> {\n ApolloAuditTracer tracer = traceContext.tracer();\n\n assertEquals(mockTracer, tracer);\n\n latch.countDown();\n });\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n // Given\n RequestAttributes requestAttributes = new ServletRequestAttributes(new MockHttpServletRequest());\n RequestContextHolder.setRequestAttributes(requestAttributes);\n ApolloAuditTracer tracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n requestAttributes.setAttribute(ApolloAuditConstants.TRACER, tracer, RequestAttributes.SCOPE_REQUEST);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchMultipleItemsAndIndices()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchMultipleItemsAndIndices", + "reference": " @Test\n public void testSearchMultipleItemsAndIndices() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(1);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchMultipleItemsAndIndices() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(10, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 69, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n static Map>> getCandidateJobs(String tableNameWithType,\n Map> allJobMetadata)\n throws Exception {\n long nowMs = System.currentTimeMillis();\n Map>> candidates = new HashMap<>();\n // If the job started most recently has already completed, then skip retry for the table.\n Pair latestStartedJob = null;\n Pair latestCompletedJob = null;\n // The processing order of job metadata from the given Map is not deterministic. Track the completed original\n // jobs so that we can simply skip the retry jobs belonging to the completed original jobs.\n Map completedOriginalJobs = new HashMap<>();\n Set cancelledOriginalJobs = new HashSet<>();\n for (Map.Entry> entry : allJobMetadata.entrySet()) {\n String jobId = entry.getKey();\n Map jobMetadata = entry.getValue();\n long statsUpdatedAt = Long.parseLong(jobMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));\n String jobStatsInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);\n if (StringUtils.isEmpty(jobStatsInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job progress stats\", jobId);\n continue;\n }\n String jobCtxInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n if (StringUtils.isEmpty(jobCtxInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job context\", jobId);\n continue;\n }\n TableRebalanceProgressStats jobStats = JsonUtils.stringToObject(jobStatsInStr, TableRebalanceProgressStats.class);\n TableRebalanceContext jobCtx = JsonUtils.stringToObject(jobCtxInStr, TableRebalanceContext.class);\n long jobStartTimeMs = jobStats.getStartTimeMs();\n if (latestStartedJob == null || latestStartedJob.getRight() < jobStartTimeMs) {\n latestStartedJob = Pair.of(jobId, jobStartTimeMs);\n }\n String originalJobId = jobCtx.getOriginalJobId();\n RebalanceResult.Status jobStatus = jobStats.getStatus();\n if (jobStatus == RebalanceResult.Status.DONE || jobStatus == RebalanceResult.Status.NO_OP) {\n LOGGER.info(\"Skip rebalance job: {} as it has completed with status: {}\", jobId, jobStatus);\n completedOriginalJobs.put(originalJobId, jobId);\n if (latestCompletedJob == null || latestCompletedJob.getRight() < jobStartTimeMs) {\n latestCompletedJob = Pair.of(jobId, jobStartTimeMs);\n }\n continue;\n }\n if (jobStatus == RebalanceResult.Status.FAILED || jobStatus == RebalanceResult.Status.ABORTED) {\n LOGGER.info(\"Found rebalance job: {} for original job: {} has been stopped with status: {}\", jobId,\n originalJobId, jobStatus);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n continue;\n }\n if (jobStatus == RebalanceResult.Status.CANCELLED) {\n LOGGER.info(\"Found cancelled rebalance job: {} for original job: {}\", jobId, originalJobId);\n cancelledOriginalJobs.add(originalJobId);\n continue;\n }\n // Check if an IN_PROGRESS job is still actively running.\n long heartbeatTimeoutMs = jobCtx.getConfig().getHeartbeatTimeoutInMs();\n if (nowMs - statsUpdatedAt < heartbeatTimeoutMs) {\n LOGGER.info(\"Rebalance job: {} is actively running with status updated at: {} within timeout: {}. Skip \"\n + \"retry for table: {}\", jobId, statsUpdatedAt, heartbeatTimeoutMs, tableNameWithType);\n return Collections.emptyMap();\n }\n // The job is considered failed, but it's possible it is still running, then we might end up with more than one\n // rebalance jobs running in parallel for a table. The rebalance algorithm is idempotent, so this should be fine\n // for the correctness.\n LOGGER.info(\"Found stuck rebalance job: {} for original job: {}\", jobId, originalJobId);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n }\n if (latestCompletedJob != null && latestCompletedJob.getLeft().equals(latestStartedJob.getLeft())) {\n LOGGER.info(\"Rebalance job: {} started most recently has already done. Skip retry for table: {}\",\n latestCompletedJob.getLeft(), tableNameWithType);\n return Collections.emptyMap();\n }\n for (String jobId : cancelledOriginalJobs) {\n LOGGER.info(\"Skip original job: {} as it's cancelled\", jobId);\n candidates.remove(jobId);\n }\n for (Map.Entry entry : completedOriginalJobs.entrySet()) {\n LOGGER.info(\"Skip original job: {} as it's completed by attempt: {}\", entry.getKey(), entry.getValue());\n candidates.remove(entry.getKey());\n }\n return candidates;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetCandidateJobs()\n throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetCandidateJobs", + "reference": " @Test\n public void testGetCandidateJobs()\n throws Exception {\n String tableName = \"table01\";\n Map> allJobMetadata = new HashMap<>();\n\n // Original job run as job1, and all its retry jobs failed too.\n RebalanceConfig jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n TableRebalanceProgressStats stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(1000);\n TableRebalanceContext jobCtx = TableRebalanceContext.forInitialAttempt(\"job1\", jobCfg);\n Map jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job1\", stats, jobCtx);\n allJobMetadata.put(\"job1\", jobMetadata);\n // 3 failed retry runs for job1\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 2, 1100, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 3, 1200, RebalanceResult.Status.ABORTED);\n allJobMetadata.put(\"job1_3\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 4, 1300, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_4\", jobMetadata);\n\n // Original job run as job2, and its retry job job2_1 completed.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(2000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job2\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job2\", stats, jobCtx);\n allJobMetadata.put(\"job2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job2\", 2, 2100, RebalanceResult.Status.DONE);\n allJobMetadata.put(\"job2_2\", jobMetadata);\n\n // Original job run as job3, and failed to send out heartbeat in time.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.IN_PROGRESS);\n stats.setStartTimeMs(3000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job3\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job3\", stats, jobCtx);\n jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"3000\");\n allJobMetadata.put(\"job3\", jobMetadata);\n\n // Original job run as job4, which didn't have retryJobCfg as from old version of the code.\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(4000);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job4\", stats, null);\n jobMetadata.remove(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n allJobMetadata.put(\"job4\", jobMetadata);\n\n // Only need to retry job1 and job3, as job2 is completed and job4 is from old version of code.\n Map>> jobs =\n RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 2);\n assertTrue(jobs.containsKey(\"job1\"));\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job1\").size(), 4); // four runs including job1,job1_1,job1_2,job1_3\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Abort job1 and cancel its retries, then only job3 is retry candidate.\n jobMetadata = allJobMetadata.get(\"job1_4\");\n cancelRebalanceJob(jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 1);\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Add latest job5 that's already done, thus no need to retry for table.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.DONE);\n stats.setStartTimeMs(5000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job5\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job5\", stats, jobCtx);\n allJobMetadata.put(\"job5\", jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 0);\n }", + "completions": [ + "```java\n @Test\n public void testGetCandidateJobs()\n throws Exception {\n // Given\n String tableNameWithType = \"table1_type1\";\n Map> allJobMetadata = new HashMap<>();\n Map jobMetadata1 = new HashMap<>();\n jobMetadata1.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"FAILED\\\"}\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job1\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job1\", jobMetadata1);\n Map jobMetadata2 = new HashMap<>();\n jobMetadata2.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"CANCELLED\\\"}\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job2\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job2\", jobMetadata2);\n\n // When\n Map>> result = getCandidateJobs(tableNameWithType, allJobMetadata);\n\n // Then\n assertEquals(1, result.size());\n assertTrue(result.containsKey(\"job2\"));\n Set> jobSet = result.get(\"job2\");\n assertEquals(1, jobSet.size());\n Pair jobPair = jobSet.iterator().next();\n assertEquals(\"job2\", jobPair.getLeft().getOriginalJobId());\n assertEquals(1640995200000L, jobPair.getRight().longValue());\n }\n```\n" + ], + "line": 94, + "token": 952, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 70, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchError()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchError", + "reference": " @Test\n public void distClientMatchError() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.ERROR);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.ERROR);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.NO_SUB));\n for (IMQTTTransientSession session : sessions) {\n verify(session, never()).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchError() {\n // Create a mock for the ITenantMeter interface\n ITenantMeter mockTenantMeter = Mockito.mock(ITenantMeter.class);\n // Create a mock for the ILocalDistService interface\n ILocalDistService mockLocalDistService = Mockito.mock(ILocalDistService.class);\n // Create a mock for the sessionMap\n Map mockSessionMap = new HashMap<>();\n // Create a mock for the routeMap\n Map> mockRouteMap = new HashMap<>();\n // Create a mock for the resourceThrottler\n IResourceThrottler mockResourceThrottler = Mockito.mock(IResourceThrottler.class);\n\n // Set up the mock objects\n Mockito.when(mockTenantMeter.get(Mockito.anyString())).thenReturn(mockTenantMeter);\n Mockito.when(mockLocalDistService.isGlobal(Mockito.anyString())).thenReturn(false);\n Mockito.when(mockSessionMap.get(Mockito.anyString())).thenReturn(null);\n Mockito.when(mockRouteMap.get(Mockito.any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(null));\n Mockito.when(mockResourceThrottler.hasResource(Mockito.anyString(), Mockito.anyLong())).thenReturn(false);\n\n // Create a new instance of the class under test\n DistService distService = new DistService(mockTenantMeter, mockLocalDistService, mockSessionMap, mockRouteMap, mockResourceThrottler);\n\n // Create a test request\n DeliveryRequest request = DeliveryRequest.newBuilder().build();\n\n // Call the method under test\n CompletableFuture replyFuture = distService.dist(request);\n\n // Verify the results\n DeliveryReply reply = replyFuture.join();\n Assert.assertEquals(0, reply.getResultMap().size());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 71, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchOK()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchOK", + "reference": " @Test\n public void distClientMatchOK() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n when(session.publish(any(), anyList())).thenReturn(true);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.OK);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.OK);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.OK));\n for (IMQTTTransientSession session : sessions) {\n verify(session, times(1)).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchOK() {\n // Given\n DeliveryRequest request = DeliveryRequest.newBuilder()\n .putPackage(\"tenant1\", DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .setMessagePack(TopicMessagePack.newBuilder()\n .setTopic(\"topic1\")\n .setPayload(\"payload1\")\n .build())\n .addMatchInfo(MatchInfo.newBuilder()\n .setReceiverId(\"receiver1\")\n .setTopicFilter(\"topic1\")\n .build())\n .build())\n .build())\n .build();\n sessionMap.put(\"receiver1\", mock(IMQTTTransientSession.class));\n when(sessionMap.get(\"receiver1\").publish(any(MatchInfo.class), anyList())).thenReturn(true);\n when(routeMap.get(any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(LocalRoutes.newBuilder()\n .setLocalizedReceiverId(\"receiver1\")\n .putRoute(\"receiver1\", mock(IMQTTTransientSession.class))\n .build()));\n when(ITenantMeter.get(\"tenant1\")).thenReturn(mock(ITenantMeter.class));\n when(resourceThrottler.hasResource(\"tenant1\", TotalTransientFanOutBytesPerSeconds)).thenReturn(true);\n\n // When\n CompletableFuture replyFuture = dist(request);\n\n // Then\n DeliveryReply reply = replyFuture.join();\n DeliveryResults results = reply.getResultMap().get(\"tenant1\");\n assertEquals(1, results.getResultCount());\n DeliveryResult result = results.getResult(0);\n assertEquals(DeliveryResult.Code.OK, result.getCode());\n assertEquals(\"receiver1\", result.getMatchInfo().getReceiverId());\n assertEquals(\"topic1\", result.getMatchInfo().getTopicFilter());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 72, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Range calcInterquartileRangeApproximation(List sortedValues) {\n Range iqr = null;\n\n if (sortedValues.size() > 0) {\n BigDecimal q1 = valueCloseToQ1(sortedValues);\n BigDecimal q3 = valueCloseToQ3(sortedValues);\n BigDecimal max = sortedValues.get(sortedValues.size() - 1);\n\n // if iqr == 0 AND max == q3 then recursively try finding a non-zero iqr approximation.\n if (q1.compareTo(q3) == 0 && max.compareTo(q3) == 0) {\n // filter out max and try again\n iqr = this.calcInterquartileRangeApproximation(\n sortedValues.stream().filter(d -> d.compareTo(max) < 0).collect(Collectors.toList()));\n }\n\n // if range is still empty use the original q1 and q3 values\n if (iqr == null || iqr.isEmpty()) {\n iqr = Range.closedOpen(q1, q3);\n }\n }\n\n return iqr;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCalcInterquartileRangeApproximation()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCalcInterquartileRangeApproximation", + "reference": " @Test\n public void testCalcInterquartileRangeApproximation() {\n List values = decList(1, 2, 3, 4, 5, 6, 7, 8);\n assertEquals(new BigDecimal(\"3\"), dataBinHelper.calcInterquartileRangeApproximation(values).lowerEndpoint());\n assertEquals(new BigDecimal(\"7\"), dataBinHelper.calcInterquartileRangeApproximation(values).upperEndpoint());\n }", + "completions": [ + "```java\n @Test\n public void testCalcInterquartileRangeApproximation() {\n List sortedValues = Arrays.asList(\n new BigDecimal(\"1.0\"),\n new BigDecimal(\"2.0\"),\n new BigDecimal(\"3.0\"),\n new BigDecimal(\"4.0\"),\n new BigDecimal(\"5.0\"),\n new BigDecimal(\"6.0\"),\n new BigDecimal(\"7.0\"),\n new BigDecimal(\"8.0\"),\n new BigDecimal(\"9.0\"),\n new BigDecimal(\"10.0\")\n );\n\n Range expectedIqr = Range.closedOpen(new BigDecimal(\"2.5\"), new BigDecimal(\"7.5\"));\n Range actualIqr = calcInterquartileRangeApproximation(sortedValues);\n\n assertEquals(expectedIqr, actualIqr);\n }\n```\n" + ], + "line": 30, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 73, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_non_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new NonSerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new NonSerializableObject(\"name2\"),\n new NonSerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObjectWithNonSerializableNestedObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serialable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key1\", \"value1\");\n map.put(\"key2\", new NonSerializableObject(\"nonSerializableValue\"));\n\n // When\n Map clonedMap = objectCloner.clone(map);\n\n // Then\n assertNotSame(map, clonedMap);\n assertEquals(map.size(), clonedMap.size());\n assertEquals(map.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(map.get(\"key2\"), clonedMap.get(\"key2\"));\n assertEquals(map.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 74, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SignedJWT token2Jwt(String token) {\n if (token == null) {\n log.error(\"Token is null.\");\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n try {\n return SignedJWT.parse(token);\n } catch (ParseException e) {\n log.error(\"Token to JWT Error,message = {}\", e.getMessage());\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testToken2Jwt()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testToken2Jwt", + "reference": " @Test\n public void testToken2Jwt() {\n Assertions.assertThrows(BizException.class, () -> tokenParseHelper.token2Jwt(null));\n }", + "completions": [ + "```java\n @Test\n public void testToken2Jwt() {\n // Given\n String token = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c\";\n\n // When\n SignedJWT jwt = token2Jwt(token);\n\n // Then\n assertNotNull(jwt);\n assertEquals(token, jwt.serialize());\n }\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testLongPollWaitIsLimited()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testLongPollWaitIsLimited", + "reference": " @Test\n public void testLongPollWaitIsLimited() {\n consumer = newConsumer();\n String topicName = \"topic1\";\n consumer.subscribe(singletonList(topicName));\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertTrue(consumer.assignment().isEmpty());\n\n final int partition = 3;\n final TopicPartition tp = new TopicPartition(topicName, partition);\n final List> records = asList(\n new ConsumerRecord<>(topicName, partition, 2, \"key1\", \"value1\"),\n new ConsumerRecord<>(topicName, partition, 3, \"key2\", \"value2\")\n );\n\n // On the first iteration, return no data; on the second, return two records\n doAnswer(invocation -> {\n // Mock the subscription being assigned as the first fetch is collected\n consumer.subscriptions().assignFromSubscribed(Collections.singleton(tp));\n return Fetch.empty();\n }).doAnswer(invocation -> {\n return Fetch.forPartition(tp, records, true);\n }).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n\n // And then poll for up to 10000ms, which should return 2 records without timing out\n ConsumerRecords returnedRecords = consumer.poll(Duration.ofMillis(10000));\n assertEquals(2, returnedRecords.count());\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertEquals(singleton(tp), consumer.assignment());\n }", + "completions": [ + "```java\n @Test\n public void testLongPollWaitIsLimited() {\n // Arrange\n Duration timeout = Duration.ofMillis(1000);\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // Act\n long startTime = System.currentTimeMillis();\n ConsumerRecords records = consumer.poll(timeout);\n long endTime = System.currentTimeMillis();\n\n // Assert\n long pollDuration = endTime - startTime;\n assertTrue(pollDuration >= timeout.toMillis());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForMinAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForMinAgg", + "reference": " @Test\n public void testFullIndexSearchForMinAgg() {\n Instant time = Instant.now();\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new MinAggBuilder(\n \"test\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalMin internalMin =\n (InternalMin) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n assertThat(Double.valueOf(internalMin.getValue()).longValue()).isEqualTo(time.toEpochMilli());\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForMinAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new MinAggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation() instanceof MinAgg);\n assertTrue(result.getElapsedTimeMicros() > 0);\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 77, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_bsaChecksumFetched", + "reference": " @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n setupMocks();\n when(connection.getResponseCode()).thenReturn(SC_OK);\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream(\"bsa-checksum\\ndata\".getBytes(UTF_8)));\n LazyBlockList download = fetcher.tryFetch(BlockListType.BLOCK);\n assertThat(download.getName()).isEqualTo(BlockListType.BLOCK);\n assertThat(download.checksum()).isEqualTo(\"bsa-checksum\");\n verify(connection, times(1)).setRequestMethod(\"GET\");\n verify(connection, times(1)).setRequestProperty(\"Authorization\", \"Bearer authToken\");\n }", + "completions": [ + "```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n // Given\n URL dataUrl = new URL(\"https://example.com/bsaChecksum\");\n HttpsURLConnection connection = mock(HttpsURLConnection.class);\n when(connection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);\n when(connection.getErrorStream()).thenReturn(new ByteArrayInputStream(\"\".getBytes()));\n when(urlConnectionService.createConnection(dataUrl)).thenReturn(connection);\n\n // When\n LazyBlockList lazyBlockList = bsaService.tryFetch(BlockListType.BSA_CHECKSUM);\n\n // Then\n assertEquals(BlockListType.BSA_CHECKSUM, lazyBlockList.getBlockListType());\n assertEquals(connection, lazyBlockList.getConnection());\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDisabledFullTextSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDisabledFullTextSearch", + "reference": " @Test\n public void testDisabledFullTextSearch() {\n Instant time = Instant.ofEpochSecond(1593365471);\n Trace.KeyValue field1Tag =\n Trace.KeyValue.newBuilder()\n .setVInt32(1234)\n .setKey(\"field1\")\n .setFieldType(Schema.SchemaFieldType.INTEGER)\n .build();\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(2, \"apple baby\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(3, \"baby car 1234\", time.plusSeconds(4)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Without the _all field as default.\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(3);\n\n // empty string\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"app*\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Returns baby or car, 2 messages.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby car\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Test numbers\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"apple 1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"123\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n }", + "completions": [ + "```java\n @Test\n public void testDisabledFullTextSearch() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getResults().size());\n assertNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 79, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyWithException", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(\n OperateRuntimeException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHitsRelation.EQUAL_TO));\n Mockito.when(searchHits.getHits()).thenReturn(new SearchHit[0]);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processInstanceListViewService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException to be thrown\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find unique process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 80, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_SecurityException_notRetriable", + "reference": " @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n when(connectionService.createConnection(any(URL.class)))\n .thenThrow(new GeneralSecurityException());\n assertThat(\n assertThrows(BsaException.class, () -> fetcher.tryFetch(BlockListType.BLOCK))\n .isRetriable())\n .isFalse();\n }", + "completions": [ + "```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n // Arrange\n String blockListTypeName = \"TEST\";\n BlockListType blockListType = new BlockListType(blockListTypeName);\n URL dataUrl = new URL(\"http://example.com\");\n when(blockListUrls.get(blockListTypeName)).thenReturn(dataUrl.toString());\n when(urlConnectionService.createConnection(dataUrl)).thenThrow(new GeneralSecurityException());\n\n // Act and Assert\n assertThrows(BsaException.class, () -> {\n try {\n blockListService.tryFetch(blockListType);\n } catch (BsaException e) {\n assertFalse(e.isRetriable());\n throw e;\n }\n });\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 81, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssignOnEmptyTopicPartition()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssignOnEmptyTopicPartition", + "reference": " @Test\n public void testAssignOnEmptyTopicPartition() {\n consumer = newConsumer();\n completeUnsubscribeApplicationEventSuccessfully();\n\n consumer.assign(Collections.emptyList());\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().isEmpty());\n }", + "completions": [ + "```java\n @Test\n public void testAssignOnEmptyTopicPartition() {\n // Given\n Collection partitions = new ArrayList<>();\n ConsumerImpl consumer = new ConsumerImpl();\n\n // When\n consumer.assign(partitions);\n\n // Then\n // Add assertions to verify the expected behavior\n // For example, you can check if the consumer is unsubscribed and if the fetch buffer is empty\n // assertFalse(consumer.isSubscribed());\n // assertTrue(consumer.getFetchBuffer().isEmpty());\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 82, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ImmutableList applyLabelDiff(\n ImmutableList labels,\n IdnChecker idnChecker,\n DownloadSchedule schedule,\n DateTime now) {\n ImmutableList.Builder nonBlockedDomains = new ImmutableList.Builder<>();\n ImmutableMap> labelsByType =\n ImmutableMap.copyOf(\n labels.stream().collect(groupingBy(BlockLabel::labelType, toImmutableList())));\n\n tm().transact(\n () -> {\n for (Map.Entry> entry :\n labelsByType.entrySet()) {\n switch (entry.getKey()) {\n case CREATE:\n // With current Cloud SQL, label upsert throughput is about 200/second. If\n // better performance is needed, consider bulk insert in native SQL.\n tm().putAll(\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(\n label ->\n new BsaLabel(label.label(), schedule.jobCreationTime()))\n .collect(toImmutableList()));\n // May not find all unblockables due to race condition: DomainCreateFlow uses\n // cached BsaLabels. Eventually will be consistent.\n nonBlockedDomains.addAll(\n tallyUnblockableDomainsForNewLabels(entry.getValue(), idnChecker, now));\n break;\n case DELETE:\n ImmutableSet deletedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n // Delete labels in DB. Also cascade-delete BsaUnblockableDomain.\n int nDeleted = Queries.deleteBsaLabelByLabels(deletedLabels);\n if (nDeleted != deletedLabels.size()) {\n logger.atSevere().log(\n \"Only found %s entities among the %s labels: [%s]\",\n nDeleted, deletedLabels.size(), deletedLabels);\n }\n break;\n case NEW_ORDER_ASSOCIATION:\n ImmutableSet affectedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n ImmutableSet labelsInDb =\n Queries.queryBsaLabelByLabels(affectedLabels)\n .map(BsaLabel::getLabel)\n .collect(toImmutableSet());\n verify(\n labelsInDb.size() == affectedLabels.size(),\n \"Missing labels in DB: %s\",\n LazyArgs.lazy(() -> difference(affectedLabels, labelsInDb)));\n\n // Reuse registered and reserved names that are already computed.\n Queries.queryBsaUnblockableDomainByLabels(affectedLabels)\n .map(BsaUnblockableDomain::toUnblockableDomain)\n .forEach(nonBlockedDomains::add);\n\n for (BlockLabel label : entry.getValue()) {\n getInvalidTldsForLabel(label, idnChecker)\n .map(tld -> UnblockableDomain.of(label.label(), tld, Reason.INVALID))\n .forEach(nonBlockedDomains::add);\n }\n break;\n }\n }\n },\n TRANSACTION_REPEATABLE_READ);\n logger.atInfo().log(\"Processed %s of labels.\", labels.size());\n return nonBlockedDomains.build();\n }\n```\n###test function signature:\n```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "applyLabelDiffs_newAssociationOfLabelToOrder", + "reference": " @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n tm().transact(\n () -> {\n tm().insert(new BsaLabel(\"label\", fakeClock.nowUtc()));\n tm().insert(new BsaUnblockableDomain(\"label\", \"app\", Reason.REGISTERED));\n });\n when(idnChecker.getSupportingTlds(any())).thenReturn(ImmutableSet.of(app));\n when(idnChecker.getForbiddingTlds(any()))\n .thenReturn(Sets.difference(ImmutableSet.of(dev), ImmutableSet.of()).immutableCopy());\n\n ImmutableList unblockableDomains =\n applyLabelDiff(\n ImmutableList.of(\n BlockLabel.of(\"label\", LabelType.NEW_ORDER_ASSOCIATION, ImmutableSet.of())),\n idnChecker,\n schedule,\n fakeClock.nowUtc());\n assertThat(unblockableDomains)\n .containsExactly(\n UnblockableDomain.of(\"label.app\", UnblockableDomain.Reason.REGISTERED),\n UnblockableDomain.of(\"label.dev\", UnblockableDomain.Reason.INVALID));\n assertThat(tm().transact(() -> tm().loadByKeyIfPresent(BsaLabel.vKey(\"label\")))).isPresent();\n assertThat(\n tm().transact(() -> tm().loadByKeyIfPresent(BsaUnblockableDomain.vKey(\"label\", \"app\"))))\n .isPresent();\n }", + "completions": [ + "```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n // Given\n ImmutableList labels = ImmutableList.of(BlockLabel.of(\"label1\", LabelType.NEW_ORDER_ASSOCIATION));\n IdnChecker idnChecker = mock(IdnChecker.class);\n DownloadSchedule schedule = mock(DownloadSchedule.class);\n DateTime now = DateTime.now();\n\n // When\n ImmutableList result = applyLabelDiff(labels, idnChecker, schedule, now);\n\n // Then\n // Add assertions here to verify the result\n }\n```\n" + ], + "line": 85, + "token": 602, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 83, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture createTopic(CreateTopicRequest request) {\n CompletableFuture future = new CompletableFuture<>();\n for (; ; ) {\n if (metadataStore.isLeader()) {\n try (SqlSession session = metadataStore.openSession()) {\n if (!metadataStore.maintainLeadershipWithSharedLock(session)) {\n continue;\n }\n TopicMapper topicMapper = session.getMapper(TopicMapper.class);\n if (null != topicMapper.get(null, request.getTopic())) {\n ControllerException e = new ControllerException(Code.DUPLICATED_VALUE,\n String.format(\"Topic %s was taken\", request.getTopic()));\n future.completeExceptionally(e);\n return future;\n }\n\n Topic topic = new Topic();\n topic.setName(request.getTopic());\n topic.setQueueNum(request.getCount());\n topic.setStatus(TopicStatus.TOPIC_STATUS_ACTIVE);\n topic.setAcceptMessageTypes(JsonFormat.printer().print(request.getAcceptTypes()));\n topic.setRetentionHours(request.getRetentionHours());\n topicMapper.create(topic);\n long topicId = topic.getId();\n List assignments = createQueues(IntStream.range(0, request.getCount()),\n topicId, session);\n // Commit transaction\n session.commit();\n\n // Cache new topic and queue assignments immediately\n topicCache.apply(List.of(topic));\n assignmentCache.apply(assignments);\n future.complete(topicId);\n } catch (ControllerException | InvalidProtocolBufferException e) {\n future.completeExceptionally(e);\n }\n return future;\n } else {\n Optional leaderAddress = metadataStore.electionService().leaderAddress();\n if (leaderAddress.isEmpty()) {\n return CompletableFuture.failedFuture(new ControllerException(Code.NO_LEADER_VALUE, \"No leader is elected yet\"));\n }\n metadataStore.controllerClient().createTopic(leaderAddress.get(), request).whenComplete((res, e) -> {\n if (null != e) {\n future.completeExceptionally(e);\n } else {\n future.complete(res);\n }\n });\n }\n break;\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCreateTopic()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCreateTopic", + "reference": " @Test\n public void testCreateTopic() {\n MetadataStore metadataStore = Mockito.mock(MetadataStore.class);\n ControllerClient controllerClient = Mockito.mock(ControllerClient.class);\n Mockito.when(controllerClient.createTopic(ArgumentMatchers.anyString(), ArgumentMatchers.any())).thenReturn(\n CompletableFuture.failedFuture(new CompletionException(new RuntimeException()))\n );\n Mockito.when(metadataStore.controllerClient()).thenReturn(controllerClient);\n Mockito.when(metadataStore.isLeader()).thenReturn(false);\n\n ElectionService electionService = Mockito.mock(ElectionService.class);\n Mockito.when(metadataStore.electionService()).thenReturn(electionService);\n\n Mockito.when(electionService.leaderAddress()).thenReturn(Optional.of(\"localhost:1234\"));\n\n TopicManager topicManager = new TopicManager(metadataStore);\n topicManager.createTopic(CreateTopicRequest.newBuilder().build());\n }", + "completions": [ + "```java\n @Test\n public void testCreateTopic() {\n // Given\n CreateTopicRequest request = new CreateTopicRequest();\n request.setTopic(\"testTopic\");\n request.setCount(3);\n request.setRetentionHours(24);\n request.setAcceptTypes(List.of(MessageType.MESSAGE_TYPE_JSON));\n\n // When\n CompletableFuture future = controller.createTopic(request);\n\n // Then\n assertDoesNotThrow(() -> {\n Long topicId = future.get();\n assertNotNull(topicId);\n assertTrue(topicId > 0);\n });\n }\n```\n" + ], + "line": 63, + "token": 424, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 84, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull(){\n try(MockedStatic sessionMockedStatic = Mockito.mockStatic(Session.class)){\n Session mockSession = mock(Session.class);\n Transport transport = mock(Transport.class);\n sessionMockedStatic.when(()->Session.getDefaultInstance(any(Properties.class))).thenReturn(mockSession);\n when(blacklistService.inBlacklist(anyString())).thenReturn(false);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n when(settingsService.getByCategoryAndKey(anyString(),anyString())).thenReturn(\"test\").thenReturn(5678).thenReturn(\"test\").thenReturn(\"test\").thenReturn(\"123456\");\n when(mockSession.getTransport(anyString())).thenReturn(transport);\n doNothing().when(transport).connect(anyString(),anyInt(),anyString(),anyString());\n doAnswer(invocationOnMock -> {\n InternetAddress[] internetAddressList = invocationOnMock.getArgument(1);\n Assertions.assertEquals(addressList.get(0),internetAddressList[0].getAddress());\n return null;\n }).when(transport).sendMessage(any(MimeMessage.class),any());\n mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n } catch (MessagingException e) {\n throw new RuntimeException(e);\n }\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = Arrays.asList(\"test1@example.com\", \"test2@example.com\");\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://example.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 85, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead", + "reference": " @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 1,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n AstraConfigs.IndexerConfig headLocationAndNoRecoveryConfig =\n AstraConfigs.IndexerConfig.newBuilder()\n .setCreateRecoveryTasksOnStart(false)\n .setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST)\n .build();\n\n // When there is no data and ReadFromLocationOnStart is set to LATEST, return the current head\n assertThat(\n recoveryTaskCreator.determineStartingOffset(1000, 0, headLocationAndNoRecoveryConfig))\n .isEqualTo(1000);\n\n // Data exists for not for this partition.\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTime = 1;\n final long endTime = 100;\n final long maxOffset = 100;\n\n final SnapshotMetadata partition1 =\n new SnapshotMetadata(name, path, startTime, endTime, maxOffset, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition1);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition1));\n\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n\n final SnapshotMetadata partition11 =\n new SnapshotMetadata(\n name + \"1\", path, endTime + 1, endTime * 2, maxOffset * 2, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition11);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition11));\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(partition1, partition11);\n assertThat(recoveryTaskCreator.determineStartingOffset(0, 0, indexerConfig)).isNegative();\n\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n \"2\",\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(false);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n\n // When\n long result = determineStartingOffset(currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(currentEndOffsetForPartition, result);\n }\n```\n" + ], + "line": 128, + "token": 935, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 86, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime =\n LocalDateTime.of(2020, 10, 1, 10, 10, 0).atZone(ZoneOffset.UTC).toInstant();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n long startOffset = 1;\n long endOffset = msgsToProduce - 1;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry))\n .isEqualTo(endOffset - startOffset + 1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(0);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(0);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isEmpty();\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 0,\n 100,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 87, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"3g\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n String expirationTime = \"10x\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(expirationTime);\n assertEquals(CacheControlHeader.DEFAULT_BASE_VALUE + CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_empty_collection()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_empty_collection", + "reference": " @Test\n public void should_clone_empty_collection() {\n List original = new ArrayList<>();\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_empty_collection() {\n // Given\n Collection collection = new ArrayList<>();\n\n // When\n Collection clonedCollection = (Collection) clone(collection);\n\n // Then\n assertNotSame(collection, clonedCollection);\n assertEquals(collection, clonedCollection);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 89, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testTimeBoundSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testTimeBoundSearch", + "reference": " @Test\n public void testTimeBoundSearch() {\n Instant time = Instant.now();\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(1, time));\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(2, time.plusSeconds(100)));\n strictLogStore.logStore.commit();\n strictLogStore.logStore.refresh();\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(2);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(0);\n assertThat(getTimerCount(REFRESHES_TIMER, strictLogStore.metricsRegistry)).isEqualTo(1);\n\n // Start inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Extended range still only picking one element.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(90).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Both ranges are inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.toEpochMilli(),\n time.plusSeconds(100).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n // Extended range to pick up both events\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(1000).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n }", + "completions": [ + "```java\n @Test\n public void testTimeBoundSearch() {\n // Given\n String dataset = \"test_dataset\";\n String queryStr = \"test_query\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n for (LogMessage logMessage : result.getResults()) {\n assertTrue(logMessage.getTimestampMs() >= startTimeMsEpoch);\n assertTrue(logMessage.getTimestampMs() <= endTimeMsEpoch);\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 90, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "\t@ParameterizedTest\n\t@MethodSource", + "reference": "\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\t// setup\n\t\tString key = \"key\";\n\t\tTestParameters parameters = new TestParameters();\n\t\tObject testValue = theParams.getTestValue();\n\n\t\t// test\n\t\tif (theParams.isExpectedToWork()) {\n\t\t\tparameters.setUserData(key, testValue);\n\t\t\tassertFalse(parameters.getUserData().isEmpty());\n\t\t\tassertEquals(testValue, parameters.getUserData().get(key));\n\t\t} else {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, testValue);\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tString dataType = testValue.getClass().getName();\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid data type provided \" + dataType),\n\t\t\t\t\tex.getMessage());\n\t\t\t\tassertTrue(parameters.getUserData().isEmpty());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\tUser user = new User();\n\t\tuser.setUserData(theParams.getKey(), theParams.getValue());\n\t\tassertEquals(theParams.getValue(), user.getUserData().get(theParams.getKey()));\n\t}\n\n\tstatic Stream parameters() {\n\t\treturn Stream.of(\n\t\t\t\tnew TestParam(\"stringKey\", \"stringValue\"),\n\t\t\t\tnew TestParam(\"numberKey\", 123),\n\t\t\t\tnew TestParam(\"booleanKey\", true),\n\t\t\t\tnew TestParam(\"nullKey\", null)\n\t\t);\n\t}\n\n\tstatic class TestParam {\n\t\tprivate final String key;\n\t\tprivate final Object value;\n\n\t\tTestParam(String key, Object value) {\n\t\t\tthis.key = key;\n\t\t\tthis.value = value;\n\t\t}\n\n\t\tString getKey() {\n\t\t\treturn key;\n\t\t}\n\n\t\tObject getValue() {\n\t\t\treturn value;\n\t\t}\n\t}\n```\n" + ], + "line": 24, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\n\t\tString text = \"Hello, how are you?\";\n\n\t\tEmbedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));\n\n\t\tserver\n\t\t\t.expect(requestToUriTemplate(\"/models/{generative}:embedText?key={apiKey}\",\n\t\t\t\t\tVertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))\n\t\t\t.andExpect(method(HttpMethod.POST))\n\t\t\t.andExpect(content().json(objectMapper.writeValueAsString(Map.of(\"text\", text))))\n\t\t\t.andRespond(withSuccess(objectMapper.writeValueAsString(Map.of(\"embedding\", expectedEmbedding)),\n\t\t\t\t\tMediaType.APPLICATION_JSON));\n\n\t\tEmbedding embedding = client.embedText(text);\n\n\t\tassertThat(embedding).isEqualTo(expectedEmbedding);\n\n\t\tserver.verify();\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\tEmbeddingResponse response = new EmbeddingResponse(expectedEmbedding);\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(response);\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Mono syncAclWithAclCsv(KafkaCluster cluster, String csv) {\n return adminClientService.get(cluster)\n .flatMap(ac -> ac.listAcls(ResourcePatternFilter.ANY).flatMap(existingAclList -> {\n var existingSet = Set.copyOf(existingAclList);\n var newAcls = Set.copyOf(AclCsv.parseCsv(csv));\n var toDelete = Sets.difference(existingSet, newAcls);\n var toAdd = Sets.difference(newAcls, existingSet);\n logAclSyncPlan(cluster, toAdd, toDelete);\n if (toAdd.isEmpty() && toDelete.isEmpty()) {\n return Mono.empty();\n }\n log.info(\"Starting new ACLs creation\");\n return ac.createAcls(toAdd)\n .doOnSuccess(v -> {\n log.info(\"{} new ACLs created\", toAdd.size());\n log.info(\"Starting ACLs deletion\");\n })\n .then(ac.deleteAcls(toDelete)\n .doOnSuccess(v -> log.info(\"{} ACLs deleted\", toDelete.size())));\n }));\n }\n```\n###test function signature:\n```java\n @Test\n void testSyncAclWithAclCsv()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSyncAclWithAclCsv", + "reference": " @Test\n void testSyncAclWithAclCsv() {\n var existingBinding1 = new AclBinding(\n new ResourcePattern(ResourceType.TOPIC, \"*\", PatternType.LITERAL),\n new AccessControlEntry(\"User:test1\", \"*\", AclOperation.READ, AclPermissionType.ALLOW));\n\n var existingBinding2 = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"group1\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test2\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n var newBindingToBeAdded = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"groupNew\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test3\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n when(adminClientMock.listAcls(ResourcePatternFilter.ANY))\n .thenReturn(Mono.just(List.of(existingBinding1, existingBinding2)));\n\n ArgumentCaptor> createdCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.createAcls(createdCaptor.capture()))\n .thenReturn(Mono.empty());\n\n ArgumentCaptor> deletedCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.deleteAcls(deletedCaptor.capture()))\n .thenReturn(Mono.empty());\n\n aclsService.syncAclWithAclCsv(\n CLUSTER,\n \"Principal,ResourceType, PatternType, ResourceName,Operation,PermissionType,Host\\n\"\n + \"User:test1,TOPIC,LITERAL,*,READ,ALLOW,*\\n\"\n + \"User:test3,GROUP,PREFIXED,groupNew,DESCRIBE,DENY,localhost\"\n ).block();\n\n Collection createdBindings = createdCaptor.getValue();\n assertThat(createdBindings)\n .hasSize(1)\n .contains(newBindingToBeAdded);\n\n Collection deletedBindings = deletedCaptor.getValue();\n assertThat(deletedBindings)\n .hasSize(1)\n .contains(existingBinding2);\n }", + "completions": [ + "```java\n @Test\n void testSyncAclWithAclCsv() {\n // Given\n KafkaCluster cluster = new KafkaCluster(\"testCluster\", \"localhost:9092\");\n String csv = \"principal,host,operation,permissionType\\nUser:Alice,*,Read,Allow\\nUser:Bob,*,Write,Allow\";\n AdminClient adminClient = Mockito.mock(AdminClient.class);\n when(adminClientService.get(cluster)).thenReturn(Mono.just(adminClient));\n when(adminClient.listAcls(ResourcePatternFilter.ANY)).thenReturn(Flux.just(\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Alice\", \"*\", AclOperation.READ, AclPermissionType.ALLOW)),\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Bob\", \"*\", AclOperation.WRITE, AclPermissionType.ALLOW))\n ).collectList());\n\n // When\n Mono result = service.syncAclWithAclCsv(cluster, csv);\n\n // Then\n StepVerifier.create(result)\n .verifyComplete();\n\n verify(adminClient).createAcls(any());\n verify(adminClient).deleteAcls(any());\n }\n```\n" + ], + "line": 32, + "token": 259, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 93, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyNoResults", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final SearchResponse mockResponse = mock(SearchResponse.class);\n final SearchHits mockHits = mock(SearchHits.class);\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n when(mockHits.getHits()).thenReturn(new SearchHit[0]);\n when(tenantAwareClient.search(any(SearchRequest.class))).thenReturn(mockResponse);\n\n // When\n try {\n processInstanceService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 94, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyMultipleRecoveryBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(3);\n assertThat(recoveryTasks).contains(recoveryTask1, recoveryTask11);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo(recoveryStartOffset * 3 + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(100);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 95, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n private String saveAndChecksumBlockList(String jobName, LazyBlockList blockList) {\n BlobId blobId = getBlobId(jobName, getBlockListFileName(blockList.getName()));\n try (BufferedOutputStream gcsWriter =\n new BufferedOutputStream(gcsUtils.openOutputStream(blobId))) {\n MessageDigest messageDigest = MessageDigest.getInstance(checksumAlgorithm);\n blockList.consumeAll(\n (byteArray, length) -> {\n try {\n gcsWriter.write(byteArray, 0, length);\n } catch (IOException e) {\n throw new RuntimeException(e);\n }\n messageDigest.update(byteArray, 0, length);\n });\n return base16().lowerCase().encode(messageDigest.digest());\n } catch (IOException | NoSuchAlgorithmException e) {\n throw new RuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "saveAndChecksumBlockList_success", + "reference": " @Test\n void saveAndChecksumBlockList_success() throws Exception {\n String payload = \"somedata\\n\";\n String payloadChecksum = \"0737c8e591c68b93feccde50829aca86a80137547d8cfbe96bab6b20f8580c63\";\n\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream((\"bsa-checksum\\n\" + payload).getBytes(UTF_8)));\n lazyBlockList = new LazyBlockList(BlockListType.BLOCK, connection);\n\n ImmutableMap checksums =\n gcsClient.saveAndChecksumBlockList(\"some-name\", ImmutableList.of(lazyBlockList));\n assertThat(gcsUtils.existsAndNotEmpty(BlobId.of(\"my-bucket\", \"some-name/BLOCK.csv\"))).isTrue();\n assertThat(checksums).containsExactly(BlockListType.BLOCK, payloadChecksum);\n assertThat(gcsClient.readBlockList(\"some-name\", BlockListType.BLOCK))\n .containsExactly(\"somedata\");\n }", + "completions": [ + "```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception {\n // Given\n String jobName = \"testJob\";\n LazyBlockList blockList = mock(LazyBlockList.class);\n BlobId blobId = mock(BlobId.class);\n BufferedOutputStream gcsWriter = mock(BufferedOutputStream.class);\n MessageDigest messageDigest = mock(MessageDigest.class);\n when(gcsUtils.openOutputStream(blobId)).thenReturn(gcsWriter);\n when(checksumAlgorithm).thenReturn(\"SHA-256\");\n when(messageDigest.digest()).thenReturn(new byte[] {1, 2, 3, 4, 5});\n when(base16().lowerCase().encode(messageDigest.digest())).thenReturn(\"0102030405\");\n\n // When\n String checksum = saveAndChecksumBlockList(jobName, blockList);\n\n // Then\n assertEquals(\"0102030405\", checksum);\n verify(gcsWriter).write(any(), anyInt(), anyInt());\n verify(messageDigest).update(any(), anyInt(), anyInt());\n }\n```\n" + ], + "line": 30, + "token": 211, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ListenableFuture> findMissingBlobs(\n Iterable blobDigests, RequestMetadata requestMetadata) {\n // Some requests have been blocked, and we should tell the client we refuse to perform a lookup.\n try {\n if (inDenyList(requestMetadata)) {\n return immediateFailedFuture(\n Status.UNAVAILABLE\n .withDescription(\"The action associated with this request is forbidden\")\n .asException());\n }\n } catch (IOException e) {\n return immediateFailedFuture(Status.fromThrowable(e).asException());\n }\n\n // Empty blobs are an exceptional case. Filter them out.\n // If the user only requested empty blobs we can immediately tell them we already have it.\n Iterable nonEmptyDigests =\n Iterables.filter(blobDigests, (digest) -> digest.getSizeBytes() != 0);\n if (Iterables.isEmpty(nonEmptyDigests)) {\n return immediateFuture(ImmutableList.of());\n }\n\n if (configs.getServer().isFindMissingBlobsViaBackplane()) {\n return findMissingBlobsViaBackplane(nonEmptyDigests, requestMetadata);\n }\n\n return findMissingBlobsQueryingEachWorker(nonEmptyDigests, requestMetadata);\n }\n```\n###test function signature:\n```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "findMissingBlobsTest_ViaBackPlane", + "reference": " @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n Set activeWorkers = ImmutableSet.of(\"worker1\", \"worker2\", \"worker3\");\n Set expiredWorkers = ImmutableSet.of(\"workerX\", \"workerY\", \"workerZ\");\n Set imposterWorkers = ImmutableSet.of(\"imposter1\", \"imposter2\", \"imposter3\");\n\n Set availableDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFound1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build());\n\n Set missingDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"missing1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build());\n\n Set digestAvailableOnImposters =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFoundOnImposter1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter3\").setSizeBytes(1).build());\n\n Set emptyDigests =\n new HashSet<>(\n Arrays.asList(\n Digest.newBuilder().setHash(\"empty1\").build(),\n Digest.newBuilder().setHash(\"empty2\").build()));\n\n Iterable allDigests =\n Iterables.concat(\n availableDigests,\n missingDigests,\n emptyDigests,\n digestAvailableOnImposters,\n Arrays.asList(\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build()));\n\n Map> digestAndWorkersMap = new HashMap<>();\n\n for (Digest digest : availableDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(activeWorkers));\n }\n for (Digest digest : missingDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(expiredWorkers));\n }\n for (Digest digest : digestAvailableOnImposters) {\n digestAndWorkersMap.put(digest, getRandomSubset(imposterWorkers));\n }\n\n BuildfarmConfigs buildfarmConfigs = instance.getBuildFarmConfigs();\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(true);\n Set activeAndImposterWorkers =\n Sets.newHashSet(Iterables.concat(activeWorkers, imposterWorkers));\n\n when(mockBackplane.getStorageWorkers()).thenReturn(activeAndImposterWorkers);\n when(mockBackplane.getBlobDigestsWorkers(any(Iterable.class))).thenReturn(digestAndWorkersMap);\n when(mockInstanceLoader.load(anyString())).thenReturn(mockWorkerInstance);\n when(mockWorkerInstance.findMissingBlobs(anyIterable(), any(RequestMetadata.class)))\n .thenReturn(Futures.immediateFuture(new ArrayList<>()));\n\n long serverStartTime = 1686951033L; // june 15th, 2023\n Map workersStartTime = new HashMap<>();\n for (String worker : activeAndImposterWorkers) {\n workersStartTime.put(worker, serverStartTime);\n }\n when(mockBackplane.getWorkersStartTimeInEpochSecs(activeAndImposterWorkers))\n .thenReturn(workersStartTime);\n long oneDay = 86400L;\n for (Digest digest : availableDigests) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime + oneDay);\n }\n for (Digest digest : digestAvailableOnImposters) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime - oneDay);\n }\n\n Iterable actualMissingDigests =\n instance.findMissingBlobs(allDigests, RequestMetadata.getDefaultInstance()).get();\n Iterable expectedMissingDigests =\n Iterables.concat(missingDigests, digestAvailableOnImposters);\n\n assertThat(actualMissingDigests).containsExactlyElementsIn(expectedMissingDigests);\n verify(mockWorkerInstance, atMost(3))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n verify(mockWorkerInstance, atLeast(1))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n\n for (Digest digest : actualMissingDigests) {\n assertThat(digest).isNotIn(availableDigests);\n assertThat(digest).isNotIn(emptyDigests);\n assertThat(digest).isIn(expectedMissingDigests);\n }\n\n // reset BuildfarmConfigs\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(false);\n }", + "completions": [ + "```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n // Given\n Iterable blobDigests = Arrays.asList(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n RequestMetadata requestMetadata = RequestMetadata.newBuilder().setClientId(\"client1\").build();\n\n // When\n ListenableFuture> future =\n underTest.findMissingBlobs(blobDigests, requestMetadata);\n\n // Then\n Iterable missingBlobs = future.get();\n assertThat(missingBlobs).containsExactly(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n }\n```\n" + ], + "line": 37, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 97, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent_notFound()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent_notFound", + "reference": " @Test\n public void testReadComponent_notFound() {\n Component component = mock(Component.class);\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(null);\n try {\n underTest.readComponent(\"someid\", \"testRepositoryName\");\n fail(\"Exception should have been thrown\");\n }\n catch (WebApplicationException e) {\n assertThat(e.getResponse(), is(notNullValue()));\n assertThat(e.getResponse().getStatus(), is(404));\n }\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent_notFound() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(() -> storageTx);\n when(storageTx.findComponent(componentId)).thenReturn(null);\n\n // When\n try {\n readComponent(repository, componentId);\n fail(\"Expected WebApplicationException\");\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Status.NOT_FOUND.getStatusCode(), e.getResponse().getStatus());\n }\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture auth(MQTT3AuthData authData) {\n try {\n HttpRequest request = HttpRequest.newBuilder()\n .uri(URI.create(webhookURI + \"/auth\"))\n .header(\"Content-Type\", \"application/json\")\n .POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))\n .timeout(Duration.ofSeconds(5))\n .build();\n return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())\n .thenApply(response -> {\n if (response.statusCode() == 200) {\n try {\n MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();\n JsonFormat.parser()\n .ignoringUnknownFields()\n .merge(response.body(), resultBuilder);\n return resultBuilder.build();\n } catch (Throwable e) {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build();\n }\n } else {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(\"Authenticate failed\")\n .build())\n .build();\n }\n })\n .exceptionally(e -> {\n System.out.println(\"Failed to call webhook: \" + e.getMessage());\n return null;\n });\n } catch (Throwable e) {\n return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAuth()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAuth", + "reference": " @Test\n public void testAuth() {\n MQTT3AuthData authedUser = MQTT3AuthData.newBuilder()\n .setUsername(\"authUser\")\n .build();\n\n MQTT3AuthData unauthUser = MQTT3AuthData.newBuilder()\n .setUsername(\"unauthUser\")\n .build();\n testServer.addAuthedUser(authedUser.getUsername());\n WebHookBasedAuthProvider provider = new WebHookBasedAuthProvider(testServer.getURI());\n MQTT3AuthResult authResult = provider.auth(authedUser).join();\n assertTrue(authResult.hasOk());\n\n authResult = provider.auth(unauthUser).join();\n assertFalse(authResult.hasOk());\n }", + "completions": [ + "```java\n @Test\n public void testAuth() {\n // Given\n String webhookURI = \"http://localhost:8080\";\n HttpClient httpClient = HttpClient.newHttpClient();\n MQTT3AuthData authData = MQTT3AuthData.newBuilder()\n .setClientId(\"client1\")\n .setUsername(\"user1\")\n .setPassword(\"password1\")\n .build();\n MQTT3AuthResult expectedResult = MQTT3AuthResult.newBuilder()\n .setSuccess(Success.newBuilder()\n .setIsSuperuser(false)\n .build())\n .build();\n\n // When\n CompletableFuture result = new MQTT3AuthService(webhookURI, httpClient).auth(authData);\n\n // Then\n assertNotNull(result);\n assertEquals(expectedResult, result.join());\n }\n```\n" + ], + "line": 59, + "token": 335, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 99, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_success", + "reference": " @Test\n public void testValidateRoleDuplicate_success() {\n // 调用,不会抛异常\n roleService.validateRoleDuplicate(randomString(), randomString(), null);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_success() {\n // Given\n String name = \"TestRole\";\n String code = \"TEST_ROLE\";\n Long id = 1L;\n\n // When\n // Call the method to be tested\n validateRoleDuplicate(name, code, id);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/3/EI/line_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Java/3/EI/line_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..e96ee55b2310b6ede55e8dce3b918ae83f260049 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/3/EI/line_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(22~57),subset_1(59~85),subset_2(94~128) +StarCoder2-15b,32.05,25.09,17.77 +CodeLlama-7b,33.09,25.34,23.49 +CodeLlama-13b,30.49,24.62,21.73 +CodeLlama-34b,31.61,24.30,23.39 +DeepSeek-Coder-1.3b,30.93,25.69,20.56 +DeepSeek-Coder-6.7b,20.22,23.69,22.39 +DeepSeek-Coder-33b,34.15,27.25,20.25 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/3/EI/token_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Java/3/EI/token_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..f6b21891f1512ea2964e5bc49d59ea02b0d1ea55 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/3/EI/token_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~377),subset_1(424~664),subset_2(681~952) +StarCoder2-15b,32.51,22.26,19.93 +CodeLlama-7b,33.77,22.98,25.15 +CodeLlama-13b,30.70,23.83,21.73 +CodeLlama-34b,31.85,23.58,24.12 +DeepSeek-Coder-1.3b,31.14,24.98,19.11 +DeepSeek-Coder-6.7b,20.65,21.81,23.54 +DeepSeek-Coder-33b,34.23,25.40,22.73 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/3/EI/tongji.py b/dataset/Test Generation/ComplexCodeEval-Java/3/EI/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/3/EI/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/3/QS/QS.json b/dataset/Test Generation/ComplexCodeEval-Java/3/QS/QS.json new file mode 100644 index 0000000000000000000000000000000000000000..c0392cceefff7b0a2bbdeccaa59ceff949f78538 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/3/QS/QS.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTask", + "reference": " @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(TEST_S3_BUCKET);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 100L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockKafkaConsumer.consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(true);\n Mockito.when(mockKafkaConsumer.close()).thenReturn(true);\n Mockito.when(mockChunkManager.stopAsync()).thenReturn(true);\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenReturn(true);\n Mockito.when(AstraKafkaConsumer.makeKafkaConfig(Mockito.any(), Mockito.anyInt())).thenReturn(new Properties());\n Mockito.when(RecoveryChunkManager.fromConfig(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(mockChunkManager);\n Mockito.when(adminClient.listConsumerGroupOffsets(Mockito.anyString())).thenReturn(new ConsumerGroupOffsets());\n Mockito.when(AstraConfig.getRecoveryConfig()).thenReturn(new RecoveryConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig()).thenReturn(new KafkaConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic()).thenReturn(\"testTopic\");\n Mockito.when(AstraConfig.getIndexerConfig()).thenReturn(new IndexerConfig());\n Mockito.when(AstraConfig.getS3Config()).thenReturn(new S3Config());\n Mockito.when(meterRegistry.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class));\n Mockito.when(Timer.start(Mockito.any())).thenReturn(Mockito.mock(Timer.Sample.class));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n Assert.assertTrue(result);\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).prepConsumerForConsumption(Mockito.anyLong());\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong());\n Mockito.verify(mockChunkManager, Mockito.times(1)).waitForRollOvers();\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).close();\n Mockito.verify(mockChunkManager, Mockito.times(1)).stopAsync();\n Mockito.verify(mockChunkManager, Mockito.times(1)).awaitTerminated(Mockito.anyLong());\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 1, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskFailure", + "reference": " @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n String fakeS3Bucket = \"fakeBucket\";\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(fakeS3Bucket);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isFalse();\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 10L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(false);\n Mockito.when(mockKafkaConsumer.close()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.stopAsync()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertFalse(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 2, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public int migration(String srcTable, List uuids, String targetTable) {\n if (Objects.isNull(srcTable) || Objects.isNull(targetTable) || CollectionUtils.isEmpty(uuids)) {\n throw new SwValidationException(\n ValidSubject.EVALUATION, \"srcTable, targetTable and uuids must not be null\");\n }\n var filter = TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build();\n return store.migration(DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(filter)\n .build()\n );\n }\n```\n###test function signature:\n```java\n @Test\n public void testMigration()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMigration", + "reference": " @Test\n public void testMigration() {\n var srcTable = \"p/1/space/1/eval/summary\";\n var targetTable = \"p/1/eval/summary\";\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(null, List.of(\"uuid1\"), targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, null, targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, List.of(\"uuid1\"), null));\n\n evaluationRepo.migration(srcTable, List.of(\"uuid1\", \"uuid2\"), targetTable);\n\n verify(dataStore, times(1)).migration(\n DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(\n TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(List.of(\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid1\")))\n .build(),\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid2\")))\n .build()))\n .build())\n .build()\n );\n }", + "completions": [ + "```java\n @Test\n public void testMigration() {\n // Given\n String srcTable = \"srcTable\";\n List uuids = Arrays.asList(\"uuid1\", \"uuid2\", \"uuid3\");\n String targetTable = \"targetTable\";\n DataStoreMigrationRequest expectedRequest = DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build()\n )\n .build();\n when(store.migration(any(DataStoreMigrationRequest.class))).thenReturn(1);\n\n // When\n int result = migrationService.migration(srcTable, uuids, targetTable);\n\n // Then\n assertEquals(1, result);\n verify(store).migration(expectedRequest);\n }\n```\n" + ], + "line": 37, + "token": 221, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDatasMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDatasMapTest", + "reference": " @Test\n public void sendDatasMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"doris_test_db\";\n List> dataList = new ArrayList<>();\n for (int i = 0; i < 100; i++) {\n Map data = new HashMap<>();\n data.put(\"id\", i);\n data.put(\"name\", \"张三\");\n data.put(\"message\", \"股东风波的法规的规定翻跟斗广泛的给yryrtytr\");\n System.out.println(gson.toJson(data));\n dataList.add(data);\n\n }\n dorisStreamLoad.sendData(database, tableName, columnList, dataList);\n }", + "completions": [ + "```java\n @Test\n public void sendDataTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 4, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_serializable_key_and_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_serializable_key_and_value", + "reference": " @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map original = new HashMap<>();\n original.put(\"key\", new SerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map originalMap = new HashMap<>();\n originalMap.put(\"one\", 1);\n originalMap.put(\"two\", 2);\n originalMap.put(\"three\", 3);\n\n Map clonedMap = clone(originalMap);\n\n assertNotSame(originalMap, clonedMap);\n assertEquals(originalMap, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 5, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void encode(String schemaName, long nextTransactionOffset) {\n if (schemaName == null) {\n schemaName = UNKNOWN;\n }\n this.schemaName = schemaName;\n this.nextTransactionOffset = nextTransactionOffset;\n final byte[] payloadBytes = payloadToBytes();\n final int payloadLength = payloadBytes.length;\n\n // set logEventHeader\n int eventSize = eventHeaderLengthVersionGt1 + payloadLength;\n setLogEventHeader(new LogEventHeader(drc_filter_log_event.getType(), 0, eventSize, nextTransactionOffset)\n );\n\n // set payload\n final ByteBuf payloadByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(payloadLength);\n payloadByteBuf.writeBytes(payloadBytes);\n payloadByteBuf.skipBytes(payloadLength);\n setPayloadBuf(payloadByteBuf);\n }\n```\n###test function signature:\n```java\n @Test\n public void encode()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "encode", + "reference": " @Test\n public void encode() {\n FilterLogEvent filterLogEvent = new FilterLogEvent();\n filterLogEvent.encode(\"drc1\", 101);\n Assert.assertEquals(\"drc1\", filterLogEvent.getSchemaName());\n Assert.assertEquals(101, filterLogEvent.getNextTransactionOffset());\n\n ByteBuf headerBuf = filterLogEvent.getLogEventHeader().getHeaderBuf();\n headerBuf.readerIndex(0);\n ByteBuf payloadBuf = filterLogEvent.getPayloadBuf();\n payloadBuf.readerIndex(0);\n CompositeByteBuf compositeByteBuf = PooledByteBufAllocator.DEFAULT.compositeDirectBuffer();\n compositeByteBuf.addComponents(true, headerBuf, payloadBuf);\n\n FilterLogEvent newFilterLogEvent = new FilterLogEvent();\n newFilterLogEvent.read(compositeByteBuf);\n compositeByteBuf.release(compositeByteBuf.refCnt());\n Assert.assertEquals(\"drc1\", newFilterLogEvent.getSchemaName());\n Assert.assertEquals(101, newFilterLogEvent.getNextTransactionOffset());\n }", + "completions": [ + "```java\n @Test\n public void encode() {\n // Given\n String schemaName = \"testSchema\";\n long nextTransactionOffset = 1234567890L;\n drc_filter_log_event event = new drc_filter_log_event();\n\n // When\n event.encode(schemaName, nextTransactionOffset);\n\n // Then\n assertEquals(schemaName, event.getSchemaName());\n assertEquals(nextTransactionOffset, event.getNextTransactionOffset());\n assertEquals(drc_filter_log_event.getType(), event.getLogEventHeader().getType());\n assertEquals(0, event.getLogEventHeader().getServerId());\n assertEquals(eventHeaderLengthVersionGt1 + event.getPayloadBuf().readableBytes(), event.getLogEventHeader().getEventLength());\n assertEquals(nextTransactionOffset, event.getLogEventHeader().getNextTransactionOffset());\n assertEquals(event.getPayloadBuf().readableBytes(), event.getPayloadBuf().readableBytes());\n }\n```\n" + ], + "line": 29, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void put(String path, long position, ByteBuf data) {\n writeLock.lock();\n try {\n int dataLength = data.readableBytes();\n NavigableMap cache = path2cache.computeIfAbsent(path, k -> new TreeMap<>());\n Map.Entry pos2value = cache.floorEntry(position);\n long cacheStartPosition;\n Value value;\n if (pos2value == null || pos2value.getKey() + pos2value.getValue().dataLength < position) {\n cacheStartPosition = position;\n value = Value.EMPTY;\n } else {\n cacheStartPosition = pos2value.getKey();\n value = pos2value.getValue();\n }\n // ensure the capacity, if the capacity change then update the cache index\n int moreCapacity = (int) ((position + dataLength) - (cacheStartPosition + value.blocks.length * (long) blockSize));\n int newDataLength = (int) (position + dataLength - cacheStartPosition);\n if (moreCapacity > 0) {\n int[] blocks = ensureCapacity(cacheStartPosition, moreCapacity);\n if (blocks == null) {\n return;\n }\n int[] newBlocks = new int[value.blocks.length + blocks.length];\n System.arraycopy(value.blocks, 0, newBlocks, 0, value.blocks.length);\n System.arraycopy(blocks, 0, newBlocks, value.blocks.length, blocks.length);\n value = new Value(newBlocks, newDataLength);\n } else {\n value = new Value(value.blocks, newDataLength);\n }\n cache.put(cacheStartPosition, value);\n lru.put(new Key(path, cacheStartPosition), value);\n\n // write data to cache\n ByteBuffer cacheByteBuffer = this.cacheByteBuffer.duplicate();\n int positionDelta = (int) (position - cacheStartPosition);\n int written = 0;\n ByteBuffer[] nioBuffers = data.nioBuffers();\n int[] blocks = value.blocks;\n for (ByteBuffer nioBuffer : nioBuffers) {\n ByteBuf buf = Unpooled.wrappedBuffer(nioBuffer);\n while (buf.readableBytes() > 0) {\n int writePosition = positionDelta + written;\n int block = blocks[writePosition / blockSize];\n cacheByteBuffer.position(block * blockSize + writePosition % blockSize);\n int length = Math.min(buf.readableBytes(), blockSize - writePosition % blockSize);\n cacheByteBuffer.put(buf.slice(buf.readerIndex(), length).nioBuffer());\n buf.skipBytes(length);\n written += length;\n }\n }\n } finally {\n writeLock.unlock();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testMergePut() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMergePut", + "reference": " @Test\n public void testMergePut() throws IOException {\n FileCache fileCache = new FileCache(\"/tmp/file_cache_test\", 10 * 1024, 1024);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 1, 500));\n buf.addComponent(true, genBuf((byte) 2, 500));\n buf.addComponent(true, genBuf((byte) 3, 500));\n fileCache.put(\"test\", 3333L, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(1500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 4, 500));\n buf.addComponent(true, genBuf((byte) 5, 500));\n buf.addComponent(true, genBuf((byte) 6, 500));\n fileCache.put(\"test\", 3333L + 1000, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(2500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 7, 500));\n fileCache.put(\"test\", 3333L + 1000 + 1500, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(3000, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n\n assertTrue(verify(fileCache.get(\"test\", 3333L, 500).get(), (byte) 1));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 500, 500).get(), (byte) 2));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1000, 500).get(), (byte) 4));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1500, 500).get(), (byte) 5));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2000, 500).get(), (byte) 6));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2500, 500).get(), (byte) 7));\n }", + "completions": [ + "```java\n @Test\n public void testMergePut() throws IOException {\n // Given\n String path = \"/test/path\";\n long position = 100;\n ByteBuf data = Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5});\n\n // When\n cache.put(path, position, data);\n\n // Then\n NavigableMap cacheMap = cache.path2cache.get(path);\n assertNotNull(cacheMap);\n Map.Entry entry = cacheMap.floorEntry(position);\n assertNotNull(entry);\n assertEquals(position, entry.getKey().longValue());\n assertEquals(5, entry.getValue().dataLength);\n assertEquals(1, entry.getValue().blocks[0]);\n assertEquals(2, entry.getValue().blocks[1]);\n assertEquals(3, entry.getValue().blocks[2]);\n assertEquals(4, entry.getValue().blocks[3]);\n assertEquals(5, entry.getValue().blocks[4]);\n }\n```\n" + ], + "line": 65, + "token": 508, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 7, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchById()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchById", + "reference": " @Test\n public void testSearchById() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult index =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(index.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchById() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 8, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = new ArrayList<>();\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 9, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_object", + "reference": " @Test\n public void should_clone_serializable_object() {\n Object original = new SerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_object() {\n // Given\n Serializable object = new Serializable() {\n private static final long serialVersionUID = 1L;\n // Implement necessary methods\n };\n\n // When\n Serializable clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertEquals(object, clonedObject);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 10, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() {\n\n\t\tvar text = \"Hello, how are you?\";\n\n\t\tEmbedding response = vertexAiPaLm2Api.embedText(text);\n\n\t\tassertThat(response).isNotNull();\n\t\tassertThat(response.value()).hasSize(768);\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(new EmbeddingResponse(expectedEmbedding));\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Optional maybeFetchErrorIntervalMs(short errorCode, int intervalMs) {\n if (errorCode == Errors.NONE.code())\n return Optional.empty();\n\n int pushIntervalMs;\n String reason;\n\n Errors error = Errors.forCode(errorCode);\n switch (error) {\n case INVALID_REQUEST:\n case INVALID_RECORD:\n case UNSUPPORTED_VERSION:\n pushIntervalMs = Integer.MAX_VALUE;\n reason = \"The broker response indicates the client sent an request that cannot be resolved\"\n + \" by re-trying, hence disable telemetry\";\n break;\n case UNKNOWN_SUBSCRIPTION_ID:\n case UNSUPPORTED_COMPRESSION_TYPE:\n pushIntervalMs = 0;\n reason = error.message();\n break;\n case TELEMETRY_TOO_LARGE:\n case THROTTLING_QUOTA_EXCEEDED:\n reason = error.message();\n pushIntervalMs = (intervalMs != -1) ? intervalMs : ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS;\n break;\n default:\n reason = \"Unwrapped error code\";\n log.error(\"Error code: {}. Unmapped error for telemetry, disable telemetry.\", errorCode);\n pushIntervalMs = Integer.MAX_VALUE;\n }\n\n log.debug(\"Error code: {}, reason: {}. Push interval update to {} ms.\", errorCode, reason, pushIntervalMs);\n return Optional.of(pushIntervalMs);\n }\n```\n###test function signature:\n```java\n @Test\n public void testMaybeFetchErrorIntervalMs()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMaybeFetchErrorIntervalMs", + "reference": " @Test\n public void testMaybeFetchErrorIntervalMs() {\n assertEquals(Optional.empty(), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.NONE.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_REQUEST.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_RECORD.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SUBSCRIPTION_ID.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNSUPPORTED_COMPRESSION_TYPE.code(), -1));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), 20000));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), 20000));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SERVER_ERROR.code(), -1));\n }", + "completions": [ + "```java\n @Test\n public void testMaybeFetchErrorIntervalMs() {\n // Test case 1: errorCode is Errors.NONE.code()\n short errorCode = Errors.NONE.code();\n int intervalMs = 1000;\n Optional result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertFalse(result.isPresent());\n\n // Test case 2: errorCode is Errors.INVALID_REQUEST.code()\n errorCode = Errors.INVALID_REQUEST.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 3: errorCode is Errors.UNKNOWN_SUBSCRIPTION_ID.code()\n errorCode = Errors.UNKNOWN_SUBSCRIPTION_ID.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 4: errorCode is Errors.TELEMETRY_TOO_LARGE.code()\n errorCode = Errors.TELEMETRY_TOO_LARGE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 5: errorCode is Errors.THROTTLING_QUOTA_EXCEEDED.code()\n errorCode = Errors.THROTTLING_QUOTA_EXCEEDED.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 6: errorCode is Errors.UNSUPPORTED_COMPRESSION_TYPE.code()\n errorCode = Errors.UNSUPPORTED_COMPRESSION_TYPE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 7: errorCode is Errors.INVALID_RECORD.code()\n errorCode = Errors.INVALID_RECORD.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 8: errorCode is Errors.UNSUPPORTED_VERSION.code()\n errorCode = Errors.UNSUPPORTED_VERSION.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 9: errorCode is an unmapped error\n errorCode = (short) 9999;\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n }\n```\n" + ], + "line": 43, + "token": 290, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 12, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_non_serializable_object", + "reference": " @Test\n public void should_clone_non_serializable_object() {\n Object original = new NonSerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_non_serializable_object() {\n // Given\n Map nonSerializableMap = new HashMap<>();\n nonSerializableMap.put(\"key1\", new Object());\n nonSerializableMap.put(\"key2\", new Object());\n\n // When\n Map clonedMap = objectCloner.clone(nonSerializableMap);\n\n // Then\n assertNotSame(nonSerializableMap, clonedMap);\n assertNotSame(nonSerializableMap.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(nonSerializableMap.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 13, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueOperationCompletesOperationWithCachedActionResult", + "reference": " @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ActionKey actionKey = DigestUtil.asActionKey(Digest.newBuilder().setHash(\"test\").build());\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"operation-with-cached-action-result\")\n .setActionDigest(actionKey.getDigest())\n .build();\n\n ActionResult actionResult = ActionResult.getDefaultInstance();\n\n when(mockBackplane.getActionResult(eq(actionKey))).thenReturn(actionResult);\n\n Poller poller = mock(Poller.class);\n\n instance.queue(executeEntry, poller, DEFAULT_TIMEOUT).get();\n\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(CACHE_CHECK));\n verify(mockBackplane, never()).putOperation(any(Operation.class), eq(QUEUED));\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(COMPLETED));\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n ActionResult actionResult = ActionResult.newBuilder().setExitCode(0).build();\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(\"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n when(cache.get(any(ActionKey.class))).thenReturn(actionResult);\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n future.get();\n\n verify(poller).pause();\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 14, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueActionFailsQueueEligibility", + "reference": " @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n Directory inputRoot = Directory.newBuilder().build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(false);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_INVALID)\n .setSubject(INVALID_PLATFORM)\n .setDescription(\n \"properties are not valid for queue eligibility: []. If you think your\"\n + \" queue should still accept these poperties without them being\"\n + \" specified in queue configuration, consider configuring the queue\"\n + \" with `allow_unmatched: True`\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(ACTION_DIGEST));\n when(executeEntry.getStdoutStreamName()).thenReturn(STDOUT_STREAM_NAME);\n when(executeEntry.getStderrStreamName()).thenReturn(STDERR_STREAM_NAME);\n when(executeEntry.getOperationName()).thenReturn(OPERATION_NAME);\n when(executeEntry.getRequestMetadata()).thenReturn(REQUEST_METADATA);\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n when(poller.pause()).thenThrow(new RuntimeException(\"Queue eligibility failed\"));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n\n assertThrows(RuntimeException.class, future::get);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 15, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new SerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new SerializableObject(\"name2\"),\n new SerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key\", new SerializableObject());\n SerializableObject object = new SerializableObject(map);\n\n // When\n SerializableObject clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertNotSame(object.getMap(), clonedObject.getMap());\n assertNotSame(object.getMap().get(\"key\"), clonedObject.getMap().get(\"key\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 16, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForSumAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForSumAgg", + "reference": " @Test\n public void testFullIndexSearchForSumAgg() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new SumAggBuilder(\"test\", TEST_SOURCE_LONG_PROPERTY, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalSum internalSum =\n (InternalSum) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n // 1, 3, 4, 5\n assertThat(internalSum.getValue()).isEqualTo(13);\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForSumAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new SumAggBuilder(\"testField\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertEquals(SumAggBuilder.class, result.getAggregation().getClass());\n assertEquals(\"testField\", ((SumAggBuilder) result.getAggregation()).getField());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 17, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void process(String base64AuthenticatorData, String signature, String clientDataJson, Fido2RegistrationData registration,\n Fido2AuthenticationData authenticationEntity) {\n log.debug(\"Registration: {}\", registration);\n\n AuthData authData = authenticatorDataParser.parseAssertionData(base64AuthenticatorData);\n commonVerifiers.verifyRpIdHash(authData, registration.getDomain());\n\n log.debug(\"User verification option: {}\", authenticationEntity.getUserVerificationOption());\n userVerificationVerifier.verifyUserVerificationOption(authenticationEntity.getUserVerificationOption(), authData);\n\n byte[] clientDataHash = DigestUtils.getSha256Digest().digest(base64Service.urlDecode(clientDataJson));\n\n try {\n int counter = authenticatorDataParser.parseCounter(authData.getCounters());\n commonVerifiers.verifyCounter(registration.getCounter(), counter);\n registration.setCounter(counter);\n\n JsonNode uncompressedECPointNode = dataMapperService.cborReadTree(base64Service.urlDecode(registration.getUncompressedECPoint()));\n PublicKey publicKey = coseService.createUncompressedPointFromCOSEPublicKey(uncompressedECPointNode);\n\n log.debug(\"Uncompressed ECpoint node: {}\", uncompressedECPointNode);\n log.debug(\"EC Public key hex: {}\", Hex.encodeHexString(publicKey.getEncoded()));\n log.debug(\"Registration algorithm: {}, default use: -7\", registration.getSignatureAlgorithm());\n authenticatorDataVerifier.verifyAssertionSignature(authData, clientDataHash, signature, publicKey, -7);\n\n } catch (Fido2CompromisedDevice ex) {\n log.error(\"Error compromised device: {}\", ex.getMessage());\n throw ex;\n } catch (Exception ex) {\n log.error(\"Error to check none assertion: {}\", ex.getMessage());\n throw new Fido2RuntimeException(\"Failed to check none assertion: {}\", ex.getMessage(), ex);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "process_ifCborReadTreeThrowException_fido2RuntimeException", + "reference": " @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n String base64AuthenticatorData = \"base64AuthenticatorData_test\";\n String signature = \"signature_test\";\n String clientDataJson = \"clientDataJson_test\";\n Fido2RegistrationData registration = mock(Fido2RegistrationData.class);\n Fido2AuthenticationData authenticationEntity = mock(Fido2AuthenticationData.class);\n\n when(authenticationEntity.getUserVerificationOption()).thenReturn(UserVerification.preferred);\n when(registration.getDomain()).thenReturn(\"domain_test\");\n when(registration.getCounter()).thenReturn(100);\n when(registration.getUncompressedECPoint()).thenReturn(\"uncompressedECPoint_test\");\n\n when(authenticatorDataParser.parseAssertionData(any())).thenReturn(mock(AuthData.class));\n when(base64Service.urlDecode(any(String.class))).thenReturn(\"decode_test\".getBytes());\n when(dataMapperService.cborReadTree(any(byte[].class))).thenThrow(new IOException(\"IOException_test\"));\n\n Fido2RuntimeException ex = assertThrows(Fido2RuntimeException.class, () -> noneAssertionFormatProcessor.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity));\n assertNotNull(ex);\n assertEquals(ex.getMessage(), \"IOException_test\");\n\n verify(log).debug(eq(\"Registration: {}\"), any(Fido2RegistrationData.class));\n verify(log).debug(eq(\"User verification option: {}\"), any(UserVerification.class));\n verify(commonVerifiers).verifyRpIdHash(any(AuthData.class), any(String.class));\n verify(authenticatorDataParser).parseCounter(any());\n verify(log).error(eq(\"Error to check none assertion: {}\"), any(String.class));\n\n verify(log, never()).error(eq(\"Error compromised device: {}\"), any(String.class));\n verifyNoInteractions(coseService, authenticatorDataVerifier);\n verifyNoMoreInteractions(log);\n }", + "completions": [ + "```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n // Arrange\n String base64AuthenticatorData = \"base64AuthenticatorData\";\n String signature = \"signature\";\n String clientDataJson = \"clientDataJson\";\n Fido2RegistrationData registration = new Fido2RegistrationData();\n Fido2AuthenticationData authenticationEntity = new Fido2AuthenticationData();\n\n when(authenticatorDataParser.parseAssertionData(base64AuthenticatorData)).thenReturn(new AuthData());\n when(authenticatorDataParser.parseCounter(any())).thenReturn(1);\n when(dataMapperService.cborReadTree(any())).thenThrow(new IOException());\n\n // Act & Assert\n Assertions.assertThrows(Fido2RuntimeException.class, () -> {\n fido2AuthenticationService.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity);\n });\n }\n```\n" + ], + "line": 38, + "token": 365, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 18, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void setUserData_invalidKey_throws()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "setUserData_invalidKey_throws", + "reference": "\t@Test\n\tpublic void setUserData_invalidKey_throws() {\n\t\t// setup\n\t\tTestParameters parameters = new TestParameters();\n\n\t\t// test\n\t\tfor (String key : new String[] { null, \"\" }) {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, \"test\");\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid key; key must be non-empty, non-null\"),\n\t\t\t\t\tex.getMessage());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@Test(expected = IllegalArgumentException.class)\n\tpublic void setUserData_invalidKey_throws() {\n\t\tTestClass testClass = new TestClass();\n\t\ttestClass.setUserData(\"\", \"value\");\n\t}\n```\n\nIn this test, we are creating an instance of the TestClass and calling the setUserData method with an empty key and a valid value. We expect this to throw an IllegalArgumentException, which is what we specify in the @Test annotation.\n" + ], + "line": 23, + "token": 170, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueDirectoryMissingErrorsOperation", + "reference": " @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ByteString foo = ByteString.copyFromUtf8(\"foo\");\n Digest subdirDigest = DIGEST_UTIL.compute(foo);\n Directory inputRoot =\n Directory.newBuilder()\n .addDirectories(\n DirectoryNode.newBuilder().setName(\"missing-subdir\").setDigest(subdirDigest))\n .build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(true);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_MISSING)\n .setSubject(\"blobs/\" + DigestUtil.toString(subdirDigest))\n .setDescription(\"The directory `/missing-subdir` was not found in the CAS.\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(10);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(DigestUtil.DEFAULT_INSTANCE_NAME + \"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n ActionKey actionKey = DigestUtil.asActionKey(executeEntry.getActionDigest());\n when(checkCacheFuture(actionKey, any(), any())).thenReturn(immediateFuture(true));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n assertThat(future.get()).isEqualTo(null);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 20, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime = Instant.now();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Start recovery with an offset range that is partially unavailable\n long startOffset = 50;\n long endOffset = 150;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry)).isEqualTo(50);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(51);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 100L,\n 200L,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 21, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public List selectHeaderAccept(String[] accepts) {\n if (accepts.length == 0) {\n return null;\n }\n for (String accept : accepts) {\n MediaType mediaType = MediaType.parseMediaType(accept);\n if (isJsonMime(mediaType) && !isProblemJsonMime(accept)) {\n return Collections.singletonList(mediaType);\n }\n }\n return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts));\n }\n```\n###test function signature:\n```java\n @Test\n public void testSelectHeaderAccept()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSelectHeaderAccept", + "reference": " @Test\n public void testSelectHeaderAccept() {\n String[] accepts = {\"application/json\", \"application/xml\"};\n assertEquals(\"application/json\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"APPLICATION/XML\", \"APPLICATION/JSON\"};\n assertEquals(\"APPLICATION/JSON\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"application/xml\", \"application/json; charset=UTF8\"};\n assertEquals(\"application/json; charset=UTF8\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"text/plain\", \"application/xml\"};\n assertEquals(\"text/plain,application/xml\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{};\n assertNull(apiClient.selectHeaderAccept(accepts));\n }", + "completions": [ + "```java\n @Test\n public void testSelectHeaderAccept() {\n // Given\n String[] accepts = {\"application/json\", \"application/problem+json\"};\n // When\n List result = selectHeaderAccept(accepts);\n // Then\n assertEquals(1, result.size());\n assertEquals(\"application/json\", result.get(0).toString());\n }\n```\n" + ], + "line": 23, + "token": 154, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollEventSentOnConsumerPoll", + "reference": " @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n SubscriptionState subscriptions = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);\n consumer = newConsumer(\n mock(FetchBuffer.class),\n new ConsumerInterceptors<>(Collections.emptyList()),\n mock(ConsumerRebalanceListenerInvoker.class),\n subscriptions,\n singletonList(new RoundRobinAssignor()),\n \"group-id\",\n \"client-id\");\n final TopicPartition tp = new TopicPartition(\"topic\", 0);\n final List> records = singletonList(\n new ConsumerRecord<>(\"topic\", 0, 2, \"key1\", \"value1\"));\n doAnswer(invocation -> Fetch.forPartition(tp, records, true))\n .when(fetchCollector)\n .collectFetch(Mockito.any(FetchBuffer.class));\n\n consumer.subscribe(singletonList(\"topic1\"));\n consumer.poll(Duration.ofMillis(100));\n verify(applicationEventHandler).add(any(PollEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n // Arrange\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n consumer.poll(Duration.ofMillis(100));\n\n // Act\n ConsumerRecords records = consumer.poll(Duration.ofMillis(100));\n\n // Assert\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 23, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetMultiplePartitionRecoveriesBehind", + "reference": " @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n final RecoveryTaskMetadata recoveryTask2 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"2\",\n \"2\",\n recoveryStartOffset * 3 + 1,\n recoveryStartOffset * 4,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask2);\n final RecoveryTaskMetadata recoveryTask21 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"21\", \"2\", recoveryStartOffset * 4 + 1, 50000, createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask21);\n await()\n .until(\n () ->\n recoveryTaskStore\n .listSync()\n .containsAll(\n List.of(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(5);\n assertThat(recoveryTasks)\n .contains(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n // Given\n String partitionId = \"partition1\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 24, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture commitStreamObject(apache.rocketmq.controller.v1.S3StreamObject streamObject,\n List compactedObjects) {\n LOGGER.info(\"commitStreamObject with streamObject: {}, compactedObjects: {}\", TextFormat.shortDebugString(streamObject),\n compactedObjects);\n\n CompletableFuture future = new CompletableFuture<>();\n try (SqlSession session = sessionFactory.openSession()) {\n if (streamObject.getObjectId() == S3Constants.NOOP_OBJECT_ID) {\n LOGGER.error(\"S3StreamObject[object-id={}] is null or objectId is unavailable\", streamObject.getObjectId());\n String msg = String.format(\"S3StreamObject[object-id=%d] is null or objectId is unavailable\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.NOT_FOUND_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n\n long committedTs = System.currentTimeMillis();\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n\n // commit object\n if (!commitObject(streamObject.getObjectId(), streamObject.getStreamId(), streamObject.getObjectSize(), session)) {\n String msg = String.format(\"S3StreamObject[object-id=%d] is not ready for commit\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.ILLEGAL_STATE_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n long dataTs = committedTs;\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n dataTs = compactedObjects.stream()\n .map(id -> {\n // mark destroy compacted object\n S3Object object = s3ObjectMapper.getById(id);\n object.setState(S3ObjectState.BOS_WILL_DELETE);\n object.setMarkedForDeletionTimestamp(new Date());\n s3ObjectMapper.markToDelete(object.getId(), new Date());\n\n // update dataTs to the min compacted object's dataTs\n com.automq.rocketmq.metadata.dao.S3StreamObject s3StreamObject =\n s3StreamObjectMapper.getByObjectId(id);\n return s3StreamObject.getBaseDataTimestamp().getTime();\n })\n .min(Long::compareTo).get();\n }\n\n List toCache = new ArrayList<>();\n\n // create a new S3StreamObject to replace committed ones\n if (streamObject.getObjectId() != S3Constants.NOOP_OBJECT_ID) {\n com.automq.rocketmq.metadata.dao.S3StreamObject newS3StreamObj =\n new com.automq.rocketmq.metadata.dao.S3StreamObject();\n newS3StreamObj.setStreamId(streamObject.getStreamId());\n newS3StreamObj.setObjectId(streamObject.getObjectId());\n newS3StreamObj.setObjectSize(streamObject.getObjectSize());\n newS3StreamObj.setStartOffset(streamObject.getStartOffset());\n newS3StreamObj.setEndOffset(streamObject.getEndOffset());\n newS3StreamObj.setBaseDataTimestamp(new Date(dataTs));\n newS3StreamObj.setCommittedTimestamp(new Date(committedTs));\n s3StreamObjectMapper.create(newS3StreamObj);\n toCache.add(newS3StreamObj);\n }\n\n // delete the compactedObjects of S3Stream\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n compactedObjects.forEach(id -> s3StreamObjectMapper.delete(null, null, id));\n }\n session.commit();\n\n // Update Cache\n s3StreamObjectCache.cache(streamObject.getStreamId(), toCache);\n s3StreamObjectCache.onCompact(streamObject.getStreamId(), compactedObjects);\n\n LOGGER.info(\"S3StreamObject[object-id={}] commit success, compacted objects: {}\",\n streamObject.getObjectId(), compactedObjects);\n future.complete(null);\n } catch (Exception e) {\n LOGGER.error(\"CommitStream failed\", e);\n ControllerException ex = new ControllerException(Code.INTERNAL_VALUE, \"CommitStream failed\" + e.getMessage());\n future.completeExceptionally(ex);\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCommitStreamObject() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCommitStreamObject", + "reference": " @Test\n public void testCommitStreamObject() throws IOException {\n long objectId, streamId = 1;\n\n try (S3MetadataService metadataService = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n objectId = metadataService.prepareS3Objects(3, 5).get();\n } catch (Exception e) {\n throw new RuntimeException(e);\n }\n\n S3StreamObject news3StreamObject = S3StreamObject.newBuilder()\n .setObjectId(objectId + 2)\n .setStreamId(streamId)\n .setObjectSize(111L)\n .build();\n\n try (SqlSession session = this.getSessionFactory().openSession()) {\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n buildS3StreamObjs(objectId, 2, 3, 100L).forEach(s3StreamObjectMapper::create);\n session.commit();\n }\n\n try (S3MetadataService service = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n List compactedObjects = new ArrayList<>();\n compactedObjects.add(objectId);\n compactedObjects.add(objectId + 1);\n service.commitStreamObject(news3StreamObject, compactedObjects);\n }\n\n try (SqlSession session = getSessionFactory().openSession()) {\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3Object s3Object = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object.getState());\n\n S3Object s3Object1 = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object1.getState());\n\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n for (long index = objectId; index < objectId + 2; index++) {\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(index);\n Assertions.assertNull(object);\n }\n\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(objectId + 2);\n Assertions.assertEquals(111L, object.getObjectSize());\n Assertions.assertEquals(streamId, object.getStreamId());\n Assertions.assertTrue(object.getBaseDataTimestamp().getTime() > 0);\n Assertions.assertTrue(object.getCommittedTimestamp().getTime() > 0);\n }\n }", + "completions": [ + "```java\n @Test\n public void testCommitStreamObject() throws IOException {\n // Given\n apache.rocketmq.controller.v1.S3StreamObject streamObject = new apache.rocketmq.controller.v1.S3StreamObject();\n streamObject.setObjectId(1L);\n streamObject.setStreamId(2L);\n streamObject.setObjectSize(3L);\n streamObject.setStartOffset(4L);\n streamObject.setEndOffset(5L);\n\n List compactedObjects = Arrays.asList(6L, 7L, 8L);\n\n // When\n CompletableFuture future = controller.commitStreamObject(streamObject, compactedObjects);\n\n // Then\n assertDoesNotThrow(() -> future.get());\n }\n```\n" + ], + "line": 85, + "token": 664, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 25, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyTooManyResults", + "reference": " @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(totalHits.value).thenReturn(2L);\n Mockito.when(hits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(response.getHits()).thenReturn(hits);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processService.getProcessByKey(processDefinitionKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Expected exception message to match\",\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n e.getMessage());\n }\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String doGet(String url) {\n try (CloseableHttpClient client = HttpClients.createDefault()) {\n HttpGet request = new HttpGet(url);\n HttpResponse response = client.execute(request);\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n return EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n }\n } catch (IOException e) {\n log.error(\"DoGet exception.\", e);\n\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n void testDoGet()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDoGet", + "reference": " @Test\n void testDoGet() {\n String result = HttpUtil.doGet(\"http://baidu.com\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testDoGet() {\n String url = \"http://example.com\";\n String expectedResponse = \"Hello, world!\";\n\n // Mock the HttpClient and HttpResponse\n CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);\n HttpResponse mockResponse = Mockito.mock(HttpResponse.class);\n StatusLine mockStatusLine = Mockito.mock(StatusLine.class);\n HttpEntity mockEntity = Mockito.mock(HttpEntity.class);\n\n // Set up the mock objects to return the expected values\n Mockito.when(mockClient.execute(Mockito.any(HttpGet.class))).thenReturn(mockResponse);\n Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);\n Mockito.when(mockStatusLine.getStatusCode()).thenReturn(HttpStatus.SC_OK);\n Mockito.when(mockResponse.getEntity()).thenReturn(mockEntity);\n Mockito.when(EntityUtils.toString(mockEntity, StandardCharsets.UTF_8)).thenReturn(expectedResponse);\n\n // Call the method under test\n String actualResponse = YourClassName.doGet(url);\n\n // Verify the result\n Assertions.assertEquals(expectedResponse, actualResponse);\n }\n```\nPlease replace `YourClassName` with the actual name of the class where the `doGet` method is located.\n" + ], + "line": 23, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n dpopService.validateDpopThumprint(null, \"any\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n // Given\n String existingThumprint = \"\";\n String requestThumprint = \"requestThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n try {\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), e.getResponse().getStatus());\n assertEquals(\"Thumprint does not match\", e.getMessage());\n }\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDataMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDataMapTest", + "reference": " @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"hera_log_doris_table_2_11\";\n for (int i = 0; i < 1; i++) {\n String jsonString = \"{\\\"linenumber\\\":1256,\\\"tailId\\\":90219,\\\"mqtag\\\":\\\"tags_4_13_90219\\\",\\\"filename\\\":\\\"/home/work/log/log-agent/server.log\\\",\\\"tail\\\":\\\"demo-client-agent\\\",\\\"mqtopic\\\":\\\"90219_hera-demo-client\\\",\\\"message\\\":\\\"2024-01-10 19:34:40,360|INFO ||NettyClientPublicExecutor_1|c.x.mone.log.agent.rpc.task.PingTask|83|ping res: log-agent-server:2022-12-05:0.0.2->2024-01-10 19:34:40 358->10.53.129.250:9899\\\",\\\"logstore\\\":\\\"测试doris日志\\\",\\\"logip\\\":\\\"10.53.129.176\\\",\\\"timestamp\\\":1704886481241}\";\n\n Map map = gson.fromJson(jsonString, Map.class);\n dorisStreamLoad.sendData(database, tableName, columnListNew, map);\n }\n }", + "completions": [ + "```java\n @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 29, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyNoResults", + "reference": " @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final NotFoundException exception =\n Assertions.assertThrows(\n NotFoundException.class, () -> processService.getProcessByKey(processDefinitionKey));\n\n // Then\n Assertions.assertEquals(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 30, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String getDiagramByKey(Long processDefinitionKey) {\n final IdsQueryBuilder q = idsQuery().addIds(processDefinitionKey.toString());\n\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n\n if (response.getHits().getTotalHits().value == 1) {\n final Map result = response.getHits().getHits()[0].getSourceAsMap();\n return (String) result.get(BPMN_XML);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with id '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with id '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\n \"Exception occurred, while obtaining the process diagram: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetDiagramByKeyTooManyResults", + "reference": " @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getDiagramByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 1L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHits.Relation.EQUAL_TO));\n\n // When\n try {\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n processDiagramService.getDiagramByKey(processDefinitionKey);\n } catch (NotFoundException e) {\n // Then\n Assert.assertEquals(\n \"Could not find unique process with id '\" + processDefinitionKey + \"'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 37, + "token": 309, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 31, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n void shutdown() {\n LOG.info(\"Running shutdown hook.\");\n try {\n serviceManager.stopAsync().awaitStopped(30, TimeUnit.SECONDS);\n } catch (Exception e) {\n // stopping timed out\n LOG.error(\"ServiceManager shutdown timed out\", e);\n }\n try {\n curatorFramework.unwrap().close();\n } catch (Exception e) {\n LOG.error(\"Error while closing curatorFramework \", e);\n }\n LOG.info(\"Shutting down LogManager\");\n LogManager.shutdown();\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexerShutdownTwice() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexerShutdownTwice", + "reference": " @Test\n public void testIndexerShutdownTwice() throws Exception {\n startKafkaServer();\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // Create a live partition for this partiton\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTimeMs = 1;\n final long endTimeMs = 100;\n final long maxOffset = 30;\n SnapshotMetadata livePartition0 =\n new SnapshotMetadata(\n name + \"live0\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"0\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition0);\n\n SnapshotMetadata livePartition1 =\n new SnapshotMetadata(\n name + \"live1\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"1\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition1);\n\n final SnapshotMetadata partition0 =\n new SnapshotMetadata(name, path, startTimeMs, endTimeMs, maxOffset, \"0\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition0);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .containsOnly(livePartition1, livePartition0, partition0);\n\n // Empty consumer offset since there is no prior consumer.\n astraIndexer =\n new AstraIndexer(\n chunkManagerUtil.chunkManager,\n curatorFramework,\n makeIndexerConfig(50),\n getKafkaConfig(),\n metricsRegistry);\n astraIndexer.startAsync();\n astraIndexer.awaitRunning(DEFAULT_START_STOP_DURATION);\n await().until(() -> kafkaServer.getConnectedConsumerGroups() == 1);\n\n // Produce more messages since the recovery task is created for head.\n produceMessagesToKafka(kafkaServer.getBroker(), startTime);\n\n consumeMessagesAndSearchMessagesTest(100, 1);\n\n // Live snapshot is deleted, recovery task is created.\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(livePartition1, partition0);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isEqualTo(4);\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).size()).isEqualTo(1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(searchMetadataStore).size()).isEqualTo(1);\n RecoveryTaskMetadata recoveryTask1 =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).get(0);\n assertThat(recoveryTask1.startOffset).isEqualTo(31);\n assertThat(recoveryTask1.endOffset).isEqualTo(99);\n assertThat(recoveryTask1.partitionId).isEqualTo(\"0\");\n\n // Shutting down is idempotent. So, doing it twice shouldn't throw an error.\n astraIndexer.shutDown();\n astraIndexer.shutDown();\n astraIndexer = null;\n }", + "completions": [ + "```java\n @Test\n public void testIndexerShutdownTwice() throws Exception {\n // Given\n Indexer indexer = new Indexer();\n indexer.serviceManager = Mockito.mock(ServiceManager.class);\n indexer.curatorFramework = Mockito.mock(CuratorFramework.class);\n Mockito.when(indexer.serviceManager.stopAsync()).thenReturn(Mockito.mock(ServiceManager.class));\n Mockito.when(indexer.serviceManager.awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class))).thenThrow(new TimeoutException());\n Mockito.when(indexer.curatorFramework.unwrap()).thenReturn(Mockito.mock(Closeable.class));\n Mockito.doThrow(new IOException()).when(indexer.curatorFramework.unwrap()).close();\n\n // When\n indexer.shutdown();\n\n // Then\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).stopAsync();\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class));\n Mockito.verify(indexer.curatorFramework.unwrap(), Mockito.times(1)).close();\n }\n```\n" + ], + "line": 27, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(addressList,\"test\",\"test\",\"test\", SystemEnum.AGENT, MsgTypeEnum.ALARM);\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = null;\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 33, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyWithException", + "reference": " @Test\n public void testGetProcessByKeyWithException() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(OperateRuntimeException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(totalHits.value).thenReturn(2L);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final Executable executable = () -> processService.getProcessByKey(processDefinitionKey);\n\n // Then\n final NotFoundException exception =\n Assertions.assertThrows(NotFoundException.class, executable);\n Assertions.assertEquals(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 34, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testConcurrentSearches() throws InterruptedException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testConcurrentSearches", + "reference": " @Test\n public void testConcurrentSearches() throws InterruptedException {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n AtomicInteger searchFailures = new AtomicInteger(0);\n AtomicInteger statsFailures = new AtomicInteger(0);\n AtomicInteger searchExceptions = new AtomicInteger(0);\n AtomicInteger successfulRuns = new AtomicInteger(0);\n\n Runnable searchRun =\n () -> {\n for (int i = 0; i < 100; i++) {\n try {\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n 0L,\n MAX_TIME,\n 100,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n if (babies.hits.size() != 2) {\n searchFailures.addAndGet(1);\n } else {\n successfulRuns.addAndGet(1);\n }\n } catch (Exception e) {\n searchExceptions.addAndGet(1);\n }\n }\n };\n\n Thread t1 = new Thread(searchRun);\n Thread t2 = new Thread(searchRun);\n t1.start();\n t2.start();\n t1.join();\n t2.join();\n assertThat(searchExceptions.get()).isEqualTo(0);\n assertThat(statsFailures.get()).isEqualTo(0);\n assertThat(searchFailures.get()).isEqualTo(0);\n assertThat(successfulRuns.get()).isEqualTo(200);\n }", + "completions": [ + "```java\n @Test\n public void testConcurrentSearches() throws InterruptedException {\n // Create a fixed thread pool with 10 threads\n ExecutorService executorService = Executors.newFixedThreadPool(10);\n\n // Create a list to hold the futures\n List>> futures = new ArrayList<>();\n\n // Submit 100 search tasks to the executor service\n for (int i = 0; i < 100; i++) {\n // Create a search task\n Callable> searchTask = () -> {\n // Generate a random dataset\n String dataset = UUID.randomUUID().toString();\n\n // Generate a random query string\n String queryStr = UUID.randomUUID().toString();\n\n // Generate a random start time\n long startTimeMsEpoch = System.currentTimeMillis() - ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random end time\n long endTimeMsEpoch = startTimeMsEpoch + ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random number of hits requested\n int howMany = ThreadLocalRandom.current().nextInt(100);\n\n // Generate a random aggregation builder\n AggBuilder aggBuilder = new AggBuilder();\n\n // Perform the search\n return logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n };\n\n // Submit the search task to the executor service and add the future to the list\n futures.add(executorService.submit(searchTask));\n }\n\n // Wait for all search tasks to complete\n for (Future> future : futures) {\n future.get();\n }\n\n // Shutdown the executor service\n executorService.shutdown();\n }\n```\n" + ], + "line": 85, + "token": 608, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 35, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String sendGetData(String path, Map headMap) {\n log.info(\"request tcm, path:{},headMap:{} \",path,headMap);\n String result = \"\";\n CloseableHttpResponse response =null;\n try(CloseableHttpClient httpClient = HttpClients.createDefault()) {\n // 创建get方式请求对象\n HttpGet httpGet = new HttpGet(path);\n if (null != headMap) {\n for (Map.Entry entry : headMap.entrySet()) {\n httpGet.addHeader(entry.getKey(), entry.getValue());\n }\n }\n // 通过请求对象获取响应对象\n response = httpClient.execute(httpGet);\n\n // 获取结果实体\n // 判断网络连接状态码是否正常(0--200都数正常)\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n result = EntityUtils.toString(response.getEntity(), UTF_8);\n } else {\n log.error(\"get请求传输 异常.url:{}, headMap:{}\", path, headMap);\n }\n } catch (ClientProtocolException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n } catch (IOException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n }\n finally {\n try {\n if (null!=response){\n response.close();\n }\n } catch (IOException e) {\n log.error(\"关闭response 异常 \", e);\n }\n }\n log.debug(result);\n return result;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendGetData()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendGetData", + "reference": " @Test\n void testSendGetData() {\n final Map headMap = new HashMap<>();\n final String result = HttpUtils.sendGetData(\"path\", headMap);\n assertThat(result).isEqualTo(\"\");\n }", + "completions": [ + "```java\n @Test\n void testSendGetData() {\n String path = \"http://localhost:8080/api/test\";\n Map headMap = new HashMap<>();\n headMap.put(\"Content-Type\", \"application/json\");\n headMap.put(\"Authorization\", \"Bearer your_token\");\n\n String result = sendGetData(path, headMap);\n\n // Add assertions to verify the result\n // For example, if the result is expected to be a JSON string, you can use a JSON library to parse the result and assert the values\n // Here is an example using the JSON library Jackson\n ObjectMapper objectMapper = new ObjectMapper();\n JsonNode jsonNode = objectMapper.readTree(result);\n assertEquals(\"expectedValue\", jsonNode.get(\"key\").asText());\n }\n```\nPlease replace \"your_token\" with the actual token you want to use for the test. Also, replace \"expectedValue\" and \"key\" with the actual expected value and key from the JSON response.\n" + ], + "line": 51, + "token": 367, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 36, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ByteString retainKeyPrefix(ByteString tenantNS, List topicFilterLevels) {\n ByteString prefix = ByteString.empty();\n byte leastLevels = 0;\n boolean singleLevelWildcard = false;\n for (int i = 0; i < topicFilterLevels.size(); i++) {\n String tfl = topicFilterLevels.get(i);\n if (\"+\".equals(tfl)) {\n leastLevels++;\n singleLevelWildcard = true;\n continue;\n }\n if (\"#\".equals(tfl)) {\n break;\n }\n leastLevels++;\n if (!singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(tfl));\n }\n if (i + 1 < topicFilterLevels.size()) {\n if (!topicFilterLevels.get(i + 1).equals(\"#\") && !singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(NUL));\n }\n }\n }\n return tenantNS.concat(unsafeWrap(new byte[] {leastLevels})).concat(prefix);\n }\n```\n###test function signature:\n```java\n @Test\n public void testRetainKeyPrefix()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRetainKeyPrefix", + "reference": " @Test\n public void testRetainKeyPrefix() {\n String tenantId = \"tenantA\";\n ByteString tenantNS = tenantNS(tenantId);\n assertEquals(retainKeyPrefix(tenantNS, parse(\"#\", false)), tenantNS.concat(levelByte(0)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/#\", false)),\n tenantNS.concat(levelByte(1).concat(copyFromUtf8(\"a\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/+\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(\"a\").concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/b\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)).concat(copyFromUtf8(\"b\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/#\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+\", false)),\n tenantNS.concat(\n levelByte(3).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+/+\", false)),\n tenantNS.concat(\n levelByte(4).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/\", false)),\n tenantNS.concat(levelByte(4).concat(copyFromUtf8(NUL))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/+/\", false)),\n tenantNS.concat(levelByte(5).concat(copyFromUtf8(NUL))));\n }", + "completions": [ + "```java\n @Test\n public void testRetainKeyPrefix() {\n ByteString tenantNS = ByteString.copyFromUtf8(\"tenant1\");\n List topicFilterLevels = Arrays.asList(\"level1\", \"level2\", \"+\", \"#\");\n ByteString expectedPrefix = ByteString.copyFromUtf8(\"tenant1\\u0001level1\\u0000level2\");\n ByteString actualPrefix = retainKeyPrefix(tenantNS, topicFilterLevels);\n assertEquals(expectedPrefix, actualPrefix);\n }\n```\n" + ], + "line": 37, + "token": 255, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 37, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n @Transactional(rollbackFor = Exception.class)\n public void updateFileConfigMaster(Long id) {\n // 校验存在\n validateFileConfigExists(id);\n // 更新其它为非 master\n fileConfigMapper.updateBatch(new FileConfigDO().setMaster(false));\n // 更新\n fileConfigMapper.updateById(new FileConfigDO().setId(id).setMaster(true));\n // 发送刷新配置的消息\n TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {\n\n @Override\n public void afterCommit() {\n fileConfigProducer.sendFileConfigRefreshMessage();\n }\n\n });\n }\n```\n###test function signature:\n```java\n @Test\n public void testUpdateFileConfigMaster_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testUpdateFileConfigMaster_success", + "reference": " @Test\n public void testUpdateFileConfigMaster_success() {\n // mock 数据\n FileConfigDO dbFileConfig = randomFileConfigDO().setMaster(false);\n fileConfigMapper.insert(dbFileConfig);// @Sql: 先插入出一条存在的数据\n FileConfigDO masterFileConfig = randomFileConfigDO().setMaster(true);\n fileConfigMapper.insert(masterFileConfig);// @Sql: 先插入出一条存在的数据\n\n // 调用\n fileConfigService.updateFileConfigMaster(dbFileConfig.getId());\n // 断言数据\n assertTrue(fileConfigMapper.selectById(dbFileConfig.getId()).getMaster());\n assertFalse(fileConfigMapper.selectById(masterFileConfig.getId()).getMaster());\n // verify 调用\n verify(fileConfigProducer).sendFileConfigRefreshMessage();\n }", + "completions": [ + "```java\n @Test\n public void testUpdateFileConfigMaster_success() {\n // 准备数据\n FileConfigDO fileConfigDO = new FileConfigDO();\n fileConfigDO.setId(1L);\n fileConfigDO.setMaster(false);\n fileConfigMapper.insert(fileConfigDO);\n\n // 执行方法\n fileConfigService.updateFileConfigMaster(1L);\n\n // 验证结果\n FileConfigDO updatedFileConfigDO = fileConfigMapper.selectById(1L);\n Assert.assertTrue(updatedFileConfigDO.getMaster());\n\n // 清理数据\n fileConfigMapper.deleteById(1L);\n }\n```\n" + ], + "line": 28, + "token": 164, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\t@Override\n\tpublic void add(List documents) {\n\n\t\tAssert.notNull(documents, \"The document list should not be null.\");\n\t\tif (CollectionUtils.isEmpty(documents)) {\n\t\t\treturn; // nothing to do;\n\t\t}\n\n\t\tfinal var searchDocuments = documents.stream().map(document -> {\n\t\t\tfinal var embeddings = this.embeddingClient.embed(document);\n\t\t\tSearchDocument searchDocument = new SearchDocument();\n\t\t\tsearchDocument.put(ID_FIELD_NAME, document.getId());\n\t\t\tsearchDocument.put(EMBEDDING_FIELD_NAME, embeddings);\n\t\t\tsearchDocument.put(CONTENT_FIELD_NAME, document.getContent());\n\t\t\tsearchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());\n\n\t\t\t// Add the filterable metadata fields as top level fields, allowing filler\n\t\t\t// expressions on them.\n\t\t\tfor (MetadataField mf : this.filterMetadataFields) {\n\t\t\t\tif (document.getMetadata().containsKey(mf.name())) {\n\t\t\t\t\tsearchDocument.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name()));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn searchDocument;\n\t\t}).toList();\n\n\t\tIndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);\n\n\t\tfor (IndexingResult indexingResult : result.getResults()) {\n\t\t\tAssert.isTrue(indexingResult.isSucceeded(),\n\t\t\t\t\tString.format(\"Document with key %s upload is not successfully\", indexingResult.getKey()));\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void addAndSearchTest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "addAndSearchTest", + "reference": "\t@Test\n\tpublic void addAndSearchTest() {\n\n\t\tcontextRunner.run(context -> {\n\n\t\t\tVectorStore vectorStore = context.getBean(VectorStore.class);\n\n\t\t\tvectorStore.add(documents);\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\t\t\t}, hasSize(1));\n\n\t\t\tList results = vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\n\t\t\tassertThat(results).hasSize(1);\n\t\t\tDocument resultDoc = results.get(0);\n\t\t\tassertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());\n\t\t\tassertThat(resultDoc.getContent()).contains(\"The Great Depression (1929–1939) was an economic shock\");\n\t\t\tassertThat(resultDoc.getMetadata()).hasSize(2);\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"meta2\");\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"distance\");\n\n\t\t\t// Remove all documents from the store\n\t\t\tvectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Hello\").withTopK(1));\n\t\t\t}, hasSize(0));\n\t\t});\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void addAndSearchTest() {\n\t\t// Given\n\t\tList documents = new ArrayList<>();\n\t\tDocument document1 = new Document();\n\t\tdocument1.setId(\"1\");\n\t\tdocument1.setContent(\"This is a test document\");\n\t\tdocument1.setMetadata(new HashMap<>());\n\t\tdocument1.getMetadata().put(\"author\", \"John Doe\");\n\t\tdocument1.getMetadata().put(\"date\", \"2022-01-01\");\n\t\tdocuments.add(document1);\n\n\t\tDocument document2 = new Document();\n\t\tdocument2.setId(\"2\");\n\t\tdocument2.setContent(\"This is another test document\");\n\t\tdocument2.setMetadata(new HashMap<>());\n\t\tdocument2.getMetadata().put(\"author\", \"Jane Doe\");\n\t\tdocument2.getMetadata().put(\"date\", \"2022-01-02\");\n\t\tdocuments.add(document2);\n\n\t\t// When\n\t\tsearchService.add(documents);\n\n\t\t// Then\n\t\tSearchRequest searchRequest = new SearchRequest();\n\t\tsearchRequest.setQuery(\"test\");\n\t\tsearchRequest.setFilter(new HashMap<>());\n\t\tsearchRequest.getFilter().put(\"author\", \"John Doe\");\n\t\tSearchResponse searchResponse = searchService.search(searchRequest);\n\n\t\tAssert.assertEquals(1, searchResponse.getDocuments().size());\n\t\tAssert.assertEquals(\"1\", searchResponse.getDocuments().get(0).getId());\n\t}\n```\n" + ], + "line": 39, + "token": 312, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 39, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {\n long startTime = System.currentTimeMillis();\n\n // 校验参数\n try {\n validateParams(metrics);\n } catch (Exception e) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(e.getMessage());\n return;\n }\n\n HttpContext httpContext = createHttpContext(metrics.getNginx());\n HttpUriRequest request = createHttpRequest(metrics.getNginx());\n try (CloseableHttpResponse response = CommonHttpClient.getHttpClient().execute(request, httpContext)){\n // 发起http请求,获取响应数据\n int statusCode = response.getStatusLine().getStatusCode();\n if (statusCode != SUCCESS_CODE) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(\"StatusCode \" + statusCode);\n return;\n }\n String resp = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n\n Long responseTime = System.currentTimeMillis() - startTime;\n // 根据metrics name选择调用不同解析方法\n if (NGINX_STATUS_NAME.equals(metrics.getName()) || AVAILABLE.equals(metrics.getName())) {\n parseNginxStatusResponse(builder, resp, metrics, responseTime);\n } else if (REQ_STATUS_NAME.equals(metrics.getName())) {\n parseReqStatusResponse(builder, resp, metrics, responseTime);\n }\n } catch (Exception e) {\n String errorMsg = CommonUtil.getMessageFromThrowable(e);\n log.info(errorMsg);\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(errorMsg);\n } finally {\n if (request != null) {\n request.abort();\n }\n }\n\n }\n```\n###test function signature:\n```java\n @Test\n public void testNginxCollect()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNginxCollect", + "reference": " @Test\n public void testNginxCollect() {\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n long monitorId = 999;\n String app = \"testNginx\";\n\n Metrics metrics = new Metrics();\n metrics.setNginx(nginxProtocol);\n nginxCollect.collect(builder, monitorId, app, metrics);\n }", + "completions": [ + "```java\n @Test\n public void testNginxCollect() {\n // 创建一个Metrics对象\n Metrics metrics = new Metrics();\n metrics.setName(\"nginx_status\");\n metrics.setNginx(new Nginx(\"http://localhost:8080/nginx_status\"));\n\n // 创建一个CollectRep.MetricsData.Builder对象\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n\n // 创建一个NginxCollect对象\n NginxCollect nginxCollect = new NginxCollect();\n\n // 调用collect方法\n nginxCollect.collect(builder, 1, \"test\", metrics);\n\n // 验证结果\n assertEquals(CollectRep.Code.SUCCESS, builder.getCode());\n assertNotNull(builder.getData());\n assertTrue(builder.getData().getMetrics().containsKey(\"active_connections\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"accepts\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"handled\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"requests\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"reading\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"writing\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"waiting\"));\n }\n```\n" + ], + "line": 51, + "token": 335, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 40, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Transactional(rollbackOn = Exception.class)\n public void submit(@NonNull PartitionPlanConfig partitionPlanConfig)\n throws SchedulerException, ClassNotFoundException {\n Long databaseId = partitionPlanConfig.getDatabaseId();\n Validate.notNull(databaseId, \"DatabaseId can not be null\");\n // disable all related partition plan task\n Database database = this.databaseService.detail(databaseId);\n disablePartitionPlan(database.getId());\n PartitionPlanEntity partitionPlanEntity = modelToEntity(partitionPlanConfig);\n partitionPlanEntity = this.partitionPlanRepository.save(partitionPlanEntity);\n if (!partitionPlanConfig.isEnabled()\n || CollectionUtils.isEmpty(partitionPlanConfig.getPartitionTableConfigs())) {\n log.info(\"Partition plan is disabled or table config is empty, do nothing and return\");\n return;\n }\n Validate.isTrue(partitionPlanConfig.getCreationTrigger() != null, \"Creation trigger can not be null\");\n if (partitionPlanConfig.getDroppingTrigger() == null) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(partitionPlanConfig.getPartitionTableConfigs(),\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n return;\n }\n Map> strategy2TblCfgs =\n partitionPlanConfig.getPartitionTableConfigs().stream().flatMap(tableConfig -> {\n Map> strategy2Cfgs =\n tableConfig.getPartitionKeyConfigs().stream()\n .collect(Collectors.groupingBy(PartitionPlanKeyConfig::getStrategy));\n return strategy2Cfgs.values().stream().map(cfgs -> {\n PartitionPlanTableConfig cfg = new PartitionPlanTableConfig();\n cfg.setPartitionKeyConfigs(cfgs);\n cfg.setTableName(tableConfig.getTableName());\n cfg.setEnabled(tableConfig.isEnabled());\n cfg.setPartitionNameInvoker(tableConfig.getPartitionNameInvoker());\n cfg.setPartitionNameInvokerParameters(tableConfig.getPartitionNameInvokerParameters());\n return cfg;\n });\n }).collect(Collectors.groupingBy(cfg -> cfg.getPartitionKeyConfigs().get(0).getStrategy()));\n List createConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.CREATE);\n List dropConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.DROP);\n if (CollectionUtils.isNotEmpty(createConfigs)) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(createConfigs,\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n if (CollectionUtils.isNotEmpty(dropConfigs)) {\n ScheduleEntity dropScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getDroppingTrigger());\n createPartitionPlanTables(dropConfigs,\n partitionPlanEntity.getId(), dropScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "submit_bothCreateAndDropTrigger_submitSucceed", + "reference": " @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n PartitionPlanTableConfig tableConfig = new PartitionPlanTableConfig();\n tableConfig.setTableName(MYSQL_REAL_RANGE_TABLE_NAME);\n tableConfig.setPartitionNameInvoker(\"CUSTOM_PARTITION_NAME_GENERATOR\");\n SqlExprBasedGeneratorConfig config = new SqlExprBasedGeneratorConfig();\n config.setGenerateExpr(\"concat('p', date_format(from_unixtime(unix_timestamp(\"\n + \"STR_TO_DATE(20240125, '%Y%m%d')) + \"\n + PartitionPlanVariableKey.INTERVAL.getVariable() + \"), '%Y%m%d'))\");\n config.setIntervalGenerateExpr(\"86400\");\n tableConfig.setPartitionNameInvokerParameters(getSqlExprBasedNameGeneratorParameters(config));\n PartitionPlanKeyConfig c3Create = getMysqlc3CreateConfig();\n PartitionPlanKeyConfig datekeyCreate = getMysqldatekeyCreateConfig();\n PartitionPlanKeyConfig dropConfig = getDropConfig();\n tableConfig.setPartitionKeyConfigs(Arrays.asList(c3Create, datekeyCreate, dropConfig));\n\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setPartitionTableConfigs(Collections.singletonList(tableConfig));\n partitionPlanConfig.setFlowInstanceId(1L);\n partitionPlanConfig.setTimeoutMillis(180000L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setDatabaseId(1L);\n\n long t1 = System.currentTimeMillis();\n long t2 = System.currentTimeMillis() + 1000;\n TriggerConfig createTrigger = TestRandom.nextObject(TriggerConfig.class);\n createTrigger.setStartAt(new Date(t1));\n partitionPlanConfig.setCreationTrigger(createTrigger);\n\n TriggerConfig dropTrigger = TestRandom.nextObject(TriggerConfig.class);\n dropTrigger.setStartAt(new Date(t2));\n partitionPlanConfig.setDroppingTrigger(dropTrigger);\n Database database = TestRandom.nextObject(Database.class);\n database.setId(1L);\n Mockito.when(this.databaseService.detail(1L)).thenReturn(database);\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t1;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t2;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.doNothing().when(this.scheduleService).enable(Mockito.isA(ScheduleEntity.class));\n this.partitionPlanScheduleService.submit(partitionPlanConfig);\n List actuals = this.partitionPlanTableRepository.findAll();\n Set scheduleIds = actuals.stream()\n .map(PartitionPlanTableEntity::getScheduleId).collect(Collectors.toSet());\n Assert.assertEquals(scheduleIds.size(), 2);\n }", + "completions": [ + "```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n // Given\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setDatabaseId(1L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanConfig.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanConfig.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanConfig.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanConfig.setTaskId(\"taskId\");\n partitionPlanConfig.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanConfig.setTimeoutMillis(1000L);\n\n Database database = new Database();\n database.setId(1L);\n database.setName(\"database\");\n database.setType(DatabaseType.MYSQL);\n database.setHost(\"localhost\");\n database.setPort(3306);\n database.setUsername(\"root\");\n database.setPassword(\"password\");\n database.setDatabaseName(\"test\");\n database.setProperties(\"{}\");\n database.setStatus(DatabaseStatus.ONLINE);\n database.setCreateTime(new Date());\n database.setUpdateTime(new Date());\n when(databaseService.detail(1L)).thenReturn(database);\n\n PartitionPlanEntity partitionPlanEntity = new PartitionPlanEntity();\n partitionPlanEntity.setId(1L);\n partitionPlanEntity.setDatabaseId(1L);\n partitionPlanEntity.setEnabled(true);\n partitionPlanEntity.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanEntity.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanEntity.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanEntity.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanEntity.setTaskId(\"taskId\");\n partitionPlanEntity.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanEntity.setTimeoutMillis(1000L);\n when(partitionPlanRepository.save(any(PartitionPlanEntity.class))).thenReturn(partitionPlanEntity);\n\n ScheduleEntity createScheduleEntity = new ScheduleEntity();\n createScheduleEntity.setId(1L);\n createScheduleEntity.setDatabaseId(1L);\n createScheduleEntity.setTriggerConfig(new TriggerConfig(TriggerType" + ], + "line": 70, + "token": 530, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 41, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAllQueryWithFullTextSearchDisabled", + "reference": " @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n Instant time = Instant.now();\n Trace.KeyValue customField =\n Trace.KeyValue.newBuilder()\n .setVStr(\"value\")\n .setKey(\"customField\")\n .setFieldType(Schema.SchemaFieldType.KEYWORD)\n .build();\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time, List.of(customField)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n SearchResult termQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"customField:value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(termQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermStrQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermStrQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermNumericQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermNumericQuery.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"*\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNull(result.getInternalAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 42, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollExecutedCommitAsyncCallbacks", + "reference": " @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks() {\n consumer = newConsumer();\n MockCommitCallback callback = new MockCommitCallback();\n completeCommitAsyncApplicationEventSuccessfully();\n doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n completeFetchedCommittedOffsetApplicationEventSuccessfully(mkMap());\n\n consumer.assign(Collections.singleton(new TopicPartition(\"foo\", 0)));\n assertDoesNotThrow(() -> consumer.commitAsync(new HashMap<>(), callback));\n assertMockCommitCallbackInvoked(() -> consumer.poll(Duration.ZERO),\n callback,\n null);\n }", + "completions": [ + "```java\n @Test\n public void testPoll() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertNotNull(records);\n assertEquals(0, records.count());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 43, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static URI resolve(URI base, URI ref) {\n String path = ref.getRawPath();\n if (path != null && !path.isEmpty()) {\n path = base.getRawPath();\n if (path == null || !path.endsWith(\"/\")) {\n try {\n base = new URI(base.getScheme(), base.getAuthority(), base.getPath() + '/', null, null);\n } catch (URISyntaxException e) {\n throw new IllegalStateException(e);\n }\n }\n }\n return URIUtils.resolve(base, ref);\n }\n```\n###test function signature:\n```java\n @Test\n void testResolve_BasePathTrailingSlash()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testResolve_BasePathTrailingSlash", + "reference": " @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://host/sub/dir/\");\n assertEquals(\"http://host/sub/dir/file.jar\", resolve(base, \"file.jar\"));\n assertEquals(\"http://host/sub/dir/dir/file.jar\", resolve(base, \"dir/file.jar\"));\n assertEquals(\"http://host/sub/dir/?arg=val\", resolve(base, \"?arg=val\"));\n assertEquals(\"http://host/sub/dir/file?arg=val\", resolve(base, \"file?arg=val\"));\n assertEquals(\"http://host/sub/dir/dir/file?arg=val\", resolve(base, \"dir/file?arg=val\"));\n }", + "completions": [ + "```java\n @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://example.com/path/\");\n URI ref = URI.create(\"http://example.com/path/file.txt\");\n URI expected = URI.create(\"http://example.com/path/file.txt\");\n URI actual = resolve(base, ref);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 25, + "token": 187, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_collection_of_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_collection_of_non_serializable_object", + "reference": " @Test\n public void should_clone_collection_of_non_serializable_object() {\n List original = new ArrayList<>();\n original.add(new NonSerializableObject(\"value\"));\n List cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_collection_of_non_serializable_object() {\n // Given\n List nonSerializableList = new ArrayList<>();\n nonSerializableList.add(new NonSerializableObject(\"test\"));\n nonSerializableList.add(new NonSerializableObject(\"test2\"));\n\n // When\n List clonedList = objectCloner.clone(nonSerializableList);\n\n // Then\n assertNotSame(nonSerializableList, clonedList);\n assertEquals(nonSerializableList.size(), clonedList.size());\n for (int i = 0; i < nonSerializableList.size(); i++) {\n assertNotSame(nonSerializableList.get(i), clonedList.get(i));\n assertEquals(nonSerializableList.get(i), clonedList.get(i));\n }\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 45, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssign()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssign", + "reference": " @Test\n public void testAssign() {\n consumer = newConsumer();\n final TopicPartition tp = new TopicPartition(\"foo\", 3);\n consumer.assign(singleton(tp));\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().contains(tp));\n verify(applicationEventHandler).add(any(AssignmentChangeEvent.class));\n verify(applicationEventHandler).add(any(NewTopicsMetadataUpdateRequestEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testAssign() {\n // Create a mock Consumer instance\n Consumer mockConsumer = mock(Consumer.class);\n\n // Create a collection of TopicPartition instances\n Collection partitions = new ArrayList<>();\n partitions.add(new TopicPartition(\"testTopic\", 0));\n partitions.add(new TopicPartition(\"testTopic\", 1));\n\n // Call the assign method\n mockConsumer.assign(partitions);\n\n // Verify that the assign method was called with the correct parameters\n verify(mockConsumer).assign(partitions);\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 46, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String relativePath(Path absolutePath, Path basePath) {\n Preconditions.checkArgument(absolutePath.isAbsolute(), \"absolutePath must be an absolute path\");\n Preconditions.checkArgument(basePath.isAbsolute(), \"basePath must be an absolute path\");\n\n List absolutePathComponents = Lists.newArrayList(toPathComponents(absolutePath)); // make a copy\n List basePathComponents = toPathComponents(basePath);\n boolean hasCommonPrefix = basePathComponents.isEmpty(); // when basePath is \"/\", we always have a common prefix\n\n for (String base : basePathComponents) {\n if (absolutePathComponents.get(0).equals(base)) {\n absolutePathComponents.remove(0);\n hasCommonPrefix = true;\n } else {\n break;\n }\n }\n if (hasCommonPrefix) {\n return PATH_JOINER.join(absolutePathComponents);\n } else {\n return absolutePath.toString();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testRelativePath() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRelativePath", + "reference": " @Test\n public void testRelativePath() throws Exception {\n assertEquals(\"a\", PathUtils.relativePath(Path.of(\"/a\"), Path.of(\"/\")));\n assertEquals(\"b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/a\")));\n assertEquals(\"b/c.json\", PathUtils.relativePath(Path.of(\"/a/b/c.json\"), Path.of(\"/a\")));\n assertEquals(\"c/d/e\", PathUtils.relativePath(Path.of(\"/a/b/c/d/e\"), Path.of(\"/a/b\")));\n assertEquals(\"/a/b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/c/d\"))); // no common prefix\n }", + "completions": [ + "```java\n @Test\n public void testRelativePath() throws Exception {\n Path absolutePath = Paths.get(\"/a/b/c/d\");\n Path basePath = Paths.get(\"/a/b\");\n String expected = \"c/d\";\n String actual = relativePath(absolutePath, basePath);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 31, + "token": 234, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Boolean isJson(String value) {\n JsonElement jsonElement;\n try {\n jsonElement = JsonParser.parseString(value);\n } catch (Exception e) {\n return false;\n }\n if (jsonElement == null || jsonElement.isJsonNull()) {\n return false;\n }\n if (!jsonElement.isJsonObject() && !jsonElement.isJsonArray()) {\n return false;\n }\n return true;\n }\n```\n###test function signature:\n```java\n @Test\n void testIsJson()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJson", + "reference": " @Test\n void testIsJson() {\n assertThat(JsonUtil.isJson(\"value\")).isFalse();\n }", + "completions": [ + "```java\n @Test\n void testIsJson() {\n // Test case 1: Valid JSON object\n String jsonObject = \"{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}\";\n assertTrue(isJson(jsonObject));\n\n // Test case 2: Valid JSON array\n String jsonArray = \"[{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}, {\\\"name\\\":\\\"Jane\\\", \\\"age\\\":25, \\\"city\\\":\\\"Chicago\\\"}]\";\n assertTrue(isJson(jsonArray));\n\n // Test case 3: Invalid JSON (not an object or array)\n String invalidJson = \"\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"\";\n assertFalse(isJson(invalidJson));\n\n // Test case 4: Empty string\n String emptyString = \"\";\n assertFalse(isJson(emptyString));\n\n // Test case 5: Null value\n String nullValue = null;\n assertFalse(isJson(nullValue));\n }\n```\n" + ], + "line": 26, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchOnlyHistogram()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchOnlyHistogram", + "reference": " @Test\n public void testSearchOnlyHistogram() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 0,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(0);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(2);\n\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(1).getDocCount()).isEqualTo(1);\n\n assertThat(\n Long.parseLong(histogram.getBuckets().get(0).getKeyAsString()) >= time.toEpochMilli())\n .isTrue();\n assertThat(\n Long.parseLong(histogram.getBuckets().get(1).getKeyAsString())\n <= time.plusSeconds(10).toEpochMilli())\n .isTrue();\n }", + "completions": [ + "```java\n @Test\n public void testSearchOnlyHistogram() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 0;\n AggBuilder aggBuilder = new AggBuilder();\n aggBuilder.addHistogramAgg(\"testField\", \"testInterval\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation().getHistogramAggs().size() > 0);\n assertEquals(\"testField\", result.getAggregation().getHistogramAggs().get(0).getField());\n assertEquals(\"testInterval\", result.getAggregation().getHistogramAggs().get(0).getInterval());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 49, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent", + "reference": " @Test\n public void testReadComponent() {\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n EntityMetadata entityMetadata = mock(EntityMetadata.class);\n VariableSource variableSource = mock(VariableSource.class);\n when(contentPermissionChecker.isPermitted(anyString(), any(), eq(BreadActions.BROWSE), any())).thenReturn(true);\n when(component.getEntityMetadata()).thenReturn(entityMetadata);\n when(entityMetadata.getId()).thenReturn(new DetachedEntityId(\"someid\"));\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(Arrays.asList(asset));\n when(assetVariableResolver.fromAsset(asset)).thenReturn(variableSource);\n ComponentXO componentXO = underTest.readComponent(\"someid\", \"testRepositoryName\");\n\n assertThat(componentXO, is(notNullValue()));\n assertThat(componentXO.getId(), is(\"someid\"));\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n Supplier txSupplier = () -> storageTx;\n\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(txSupplier);\n when(storageTx.findComponent(componentId)).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(ImmutableList.of(asset));\n\n // When\n ComponentXO result = readComponent(repository, componentId);\n\n // Then\n verify(repository).facet(StorageFacet.class);\n verify(storageFacet).txSupplier();\n verify(storageTx).begin();\n verify(storageTx).findComponent(componentId);\n verify(storageTx).browseAssets(component);\n verify(storageTx).close();\n // Add assertions for the result\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 50, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ListResult listEntities(String tableName, String searchIndexName, List matchFilters,\n List queryFilters, List multiMatchFilter, String nextToken, List sorters, Class clazz) {\n SearchQuery searchQuery = createSearchQuery(matchFilters, queryFilters, multiMatchFilter);\n SearchRequest searchRequest = new SearchRequest(tableName, searchIndexName, searchQuery);\n if (!StringUtils.isEmpty(nextToken)) {\n byte[] tokenBytes = EncryptionUtil.decode(nextToken);\n searchRequest.getSearchQuery().setToken(tokenBytes);\n } else {\n if (sorters != null &&!sorters.isEmpty()) {\n searchQuery.setSort(new Sort(sorters));\n }\n }\n SearchRequest.ColumnsToGet columnsToGet = new SearchRequest.ColumnsToGet();\n columnsToGet.setColumns(ReflectionUtil.getPropertyNames(clazz));\n searchRequest.setColumnsToGet(columnsToGet);\n log.info(\"searchRequest:{}\", JsonUtil.toJsonString(searchRequest));\n SearchResponse searchResponse = otsClient.search(searchRequest);\n if (searchResponse == null || searchResponse.getRows() == null) {\n return ListResult.genSuccessListResult(null, 0);\n }\n byte[] nextTokenBytes = searchResponse.getNextToken();\n nextToken = nextTokenBytes == null || nextTokenBytes.length == 0 ? null : EncryptionUtil.encode(nextTokenBytes);\n List result = searchResponse.getRows().stream()\n .map(row -> OtsUtil.convertRowToDTO(row, clazz))\n .collect(Collectors.toList());\n return ListResult.genSuccessListResult(result, searchResponse.getTotalCount(), nextToken);\n }\n```\n###test function signature:\n```java\n @Test\n void testListEntities()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListEntities", + "reference": " @Test\n void testListEntities() {\n final List matchFilters = createMatchFilters();\n final List queryFilters = createQueryFilters();\n final ListResult expectedResult = new ListResult<>();\n FieldSort fieldSort = new FieldSort(OrderOtsConstant.GMT_CREATE_LONG);\n fieldSort.setOrder(SortOrder.DESC);\n expectedResult.setData(Arrays.asList());\n expectedResult.setCount(0L);\n String nextToken = \"CAESFQoTChEKDWdtdENyZWF0ZUxvbmcQARgBIlQKCQBI8UqGigEAAApHA0IAAAAxUzM1MzQzMTM0NjQzMjYzMzAzMzYyMzE2MTMzMzkzOTM1MzEzNjM2MzM2NDM2MzAzMDMwNjYzNTM1MzA2NjY0MzM=\";\n expectedResult.setNextToken(nextToken);\n final SearchResponse searchResponse = new SearchResponse(new Response(\"requestId\"));\n when(mockOtsClient.search(any(SearchRequest.class))).thenReturn(searchResponse);\n ListResult result = baseOtsHelper.listEntities(\"order\", \"order_index\", matchFilters, queryFilters, null, nextToken, Arrays.asList(fieldSort), OrderDTO.class);\n assertThat(result.getCount()).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n void testListEntities() {\n // Given\n String tableName = \"testTable\";\n String searchIndexName = \"testIndex\";\n List matchFilters = Arrays.asList(new OtsFilter(\"field1\", OtsFilter.CompareOperator.EQUAL, \"value1\"),\n new OtsFilter(\"field2\", OtsFilter.CompareOperator.EQUAL, \"value2\"));\n List queryFilters = Arrays.asList(new OtsFilter(\"field3\", OtsFilter.CompareOperator.EQUAL, \"value3\"),\n new OtsFilter(\"field4\", OtsFilter.CompareOperator.EQUAL, \"value4\"));\n List multiMatchFilter = Arrays.asList(new OtsFilter(\"field5\", OtsFilter.CompareOperator.EQUAL, \"value5\"),\n new OtsFilter(\"field6\", OtsFilter.CompareOperator.EQUAL, \"value6\"));\n String nextToken = \"testNextToken\";\n List sorters = Arrays.asList(new Sort.Sorter(\"field7\", Sort.SortOrder.ASC),\n new Sort.Sorter(\"field8\", Sort.SortOrder.DESC));\n Class clazz = TestDTO.class;\n\n // When\n ListResult result = listEntities(tableName, searchIndexName, matchFilters, queryFilters, multiMatchFilter, nextToken, sorters, clazz);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalCount());\n assertNull(result.getNextToken());\n assertTrue(result.getData().isEmpty());\n }\n```\n" + ], + "line": 38, + "token": 346, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 51, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String[] listFiles(URI fileUri, boolean recursive) throws IOException {\n try {\n ImmutableList.Builder builder = ImmutableList.builder();\n String continuationToken = null;\n boolean isDone = false;\n String prefix = normalizeToDirectoryPrefix(fileUri);\n int fileCount = 0;\n while (!isDone) {\n ListObjectsV2Request.Builder listObjectsV2RequestBuilder =\n ListObjectsV2Request.builder().bucket(fileUri.getHost());\n if (!prefix.equals(DELIMITER)) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.prefix(prefix);\n }\n if (!recursive) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.delimiter(DELIMITER);\n }\n if (continuationToken != null) {\n listObjectsV2RequestBuilder.continuationToken(continuationToken);\n }\n ListObjectsV2Request listObjectsV2Request = listObjectsV2RequestBuilder.build();\n LOG.debug(\"Trying to send ListObjectsV2Request {}\", listObjectsV2Request);\n ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);\n LOG.debug(\"Getting ListObjectsV2Response: {}\", listObjectsV2Response);\n List filesReturned = listObjectsV2Response.contents();\n fileCount += filesReturned.size();\n filesReturned.stream()\n .forEach(\n object -> {\n // Only add files and not directories\n if (!object.key().equals(fileUri.getPath())\n && !object.key().endsWith(DELIMITER)) {\n String fileKey = object.key();\n if (fileKey.startsWith(DELIMITER)) {\n fileKey = fileKey.substring(1);\n }\n builder.add(S3_SCHEME + fileUri.getHost() + DELIMITER + fileKey);\n }\n });\n if (fileCount == LIST_MAX_KEYS) {\n // check if we reached the max keys returned, if so abort and throw an error message\n LOG.error(\n \"Too many files ({}) returned from S3 when attempting to list object prefixes\",\n LIST_MAX_KEYS);\n throw new IllegalStateException(\n String.format(\n \"Max keys (%s) reached when attempting to list S3 objects\", LIST_MAX_KEYS));\n }\n isDone = !listObjectsV2Response.isTruncated();\n continuationToken = listObjectsV2Response.nextContinuationToken();\n }\n String[] listedFiles = builder.build().toArray(new String[0]);\n LOG.debug(\n \"Listed {} files from URI: {}, is recursive: {}\", listedFiles.length, fileUri, recursive);\n return listedFiles;\n } catch (Throwable t) {\n throw new IOException(t);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListFilesInFolderRecursive", + "reference": " @Test\n public void testListFilesInFolderRecursive() throws Exception {\n String folder = \"list-files-rec\";\n String[] nestedFolders = new String[] {\"list-files-child-1\", \"list-files-child-2\"};\n String[] originalFiles = new String[] {\"a-list-3.txt\", \"b-list-3.txt\", \"c-list-3.txt\"};\n\n List expectedResultList = new ArrayList<>();\n for (String childFolder : nestedFolders) {\n String folderName = folder + DELIMITER + childFolder;\n for (String fileName : originalFiles) {\n createEmptyFile(folderName, fileName);\n expectedResultList.add(\n String.format(FILE_FORMAT, SCHEME, bucket, folderName + DELIMITER + fileName));\n }\n }\n String[] actualFiles =\n s3BlobFs.listFiles(URI.create(String.format(FILE_FORMAT, SCHEME, bucket, folder)), true);\n\n actualFiles =\n Arrays.stream(actualFiles).filter(x -> x.contains(\"list-3\")).toArray(String[]::new);\n assertEquals(actualFiles.length, expectedResultList.size());\n assertTrue(Arrays.equals(expectedResultList.toArray(), actualFiles));\n }", + "completions": [ + "```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception {\n // Given\n URI fileUri = new URI(\"s3://bucket/folder/\");\n boolean recursive = true;\n String[] expectedFiles = {\n \"s3://bucket/folder/file1.txt\",\n \"s3://bucket/folder/file2.txt\",\n \"s3://bucket/folder/subfolder/file3.txt\",\n \"s3://bucket/folder/subfolder/file4.txt\"\n };\n\n // When\n String[] actualFiles = s3FileSystem.listFiles(fileUri, recursive);\n\n // Then\n assertArrayEquals(expectedFiles, actualFiles);\n }\n```\n" + ], + "line": 70, + "token": 490, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 52, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_nameDuplicate", + "reference": " @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // mock 数据\n RoleDO roleDO = randomPojo(RoleDO.class, o -> o.setName(\"role_name\"));\n roleMapper.insert(roleDO);\n // 准备参数\n String name = \"role_name\";\n\n // 调用,并断言异常\n assertServiceException(() -> roleService.validateRoleDuplicate(name, randomString(), null),\n ROLE_NAME_DUPLICATE, name);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // Given\n String name = \"duplicateName\";\n String code = \"uniqueCode\";\n Long id = 1L;\n RoleDO existingRole = new RoleDO();\n existingRole.setId(2L);\n existingRole.setName(name);\n existingRole.setCode(code);\n when(roleMapper.selectByName(name)).thenReturn(existingRole);\n\n // When\n try {\n roleService.validateRoleDuplicate(name, code, id);\n fail(\"Expected exception not thrown\");\n } catch (ServiceException e) {\n // Then\n assertEquals(ROLE_NAME_DUPLICATE, e.getCode());\n assertEquals(name, e.getArgs()[0]);\n }\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public String sign(Map map) {\n try {\n Date date = new Date(System.currentTimeMillis() + jwtProperties.getExpireTimeMills());\n Algorithm algorithm = Algorithm.HMAC256(jwtProperties.getTokenSecret());\n Map header = new HashMap<>(2);\n header.put(\"typ\", \"jwt\");\n JWTCreator.Builder builder = JWT.create()\n .withHeader(header)\n .withIssuedAt(new Date())\n .withExpiresAt(date);\n map.entrySet().forEach(entry -> {\n if (entry.getValue() instanceof Integer) {\n builder.withClaim(entry.getKey(), (Integer) entry.getValue());\n } else if (entry.getValue() instanceof Long) {\n builder.withClaim(entry.getKey(), (Long) entry.getValue());\n } else if (entry.getValue() instanceof Boolean) {\n builder.withClaim(entry.getKey(), (Boolean) entry.getValue());\n } else if (entry.getValue() instanceof String) {\n builder.withClaim(entry.getKey(), String.valueOf(entry.getValue()));\n } else if (entry.getValue() instanceof Double) {\n builder.withClaim(entry.getKey(), (Double) entry.getValue());\n } else if (entry.getValue() instanceof Date) {\n builder.withClaim(entry.getKey(), (Date) entry.getValue());\n }\n });\n return builder.sign(algorithm);\n } catch (Exception e) {\n return null;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sign_generateJwtToken_succeed()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sign_generateJwtToken_succeed", + "reference": " @Test\n public void sign_generateJwtToken_succeed() {\n Map map = new HashMap<>();\n map.put(\"userId\", 123);\n map.put(\"username\", \"testUser\");\n String token = jwtService.sign(map);\n Assert.assertNotNull(token);\n }", + "completions": [ + "```java\n @Test\n public void sign_generateJwtToken_succeed() {\n // Given\n Map map = new HashMap<>();\n map.put(\"userId\", 1);\n map.put(\"username\", \"testUser\");\n map.put(\"isAdmin\", true);\n map.put(\"expireTime\", System.currentTimeMillis() + 3600000);\n\n // When\n String jwtToken = jwtService.sign(map);\n\n // Then\n assertNotNull(jwtToken);\n assertTrue(jwtToken.length() > 0);\n }\n```\n" + ], + "line": 41, + "token": 349, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 54, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_non_serializable_value_with_null_value", + "reference": " @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map original = new LinkedHashMap<>();\n original.put(\"null\", null);\n original.put(\"key\", new NonSerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map map = new HashMap<>();\n map.put(\"key\", null);\n\n Map clonedMap = (Map) clone(map);\n\n assertNotSame(map, clonedMap);\n assertEquals(map, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 55, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void init(final DiscoveryConfig config) {\n if (this.namingService != null) {\n LOGGER.info(\"Nacos naming service already registered\");\n return;\n }\n Properties properties = config.getProps();\n Properties nacosProperties = new Properties();\n this.groupName = properties.getProperty(\"groupName\", \"SHENYU_GROUP\");\n String serverAddr = config.getServerList();\n nacosProperties.put(PropertyKeyConst.SERVER_ADDR, serverAddr);\n nacosProperties.put(PropertyKeyConst.NAMESPACE, properties.getProperty(NAMESPACE, \"\"));\n nacosProperties.put(PropertyKeyConst.USERNAME, properties.getProperty(PropertyKeyConst.USERNAME, \"\"));\n nacosProperties.put(PropertyKeyConst.PASSWORD, properties.getProperty(PropertyKeyConst.PASSWORD, \"\"));\n nacosProperties.put(PropertyKeyConst.ACCESS_KEY, properties.getProperty(PropertyKeyConst.ACCESS_KEY, \"\"));\n nacosProperties.put(PropertyKeyConst.SECRET_KEY, properties.getProperty(PropertyKeyConst.SECRET_KEY, \"\"));\n try {\n this.namingService = NamingFactory.createNamingService(nacosProperties);\n LOGGER.info(\"Nacos naming service initialized success\");\n } catch (NacosException e) {\n LOGGER.error(\"Error initializing Nacos naming service\", e);\n throw new ShenyuException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testInit", + "reference": " @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n // Set the discovery config\n setField(nacosDiscoveryServiceUnderTest.getClass(), \"namingService\", null);\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n config.setServerList(\"127.0.0.1:8848\");\n properties.setProperty(\"groupName\", \"SHENYU_GROUP\");\n config.setProps(properties);\n\n try (MockedStatic mockedNamingFactory = mockStatic(NamingFactory.class)) {\n // Mock the successful creation of NamingService\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenReturn(namingService);\n nacosDiscoveryServiceUnderTest.init(config);\n mockedNamingFactory.verify(() -> NamingFactory.createNamingService(any(Properties.class)));\n assertEquals(namingService, getField(nacosDiscoveryServiceUnderTest, \"namingService\"));\n // Mock the situation where NamingService fails to be created and throws an exception\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenThrow(new NacosException());\n assertDoesNotThrow(() -> nacosDiscoveryServiceUnderTest.init(config));\n }\n }", + "completions": [ + "```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n properties.put(\"groupName\", \"SHENYU_GROUP\");\n properties.put(\"serverAddr\", \"localhost:8848\");\n properties.put(\"namespace\", \"\");\n properties.put(\"username\", \"\");\n properties.put(\"password\", \"\");\n properties.put(\"accessKey\", \"\");\n properties.put(\"secretKey\", \"\");\n config.setProps(properties);\n config.setServerList(\"localhost:8848\");\n\n NacosDiscoveryService nacosDiscoveryService = new NacosDiscoveryService();\n nacosDiscoveryService.init(config);\n\n Field namingServiceField = NacosDiscoveryService.class.getDeclaredField(\"namingService\");\n namingServiceField.setAccessible(true);\n NamingService namingService = (NamingService) namingServiceField.get(nacosDiscoveryService);\n\n assertNotNull(namingService);\n }\n```\n" + ], + "line": 35, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 56, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyRecoveryNotBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(850, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 2) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(750, 0, indexerConfig));\n\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n assertThat(recoveryTaskCreator.determineStartingOffset(1201, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTaskCreator.determineStartingOffset(1200, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(1150, 0, indexerConfig));\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore))\n .contains(recoveryTask1, recoveryTask11);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 57, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public boolean isJsonMime(String mediaType) {\n // \"* / *\" is default to JSON\n if (\"*/*\".equals(mediaType)) {\n return true;\n }\n\n try {\n return isJsonMime(MediaType.parseMediaType(mediaType));\n } catch (InvalidMediaTypeException e) {\n }\n return false;\n }\n```\n###test function signature:\n```java\n @Test\n public void testIsJsonMime()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJsonMime", + "reference": " @Test\n public void testIsJsonMime() {\n assertFalse(apiClient.isJsonMime(null));\n assertFalse(apiClient.isJsonMime(\"\"));\n assertFalse(apiClient.isJsonMime(\"text/plain\"));\n assertFalse(apiClient.isJsonMime(\"application/xml\"));\n assertFalse(apiClient.isJsonMime(\"application/jsonp\"));\n assertFalse(apiClient.isJsonMime(\"example/json\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+jsonx\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+xjson\"));\n\n assertTrue(apiClient.isJsonMime(\"application/json\"));\n assertTrue(apiClient.isJsonMime(\"application/json; charset=UTF8\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/JSON\"));\n\n assertTrue(apiClient.isJsonMime(\"application/problem+json\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/PROBLEM+JSON\"));\n assertTrue(apiClient.isJsonMime(\"application/json\\t\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+bar+json\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json;x;y\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json\\t;\"));\n assertTrue(apiClient.isJsonMime(\"Example/fOO+JSON\"));\n }", + "completions": [ + "```java\n @Test\n public void testIsJsonMime() {\n // Test case 1: \"* / *\"\n assertTrue(isJsonMime(\"*/*\"));\n\n // Test case 2: \"application/json\"\n assertTrue(isJsonMime(\"application/json\"));\n\n // Test case 3: \"application/json; charset=UTF-8\"\n assertTrue(isJsonMime(\"application/json; charset=UTF-8\"));\n\n // Test case 4: \"application/xml\"\n assertFalse(isJsonMime(\"application/xml\"));\n\n // Test case 5: \"text/plain\"\n assertFalse(isJsonMime(\"text/plain\"));\n\n // Test case 6: \"invalid/media/type\"\n assertFalse(isJsonMime(\"invalid/media/type\"));\n\n // Test case 7: null\n assertFalse(isJsonMime(null));\n }\n```\n" + ], + "line": 22, + "token": 139, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n dpopService.validateDpopThumprint(\"test\", \"test\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n // Given\n String existingThumprint = \"existingThumprint\";\n String requestThumprint = \"existingThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n // Call the method to test\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String getExpandUrl(String url, Object paramObj) {\n UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url);\n if (paramObj instanceof Map) {\n for (Map.Entry param : ((Map) paramObj).entrySet()) {\n if (null != param.getValue()) {\n builder.queryParam(param.getKey(), param.getValue());\n }\n }\n } else {\n for (Field param : paramObj.getClass().getDeclaredFields()) {\n param.setAccessible(true);\n\n try {\n Object value = param.get(paramObj);\n if (null != value) {\n if (value instanceof List) {\n builder.queryParam(param.getName(), (List) value);\n } else {\n builder.queryParam(param.getName(), value);\n }\n }\n } catch (IllegalAccessException e) {\n log.warn(\"Field: {} value retrieval failed, exception: {}\", param.getName(), e);\n }\n }\n }\n\n return builder.build().encode().toUri().toString();\n }\n```\n###test function signature:\n```java\n @Test\n void testGetExpandUrl()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetExpandUrl", + "reference": " @Test\n void testGetExpandUrl() {\n String result = HttpUtil.getExpandUrl(\"http://baidu.com\", \"paramObj\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testGetExpandUrl() {\n // Test with a Map\n Map map = new HashMap<>();\n map.put(\"param1\", \"value1\");\n map.put(\"param2\", \"value2\");\n String url = \"http://example.com\";\n String expected = \"http://example.com?param1=value1¶m2=value2\";\n String actual = getExpandUrl(url, map);\n assertEquals(expected, actual);\n\n // Test with a custom object\n TestObject obj = new TestObject(\"value1\", \"value2\");\n expected = \"http://example.com?param1=value1¶m2=value2\";\n actual = getExpandUrl(url, obj);\n assertEquals(expected, actual);\n }\n\n public static class TestObject {\n private String param1;\n private String param2;\n\n public TestObject(String param1, String param2) {\n this.param1 = param1;\n this.param2 = param2;\n }\n\n public String getParam1() {\n return param1;\n }\n\n public String getParam2() {\n return param2;\n }\n }\n```\n" + ], + "line": 38, + "token": 277, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 60, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"asdf\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n String invalidExpirationTime = \"invalid\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(invalidExpirationTime);\n assertEquals(CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 61, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Get\n @Path(\"/:indexName/_mapping\")\n public HttpResponse mapping(\n @Param(\"indexName\") Optional indexName,\n @Param(\"startTimeEpochMs\") Optional startTimeEpochMs,\n @Param(\"endTimeEpochMs\") Optional endTimeEpochMs)\n throws IOException {\n // Use a tree map so the results are naturally sorted\n Map> propertiesMap = new TreeMap<>();\n\n // we default the schema search to the last hour if params are not provided\n AstraSearch.SchemaResult schemaResult =\n searcher.getSchema(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(\n startTimeEpochMs.orElse(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build());\n\n Map schema = SearchResultUtils.fromSchemaResultProto(schemaResult);\n schema.forEach((key, value) -> propertiesMap.put(key, Map.of(\"type\", value.getName())));\n\n // todo - remove this after we add support for a \"date\" type\n // override the timestamp as a date field for proper autocomplete\n propertiesMap.put(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, Map.of(\"type\", \"date\"));\n\n return HttpResponse.of(\n HttpStatus.OK,\n MediaType.JSON,\n JsonUtil.writeAsString(\n ImmutableMap.of(\n indexName.orElseThrow(),\n ImmutableMap.of(\"mappings\", ImmutableMap.of(\"properties\", propertiesMap)))));\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexMapping() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexMapping", + "reference": " @Test\n public void testIndexMapping() throws IOException {\n AstraQueryServiceBase searcher = mock(AstraQueryServiceBase.class);\n ElasticsearchApiService serviceUnderTest = new ElasticsearchApiService(searcher);\n\n Instant start = Instant.now();\n Instant end = start.minusSeconds(60);\n\n when(searcher.getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build())))\n .thenReturn(AstraSearch.SchemaResult.newBuilder().build());\n\n HttpResponse response =\n serviceUnderTest.mapping(\n Optional.of(\"foo\"), Optional.of(start.toEpochMilli()), Optional.of(end.toEpochMilli()));\n verify(searcher)\n .getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build()));\n\n // handle response\n AggregatedHttpResponse aggregatedRes = response.aggregate().join();\n String body = aggregatedRes.content(StandardCharsets.UTF_8);\n JsonNode jsonNode = new ObjectMapper().readTree(body);\n\n assertThat(aggregatedRes.status().code()).isEqualTo(200);\n\n assertThat(jsonNode.findValue(\"foo\")).isNotNull();\n assertThat(\n jsonNode.findValue(\"foo\").findValue(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName))\n .isNotNull();\n\n when(searcher.getSchema(any()))\n .thenAnswer(\n invocationOnMock -> {\n AstraSearch.SchemaRequest request =\n ((AstraSearch.SchemaRequest) invocationOnMock.getArguments()[0]);\n assertThat(request.getDataset()).isEqualTo(\"bar\");\n assertThat(request.getStartTimeEpochMs())\n .isCloseTo(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli(),\n Offset.offset(1000L));\n assertThat(request.getEndTimeEpochMs())\n .isCloseTo(Instant.now().toEpochMilli(), Offset.offset(1000L));\n return AstraSearch.SchemaResult.newBuilder().build();\n });\n serviceUnderTest.mapping(Optional.of(\"bar\"), Optional.empty(), Optional.empty());\n }", + "completions": [ + "```java\n @Test\n public void testIndexMapping() throws IOException {\n Optional indexName = Optional.of(\"testIndex\");\n Optional startTimeEpochMs = Optional.of(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli());\n Optional endTimeEpochMs = Optional.of(Instant.now().toEpochMilli());\n\n AstraSearch.SchemaResult schemaResult =\n AstraSearch.SchemaResult.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(startTimeEpochMs.orElse(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build();\n\n when(searcher.getSchema(any(AstraSearch.SchemaRequest.class))).thenReturn(schemaResult);\n\n HttpResponse response = mapping(indexName, startTimeEpochMs, endTimeEpochMs);\n\n assertEquals(HttpStatus.OK, response.getStatus());\n assertEquals(MediaType.JSON, response.getContentType());\n assertNotNull(response.getContent());\n }\n```\n" + ], + "line": 43, + "token": 354, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 62, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "@SuppressWarnings", + "reference": " @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n consumer = newConsumer();\n Exception e = assertThrows(UnsupportedOperationException.class, () -> consumer.poll(0L));\n assertEquals(\"Consumer.poll(long) is not supported when \\\"group.protocol\\\" is \\\"consumer\\\". \" +\n \"This method is deprecated and will be removed in the next major release.\", e.getMessage());\n }", + "completions": [ + "```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 46, + "token": 373, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 63, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testNullSearchString()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNullSearchString", + "reference": " @Test\n public void testNullSearchString() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n assertThatExceptionOfType(IllegalArgumentException.class)\n .isThrownBy(\n () ->\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME + \"miss\",\n null,\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\")));\n }", + "completions": [ + "```java\n @Test\n public void testNullSearchString() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = null;\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n try {\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n fail(\"Expected an IllegalArgumentException to be thrown\");\n } catch (IllegalArgumentException e) {\n // Expected\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 64, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture sendMessageBack(ProxyContext ctx, ReceiptHandle handle, String messageId,\n ConsumerSendMsgBackRequestHeader requestHeader, long timeoutMillis) {\n // Build the response.\n final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null, null);\n\n Integer delayLevel = requestHeader.getDelayLevel();\n if (Objects.isNull(delayLevel)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument delay level is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n Long offset = requestHeader.getOffset();\n if (Objects.isNull(offset)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument offset is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n VirtualQueue virtualQueue = new VirtualQueue(requestHeader.getBname());\n\n CompletableFuture topicFuture = topicOf(requestHeader.getOriginTopic());\n CompletableFuture groupFuture = consumerGroupOf(requestHeader.getGroup());\n return topicFuture.thenCombine(groupFuture, (topic, group) -> {\n if (topic.getTopicId() != virtualQueue.topicId()) {\n LOGGER.error(\"Topic id in request header {} does not match topic id in message queue {}, maybe the topic is recreated.\",\n topic.getTopicId(), virtualQueue.topicId());\n throw new ProxyException(apache.rocketmq.v2.Code.TOPIC_NOT_FOUND, \"Topic resource does not exist.\");\n }\n return Pair.of(topic, group);\n }).thenCompose(pair -> {\n Topic topic = pair.getLeft();\n ConsumerGroup group = pair.getRight();\n\n return store.pull(StoreContext.EMPTY, group.getGroupId(), topic.getTopicId(), virtualQueue.physicalQueueId(),\n Filter.DEFAULT_FILTER, requestHeader.getOffset(), 1, false)\n .thenApply(pullResult -> {\n if (pullResult.status() == com.automq.rocketmq.store.model.message.PullResult.Status.FOUND) {\n return pullResult.messageList().get(0);\n }\n throw new ProxyException(apache.rocketmq.v2.Code.MESSAGE_NOT_FOUND, \"Message not found from server.\");\n }).thenCompose(messageExt -> {\n if (messageExt.deliveryAttempts() > group.getMaxDeliveryAttempt()) {\n return deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n }\n\n // Message consume retry strategy\n // <0: no retry,put into DLQ directly\n // =0: broker control retry frequency\n // >0: client control retry frequency\n return switch (Integer.compare(delayLevel, 0)) {\n case -1 ->\n deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n case 0 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n // Keep the same logic as apache RocketMQ.\n int serverDelayLevel = messageExt.deliveryAttempts() + 1;\n messageExt.setDeliveryAttempts(serverDelayLevel);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(serverDelayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put messageExt to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n case 1 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n messageExt.setDeliveryAttempts(messageExt.deliveryAttempts() + 1);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(delayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put message to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n default -> throw new IllegalStateException(\"Never reach here\");\n };\n });\n }).whenComplete((nil, throwable) -> {\n if (throwable != null) {\n response.setCode(ResponseCode.SYSTEM_ERROR);\n response.setRemark(throwable.getMessage());\n }\n }).thenApply(nil -> response);\n }\n```\n###test function signature:\n```java\n @Test\n void sendMessageBack()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendMessageBack", + "reference": " @Test\n void sendMessageBack() {\n String topicName = \"topic\";\n VirtualQueue virtualQueue = new VirtualQueue(2, 0);\n ConsumerSendMsgBackRequestHeader header = new ConsumerSendMsgBackRequestHeader();\n header.setBname(virtualQueue.brokerName());\n header.setGroup(\"group\");\n header.setOriginTopic(topicName);\n header.setOffset(0L);\n header.setDelayLevel(0);\n\n CompletionException exception = assertThrowsExactly(CompletionException.class, () -> messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join());\n assertInstanceOf(ProxyException.class, ExceptionUtils.getRealException(exception));\n ProxyException realException = (ProxyException) ExceptionUtils.getRealException(exception);\n assertEquals(realException.getErrorCode(), Code.MESSAGE_NOT_FOUND);\n assertEquals(realException.getMessage(), \"Message not found from server.\");\n\n long topicId = metadataService.topicOf(topicName).join().getTopicId();\n messageStore.put(StoreContext.EMPTY, FlatMessageUtil.convertTo(ProxyContextExt.create(), topicId, 0, \"\", new Message(topicName, \"\", new byte[] {})));\n\n // Broker controlled delay level.\n header.setDelayLevel(0);\n RemotingCommand response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Client controlled delay level.\n header.setDelayLevel(16);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Forward message into dead letter topic.\n header.setDelayLevel(-1);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.times(1)).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n }", + "completions": [ + "```java\n @Test\n void sendMessageBack() {\n // Given\n ProxyContext ctx = mock(ProxyContext.class);\n ReceiptHandle handle = mock(ReceiptHandle.class);\n String messageId = \"testMessageId\";\n ConsumerSendMsgBackRequestHeader requestHeader = mock(ConsumerSendMsgBackRequestHeader.class);\n long timeoutMillis = 1000L;\n\n when(requestHeader.getDelayLevel()).thenReturn(1);\n when(requestHeader.getOffset()).thenReturn(1L);\n when(requestHeader.getOriginTopic()).thenReturn(\"testOriginTopic\");\n when(requestHeader.getGroup()).thenReturn(\"testGroup\");\n when(requestHeader.getBname()).thenReturn(\"testBname\");\n\n Topic topic = mock(Topic.class);\n when(topic.getTopicId()).thenReturn(1);\n\n ConsumerGroup group = mock(ConsumerGroup.class);\n when(group.getGroupId()).thenReturn(\"testGroupId\");\n when(group.getMaxDeliveryAttempt()).thenReturn(3);\n\n FlatMessage message = mock(FlatMessage.class);\n when(message.deliveryAttempts()).thenReturn(1);\n when(message.originalOffset()).thenReturn(1L);\n\n MessageExt messageExt = mock(MessageExt.class);\n when(messageExt.message()).thenReturn(message);\n when(messageExt.deliveryAttempts()).thenReturn(1);\n\n PullResult pullResult = mock(PullResult.class);\n when(pullResult.status()).thenReturn(PullResult.Status.FOUND);\n when(pullResult.messageList()).thenReturn(List.of(messageExt));\n\n Store store = mock(Store.class);\n when(store.pull(any(), any(), any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(pullResult));\n when(store.put(any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n DeadLetterService deadLetterService = mock(DeadLetterService.class);\n when(deadLetterService.send(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n ProxyContextExt ctxExt = mock(ProxyContextExt.class);\n\n TopicService topicService = mock(TopicService.class);\n when(topicService.topicOf(any())).thenReturn(CompletableFuture.completedFuture(topic));\n when(topicService.consumerGroupOf(any())).thenReturn(CompletableFuture.completedFuture(group));\n\n ConsumerService consumerService = new ConsumerService(store, deadLetterService, topicService);\n\n // When\n CompletableFuture future = consumerService.sendMessageBack(ctx, handle, messageId, requestHeader, timeoutMillis);\n\n // Then\n RemotingCommand response = future.join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n }\n```\n" + ], + "line": 99, + "token": 793, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 65, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInRequestThreads()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInRequestThreads", + "reference": " @Test\n public void testGetTracerInRequestThreads() {\n ApolloAuditTracer mockTracer = new ApolloAuditTracer(Mockito.mock(ApolloAuditScopeManager.class), supplier);\n RequestAttributes mockRequestAttributes = Mockito.mock(RequestAttributes.class);\n RequestContextHolder.setRequestAttributes(mockRequestAttributes);\n Mockito.when(mockRequestAttributes.getAttribute(Mockito.eq(ApolloAuditConstants.TRACER), Mockito.eq(RequestAttributes.SCOPE_REQUEST)))\n .thenReturn(mockTracer);\n ApolloAuditTracer get = traceContext.tracer();\n assertNotNull(get);\n Mockito.verify(traceContext, Mockito.times(0))\n .setTracer(Mockito.any(ApolloAuditTracer.class));\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInRequestThreads() {\n // Given\n RequestAttributes requestAttributes = Mockito.mock(RequestAttributes.class);\n ApolloAuditTracer tracer = Mockito.mock(ApolloAuditTracer.class);\n Mockito.when(requestAttributes.getAttribute(ApolloAuditConstants.TRACER, RequestAttributes.SCOPE_REQUEST)).thenReturn(tracer);\n Mockito.when(RequestContextHolder.getRequestAttributes()).thenReturn(requestAttributes);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n Assert.assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Map generate(String modelName) {\n Map kv = new HashMap<>();\n if (modelName != null) {\n final Schema schema = this.globalContext.getSchemaMap().get(modelName);\n if (schema != null) {\n String example = Json.pretty(this.resolveModelToExample(modelName, schema));\n\n if (example != null) {\n kv.put(EXAMPLE, example);\n return Map.copyOf(kv);\n }\n }\n }\n return Collections.emptyMap();\n }\n```\n###test function signature:\n```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "shouldGenerateOneOfWhenOneOfInRoot", + "reference": " @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n OpenAPIModelGenerator generator = setupPayloadGenerator();\n\n Map example = generator.generate(\"PetType\");\n String exampleJson = example.get(\"example\");\n\n Assertions.assertThat(exampleJson)\n .contains(\"PetTypeONE_OF#/components/schemas/Husky\")\n .contains(\"PetTypeONE_OF#/components/schemas/Labrador\");\n }", + "completions": [ + "```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n // Given\n String modelName = \"ModelName\";\n Schema schema = new Schema();\n schema.setOneOf(Arrays.asList(new Schema(), new Schema()));\n this.globalContext.getSchemaMap().put(modelName, schema);\n\n // When\n Map result = this.generate(modelName);\n\n // Then\n assertNotNull(result);\n assertFalse(result.isEmpty());\n assertTrue(result.containsKey(EXAMPLE));\n assertNotNull(result.get(EXAMPLE));\n }\n```\n" + ], + "line": 25, + "token": 181, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInAnotherThreadButSameRequest", + "reference": " @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n ApolloAuditTracer mockTracer = Mockito.mock(ApolloAuditTracer.class);\n {\n Mockito.when(traceContext.tracer()).thenReturn(mockTracer);\n }\n CountDownLatch latch = new CountDownLatch(1);\n Executors.newSingleThreadExecutor().submit(() -> {\n ApolloAuditTracer tracer = traceContext.tracer();\n\n assertEquals(mockTracer, tracer);\n\n latch.countDown();\n });\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n // Given\n RequestAttributes requestAttributes = new ServletRequestAttributes(new MockHttpServletRequest());\n RequestContextHolder.setRequestAttributes(requestAttributes);\n ApolloAuditTracer tracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n requestAttributes.setAttribute(ApolloAuditConstants.TRACER, tracer, RequestAttributes.SCOPE_REQUEST);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchMultipleItemsAndIndices()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchMultipleItemsAndIndices", + "reference": " @Test\n public void testSearchMultipleItemsAndIndices() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(1);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchMultipleItemsAndIndices() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(10, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 69, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n static Map>> getCandidateJobs(String tableNameWithType,\n Map> allJobMetadata)\n throws Exception {\n long nowMs = System.currentTimeMillis();\n Map>> candidates = new HashMap<>();\n // If the job started most recently has already completed, then skip retry for the table.\n Pair latestStartedJob = null;\n Pair latestCompletedJob = null;\n // The processing order of job metadata from the given Map is not deterministic. Track the completed original\n // jobs so that we can simply skip the retry jobs belonging to the completed original jobs.\n Map completedOriginalJobs = new HashMap<>();\n Set cancelledOriginalJobs = new HashSet<>();\n for (Map.Entry> entry : allJobMetadata.entrySet()) {\n String jobId = entry.getKey();\n Map jobMetadata = entry.getValue();\n long statsUpdatedAt = Long.parseLong(jobMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));\n String jobStatsInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);\n if (StringUtils.isEmpty(jobStatsInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job progress stats\", jobId);\n continue;\n }\n String jobCtxInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n if (StringUtils.isEmpty(jobCtxInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job context\", jobId);\n continue;\n }\n TableRebalanceProgressStats jobStats = JsonUtils.stringToObject(jobStatsInStr, TableRebalanceProgressStats.class);\n TableRebalanceContext jobCtx = JsonUtils.stringToObject(jobCtxInStr, TableRebalanceContext.class);\n long jobStartTimeMs = jobStats.getStartTimeMs();\n if (latestStartedJob == null || latestStartedJob.getRight() < jobStartTimeMs) {\n latestStartedJob = Pair.of(jobId, jobStartTimeMs);\n }\n String originalJobId = jobCtx.getOriginalJobId();\n RebalanceResult.Status jobStatus = jobStats.getStatus();\n if (jobStatus == RebalanceResult.Status.DONE || jobStatus == RebalanceResult.Status.NO_OP) {\n LOGGER.info(\"Skip rebalance job: {} as it has completed with status: {}\", jobId, jobStatus);\n completedOriginalJobs.put(originalJobId, jobId);\n if (latestCompletedJob == null || latestCompletedJob.getRight() < jobStartTimeMs) {\n latestCompletedJob = Pair.of(jobId, jobStartTimeMs);\n }\n continue;\n }\n if (jobStatus == RebalanceResult.Status.FAILED || jobStatus == RebalanceResult.Status.ABORTED) {\n LOGGER.info(\"Found rebalance job: {} for original job: {} has been stopped with status: {}\", jobId,\n originalJobId, jobStatus);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n continue;\n }\n if (jobStatus == RebalanceResult.Status.CANCELLED) {\n LOGGER.info(\"Found cancelled rebalance job: {} for original job: {}\", jobId, originalJobId);\n cancelledOriginalJobs.add(originalJobId);\n continue;\n }\n // Check if an IN_PROGRESS job is still actively running.\n long heartbeatTimeoutMs = jobCtx.getConfig().getHeartbeatTimeoutInMs();\n if (nowMs - statsUpdatedAt < heartbeatTimeoutMs) {\n LOGGER.info(\"Rebalance job: {} is actively running with status updated at: {} within timeout: {}. Skip \"\n + \"retry for table: {}\", jobId, statsUpdatedAt, heartbeatTimeoutMs, tableNameWithType);\n return Collections.emptyMap();\n }\n // The job is considered failed, but it's possible it is still running, then we might end up with more than one\n // rebalance jobs running in parallel for a table. The rebalance algorithm is idempotent, so this should be fine\n // for the correctness.\n LOGGER.info(\"Found stuck rebalance job: {} for original job: {}\", jobId, originalJobId);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n }\n if (latestCompletedJob != null && latestCompletedJob.getLeft().equals(latestStartedJob.getLeft())) {\n LOGGER.info(\"Rebalance job: {} started most recently has already done. Skip retry for table: {}\",\n latestCompletedJob.getLeft(), tableNameWithType);\n return Collections.emptyMap();\n }\n for (String jobId : cancelledOriginalJobs) {\n LOGGER.info(\"Skip original job: {} as it's cancelled\", jobId);\n candidates.remove(jobId);\n }\n for (Map.Entry entry : completedOriginalJobs.entrySet()) {\n LOGGER.info(\"Skip original job: {} as it's completed by attempt: {}\", entry.getKey(), entry.getValue());\n candidates.remove(entry.getKey());\n }\n return candidates;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetCandidateJobs()\n throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetCandidateJobs", + "reference": " @Test\n public void testGetCandidateJobs()\n throws Exception {\n String tableName = \"table01\";\n Map> allJobMetadata = new HashMap<>();\n\n // Original job run as job1, and all its retry jobs failed too.\n RebalanceConfig jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n TableRebalanceProgressStats stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(1000);\n TableRebalanceContext jobCtx = TableRebalanceContext.forInitialAttempt(\"job1\", jobCfg);\n Map jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job1\", stats, jobCtx);\n allJobMetadata.put(\"job1\", jobMetadata);\n // 3 failed retry runs for job1\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 2, 1100, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 3, 1200, RebalanceResult.Status.ABORTED);\n allJobMetadata.put(\"job1_3\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 4, 1300, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_4\", jobMetadata);\n\n // Original job run as job2, and its retry job job2_1 completed.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(2000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job2\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job2\", stats, jobCtx);\n allJobMetadata.put(\"job2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job2\", 2, 2100, RebalanceResult.Status.DONE);\n allJobMetadata.put(\"job2_2\", jobMetadata);\n\n // Original job run as job3, and failed to send out heartbeat in time.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.IN_PROGRESS);\n stats.setStartTimeMs(3000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job3\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job3\", stats, jobCtx);\n jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"3000\");\n allJobMetadata.put(\"job3\", jobMetadata);\n\n // Original job run as job4, which didn't have retryJobCfg as from old version of the code.\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(4000);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job4\", stats, null);\n jobMetadata.remove(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n allJobMetadata.put(\"job4\", jobMetadata);\n\n // Only need to retry job1 and job3, as job2 is completed and job4 is from old version of code.\n Map>> jobs =\n RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 2);\n assertTrue(jobs.containsKey(\"job1\"));\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job1\").size(), 4); // four runs including job1,job1_1,job1_2,job1_3\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Abort job1 and cancel its retries, then only job3 is retry candidate.\n jobMetadata = allJobMetadata.get(\"job1_4\");\n cancelRebalanceJob(jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 1);\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Add latest job5 that's already done, thus no need to retry for table.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.DONE);\n stats.setStartTimeMs(5000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job5\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job5\", stats, jobCtx);\n allJobMetadata.put(\"job5\", jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 0);\n }", + "completions": [ + "```java\n @Test\n public void testGetCandidateJobs()\n throws Exception {\n // Given\n String tableNameWithType = \"table1_type1\";\n Map> allJobMetadata = new HashMap<>();\n Map jobMetadata1 = new HashMap<>();\n jobMetadata1.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"FAILED\\\"}\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job1\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job1\", jobMetadata1);\n Map jobMetadata2 = new HashMap<>();\n jobMetadata2.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"CANCELLED\\\"}\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job2\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job2\", jobMetadata2);\n\n // When\n Map>> result = getCandidateJobs(tableNameWithType, allJobMetadata);\n\n // Then\n assertEquals(1, result.size());\n assertTrue(result.containsKey(\"job2\"));\n Set> jobSet = result.get(\"job2\");\n assertEquals(1, jobSet.size());\n Pair jobPair = jobSet.iterator().next();\n assertEquals(\"job2\", jobPair.getLeft().getOriginalJobId());\n assertEquals(1640995200000L, jobPair.getRight().longValue());\n }\n```\n" + ], + "line": 94, + "token": 952, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 70, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchError()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchError", + "reference": " @Test\n public void distClientMatchError() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.ERROR);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.ERROR);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.NO_SUB));\n for (IMQTTTransientSession session : sessions) {\n verify(session, never()).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchError() {\n // Create a mock for the ITenantMeter interface\n ITenantMeter mockTenantMeter = Mockito.mock(ITenantMeter.class);\n // Create a mock for the ILocalDistService interface\n ILocalDistService mockLocalDistService = Mockito.mock(ILocalDistService.class);\n // Create a mock for the sessionMap\n Map mockSessionMap = new HashMap<>();\n // Create a mock for the routeMap\n Map> mockRouteMap = new HashMap<>();\n // Create a mock for the resourceThrottler\n IResourceThrottler mockResourceThrottler = Mockito.mock(IResourceThrottler.class);\n\n // Set up the mock objects\n Mockito.when(mockTenantMeter.get(Mockito.anyString())).thenReturn(mockTenantMeter);\n Mockito.when(mockLocalDistService.isGlobal(Mockito.anyString())).thenReturn(false);\n Mockito.when(mockSessionMap.get(Mockito.anyString())).thenReturn(null);\n Mockito.when(mockRouteMap.get(Mockito.any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(null));\n Mockito.when(mockResourceThrottler.hasResource(Mockito.anyString(), Mockito.anyLong())).thenReturn(false);\n\n // Create a new instance of the class under test\n DistService distService = new DistService(mockTenantMeter, mockLocalDistService, mockSessionMap, mockRouteMap, mockResourceThrottler);\n\n // Create a test request\n DeliveryRequest request = DeliveryRequest.newBuilder().build();\n\n // Call the method under test\n CompletableFuture replyFuture = distService.dist(request);\n\n // Verify the results\n DeliveryReply reply = replyFuture.join();\n Assert.assertEquals(0, reply.getResultMap().size());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 71, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchOK()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchOK", + "reference": " @Test\n public void distClientMatchOK() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n when(session.publish(any(), anyList())).thenReturn(true);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.OK);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.OK);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.OK));\n for (IMQTTTransientSession session : sessions) {\n verify(session, times(1)).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchOK() {\n // Given\n DeliveryRequest request = DeliveryRequest.newBuilder()\n .putPackage(\"tenant1\", DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .setMessagePack(TopicMessagePack.newBuilder()\n .setTopic(\"topic1\")\n .setPayload(\"payload1\")\n .build())\n .addMatchInfo(MatchInfo.newBuilder()\n .setReceiverId(\"receiver1\")\n .setTopicFilter(\"topic1\")\n .build())\n .build())\n .build())\n .build();\n sessionMap.put(\"receiver1\", mock(IMQTTTransientSession.class));\n when(sessionMap.get(\"receiver1\").publish(any(MatchInfo.class), anyList())).thenReturn(true);\n when(routeMap.get(any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(LocalRoutes.newBuilder()\n .setLocalizedReceiverId(\"receiver1\")\n .putRoute(\"receiver1\", mock(IMQTTTransientSession.class))\n .build()));\n when(ITenantMeter.get(\"tenant1\")).thenReturn(mock(ITenantMeter.class));\n when(resourceThrottler.hasResource(\"tenant1\", TotalTransientFanOutBytesPerSeconds)).thenReturn(true);\n\n // When\n CompletableFuture replyFuture = dist(request);\n\n // Then\n DeliveryReply reply = replyFuture.join();\n DeliveryResults results = reply.getResultMap().get(\"tenant1\");\n assertEquals(1, results.getResultCount());\n DeliveryResult result = results.getResult(0);\n assertEquals(DeliveryResult.Code.OK, result.getCode());\n assertEquals(\"receiver1\", result.getMatchInfo().getReceiverId());\n assertEquals(\"topic1\", result.getMatchInfo().getTopicFilter());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 72, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Range calcInterquartileRangeApproximation(List sortedValues) {\n Range iqr = null;\n\n if (sortedValues.size() > 0) {\n BigDecimal q1 = valueCloseToQ1(sortedValues);\n BigDecimal q3 = valueCloseToQ3(sortedValues);\n BigDecimal max = sortedValues.get(sortedValues.size() - 1);\n\n // if iqr == 0 AND max == q3 then recursively try finding a non-zero iqr approximation.\n if (q1.compareTo(q3) == 0 && max.compareTo(q3) == 0) {\n // filter out max and try again\n iqr = this.calcInterquartileRangeApproximation(\n sortedValues.stream().filter(d -> d.compareTo(max) < 0).collect(Collectors.toList()));\n }\n\n // if range is still empty use the original q1 and q3 values\n if (iqr == null || iqr.isEmpty()) {\n iqr = Range.closedOpen(q1, q3);\n }\n }\n\n return iqr;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCalcInterquartileRangeApproximation()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCalcInterquartileRangeApproximation", + "reference": " @Test\n public void testCalcInterquartileRangeApproximation() {\n List values = decList(1, 2, 3, 4, 5, 6, 7, 8);\n assertEquals(new BigDecimal(\"3\"), dataBinHelper.calcInterquartileRangeApproximation(values).lowerEndpoint());\n assertEquals(new BigDecimal(\"7\"), dataBinHelper.calcInterquartileRangeApproximation(values).upperEndpoint());\n }", + "completions": [ + "```java\n @Test\n public void testCalcInterquartileRangeApproximation() {\n List sortedValues = Arrays.asList(\n new BigDecimal(\"1.0\"),\n new BigDecimal(\"2.0\"),\n new BigDecimal(\"3.0\"),\n new BigDecimal(\"4.0\"),\n new BigDecimal(\"5.0\"),\n new BigDecimal(\"6.0\"),\n new BigDecimal(\"7.0\"),\n new BigDecimal(\"8.0\"),\n new BigDecimal(\"9.0\"),\n new BigDecimal(\"10.0\")\n );\n\n Range expectedIqr = Range.closedOpen(new BigDecimal(\"2.5\"), new BigDecimal(\"7.5\"));\n Range actualIqr = calcInterquartileRangeApproximation(sortedValues);\n\n assertEquals(expectedIqr, actualIqr);\n }\n```\n" + ], + "line": 30, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 73, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_non_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new NonSerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new NonSerializableObject(\"name2\"),\n new NonSerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObjectWithNonSerializableNestedObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serialable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key1\", \"value1\");\n map.put(\"key2\", new NonSerializableObject(\"nonSerializableValue\"));\n\n // When\n Map clonedMap = objectCloner.clone(map);\n\n // Then\n assertNotSame(map, clonedMap);\n assertEquals(map.size(), clonedMap.size());\n assertEquals(map.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(map.get(\"key2\"), clonedMap.get(\"key2\"));\n assertEquals(map.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 74, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SignedJWT token2Jwt(String token) {\n if (token == null) {\n log.error(\"Token is null.\");\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n try {\n return SignedJWT.parse(token);\n } catch (ParseException e) {\n log.error(\"Token to JWT Error,message = {}\", e.getMessage());\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testToken2Jwt()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testToken2Jwt", + "reference": " @Test\n public void testToken2Jwt() {\n Assertions.assertThrows(BizException.class, () -> tokenParseHelper.token2Jwt(null));\n }", + "completions": [ + "```java\n @Test\n public void testToken2Jwt() {\n // Given\n String token = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c\";\n\n // When\n SignedJWT jwt = token2Jwt(token);\n\n // Then\n assertNotNull(jwt);\n assertEquals(token, jwt.serialize());\n }\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testLongPollWaitIsLimited()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testLongPollWaitIsLimited", + "reference": " @Test\n public void testLongPollWaitIsLimited() {\n consumer = newConsumer();\n String topicName = \"topic1\";\n consumer.subscribe(singletonList(topicName));\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertTrue(consumer.assignment().isEmpty());\n\n final int partition = 3;\n final TopicPartition tp = new TopicPartition(topicName, partition);\n final List> records = asList(\n new ConsumerRecord<>(topicName, partition, 2, \"key1\", \"value1\"),\n new ConsumerRecord<>(topicName, partition, 3, \"key2\", \"value2\")\n );\n\n // On the first iteration, return no data; on the second, return two records\n doAnswer(invocation -> {\n // Mock the subscription being assigned as the first fetch is collected\n consumer.subscriptions().assignFromSubscribed(Collections.singleton(tp));\n return Fetch.empty();\n }).doAnswer(invocation -> {\n return Fetch.forPartition(tp, records, true);\n }).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n\n // And then poll for up to 10000ms, which should return 2 records without timing out\n ConsumerRecords returnedRecords = consumer.poll(Duration.ofMillis(10000));\n assertEquals(2, returnedRecords.count());\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertEquals(singleton(tp), consumer.assignment());\n }", + "completions": [ + "```java\n @Test\n public void testLongPollWaitIsLimited() {\n // Arrange\n Duration timeout = Duration.ofMillis(1000);\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // Act\n long startTime = System.currentTimeMillis();\n ConsumerRecords records = consumer.poll(timeout);\n long endTime = System.currentTimeMillis();\n\n // Assert\n long pollDuration = endTime - startTime;\n assertTrue(pollDuration >= timeout.toMillis());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 76, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForMinAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForMinAgg", + "reference": " @Test\n public void testFullIndexSearchForMinAgg() {\n Instant time = Instant.now();\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new MinAggBuilder(\n \"test\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalMin internalMin =\n (InternalMin) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n assertThat(Double.valueOf(internalMin.getValue()).longValue()).isEqualTo(time.toEpochMilli());\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForMinAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new MinAggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation() instanceof MinAgg);\n assertTrue(result.getElapsedTimeMicros() > 0);\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 77, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_bsaChecksumFetched", + "reference": " @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n setupMocks();\n when(connection.getResponseCode()).thenReturn(SC_OK);\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream(\"bsa-checksum\\ndata\".getBytes(UTF_8)));\n LazyBlockList download = fetcher.tryFetch(BlockListType.BLOCK);\n assertThat(download.getName()).isEqualTo(BlockListType.BLOCK);\n assertThat(download.checksum()).isEqualTo(\"bsa-checksum\");\n verify(connection, times(1)).setRequestMethod(\"GET\");\n verify(connection, times(1)).setRequestProperty(\"Authorization\", \"Bearer authToken\");\n }", + "completions": [ + "```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n // Given\n URL dataUrl = new URL(\"https://example.com/bsaChecksum\");\n HttpsURLConnection connection = mock(HttpsURLConnection.class);\n when(connection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);\n when(connection.getErrorStream()).thenReturn(new ByteArrayInputStream(\"\".getBytes()));\n when(urlConnectionService.createConnection(dataUrl)).thenReturn(connection);\n\n // When\n LazyBlockList lazyBlockList = bsaService.tryFetch(BlockListType.BSA_CHECKSUM);\n\n // Then\n assertEquals(BlockListType.BSA_CHECKSUM, lazyBlockList.getBlockListType());\n assertEquals(connection, lazyBlockList.getConnection());\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 78, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDisabledFullTextSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDisabledFullTextSearch", + "reference": " @Test\n public void testDisabledFullTextSearch() {\n Instant time = Instant.ofEpochSecond(1593365471);\n Trace.KeyValue field1Tag =\n Trace.KeyValue.newBuilder()\n .setVInt32(1234)\n .setKey(\"field1\")\n .setFieldType(Schema.SchemaFieldType.INTEGER)\n .build();\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(2, \"apple baby\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(3, \"baby car 1234\", time.plusSeconds(4)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Without the _all field as default.\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(3);\n\n // empty string\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"app*\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Returns baby or car, 2 messages.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby car\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Test numbers\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"apple 1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"123\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n }", + "completions": [ + "```java\n @Test\n public void testDisabledFullTextSearch() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getResults().size());\n assertNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 79, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyWithException", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(\n OperateRuntimeException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHitsRelation.EQUAL_TO));\n Mockito.when(searchHits.getHits()).thenReturn(new SearchHit[0]);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processInstanceListViewService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException to be thrown\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find unique process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 80, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_SecurityException_notRetriable", + "reference": " @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n when(connectionService.createConnection(any(URL.class)))\n .thenThrow(new GeneralSecurityException());\n assertThat(\n assertThrows(BsaException.class, () -> fetcher.tryFetch(BlockListType.BLOCK))\n .isRetriable())\n .isFalse();\n }", + "completions": [ + "```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n // Arrange\n String blockListTypeName = \"TEST\";\n BlockListType blockListType = new BlockListType(blockListTypeName);\n URL dataUrl = new URL(\"http://example.com\");\n when(blockListUrls.get(blockListTypeName)).thenReturn(dataUrl.toString());\n when(urlConnectionService.createConnection(dataUrl)).thenThrow(new GeneralSecurityException());\n\n // Act and Assert\n assertThrows(BsaException.class, () -> {\n try {\n blockListService.tryFetch(blockListType);\n } catch (BsaException e) {\n assertFalse(e.isRetriable());\n throw e;\n }\n });\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 81, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssignOnEmptyTopicPartition()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssignOnEmptyTopicPartition", + "reference": " @Test\n public void testAssignOnEmptyTopicPartition() {\n consumer = newConsumer();\n completeUnsubscribeApplicationEventSuccessfully();\n\n consumer.assign(Collections.emptyList());\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().isEmpty());\n }", + "completions": [ + "```java\n @Test\n public void testAssignOnEmptyTopicPartition() {\n // Given\n Collection partitions = new ArrayList<>();\n ConsumerImpl consumer = new ConsumerImpl();\n\n // When\n consumer.assign(partitions);\n\n // Then\n // Add assertions to verify the expected behavior\n // For example, you can check if the consumer is unsubscribed and if the fetch buffer is empty\n // assertFalse(consumer.isSubscribed());\n // assertTrue(consumer.getFetchBuffer().isEmpty());\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 82, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ImmutableList applyLabelDiff(\n ImmutableList labels,\n IdnChecker idnChecker,\n DownloadSchedule schedule,\n DateTime now) {\n ImmutableList.Builder nonBlockedDomains = new ImmutableList.Builder<>();\n ImmutableMap> labelsByType =\n ImmutableMap.copyOf(\n labels.stream().collect(groupingBy(BlockLabel::labelType, toImmutableList())));\n\n tm().transact(\n () -> {\n for (Map.Entry> entry :\n labelsByType.entrySet()) {\n switch (entry.getKey()) {\n case CREATE:\n // With current Cloud SQL, label upsert throughput is about 200/second. If\n // better performance is needed, consider bulk insert in native SQL.\n tm().putAll(\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(\n label ->\n new BsaLabel(label.label(), schedule.jobCreationTime()))\n .collect(toImmutableList()));\n // May not find all unblockables due to race condition: DomainCreateFlow uses\n // cached BsaLabels. Eventually will be consistent.\n nonBlockedDomains.addAll(\n tallyUnblockableDomainsForNewLabels(entry.getValue(), idnChecker, now));\n break;\n case DELETE:\n ImmutableSet deletedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n // Delete labels in DB. Also cascade-delete BsaUnblockableDomain.\n int nDeleted = Queries.deleteBsaLabelByLabels(deletedLabels);\n if (nDeleted != deletedLabels.size()) {\n logger.atSevere().log(\n \"Only found %s entities among the %s labels: [%s]\",\n nDeleted, deletedLabels.size(), deletedLabels);\n }\n break;\n case NEW_ORDER_ASSOCIATION:\n ImmutableSet affectedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n ImmutableSet labelsInDb =\n Queries.queryBsaLabelByLabels(affectedLabels)\n .map(BsaLabel::getLabel)\n .collect(toImmutableSet());\n verify(\n labelsInDb.size() == affectedLabels.size(),\n \"Missing labels in DB: %s\",\n LazyArgs.lazy(() -> difference(affectedLabels, labelsInDb)));\n\n // Reuse registered and reserved names that are already computed.\n Queries.queryBsaUnblockableDomainByLabels(affectedLabels)\n .map(BsaUnblockableDomain::toUnblockableDomain)\n .forEach(nonBlockedDomains::add);\n\n for (BlockLabel label : entry.getValue()) {\n getInvalidTldsForLabel(label, idnChecker)\n .map(tld -> UnblockableDomain.of(label.label(), tld, Reason.INVALID))\n .forEach(nonBlockedDomains::add);\n }\n break;\n }\n }\n },\n TRANSACTION_REPEATABLE_READ);\n logger.atInfo().log(\"Processed %s of labels.\", labels.size());\n return nonBlockedDomains.build();\n }\n```\n###test function signature:\n```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "applyLabelDiffs_newAssociationOfLabelToOrder", + "reference": " @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n tm().transact(\n () -> {\n tm().insert(new BsaLabel(\"label\", fakeClock.nowUtc()));\n tm().insert(new BsaUnblockableDomain(\"label\", \"app\", Reason.REGISTERED));\n });\n when(idnChecker.getSupportingTlds(any())).thenReturn(ImmutableSet.of(app));\n when(idnChecker.getForbiddingTlds(any()))\n .thenReturn(Sets.difference(ImmutableSet.of(dev), ImmutableSet.of()).immutableCopy());\n\n ImmutableList unblockableDomains =\n applyLabelDiff(\n ImmutableList.of(\n BlockLabel.of(\"label\", LabelType.NEW_ORDER_ASSOCIATION, ImmutableSet.of())),\n idnChecker,\n schedule,\n fakeClock.nowUtc());\n assertThat(unblockableDomains)\n .containsExactly(\n UnblockableDomain.of(\"label.app\", UnblockableDomain.Reason.REGISTERED),\n UnblockableDomain.of(\"label.dev\", UnblockableDomain.Reason.INVALID));\n assertThat(tm().transact(() -> tm().loadByKeyIfPresent(BsaLabel.vKey(\"label\")))).isPresent();\n assertThat(\n tm().transact(() -> tm().loadByKeyIfPresent(BsaUnblockableDomain.vKey(\"label\", \"app\"))))\n .isPresent();\n }", + "completions": [ + "```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n // Given\n ImmutableList labels = ImmutableList.of(BlockLabel.of(\"label1\", LabelType.NEW_ORDER_ASSOCIATION));\n IdnChecker idnChecker = mock(IdnChecker.class);\n DownloadSchedule schedule = mock(DownloadSchedule.class);\n DateTime now = DateTime.now();\n\n // When\n ImmutableList result = applyLabelDiff(labels, idnChecker, schedule, now);\n\n // Then\n // Add assertions here to verify the result\n }\n```\n" + ], + "line": 85, + "token": 602, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 83, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture createTopic(CreateTopicRequest request) {\n CompletableFuture future = new CompletableFuture<>();\n for (; ; ) {\n if (metadataStore.isLeader()) {\n try (SqlSession session = metadataStore.openSession()) {\n if (!metadataStore.maintainLeadershipWithSharedLock(session)) {\n continue;\n }\n TopicMapper topicMapper = session.getMapper(TopicMapper.class);\n if (null != topicMapper.get(null, request.getTopic())) {\n ControllerException e = new ControllerException(Code.DUPLICATED_VALUE,\n String.format(\"Topic %s was taken\", request.getTopic()));\n future.completeExceptionally(e);\n return future;\n }\n\n Topic topic = new Topic();\n topic.setName(request.getTopic());\n topic.setQueueNum(request.getCount());\n topic.setStatus(TopicStatus.TOPIC_STATUS_ACTIVE);\n topic.setAcceptMessageTypes(JsonFormat.printer().print(request.getAcceptTypes()));\n topic.setRetentionHours(request.getRetentionHours());\n topicMapper.create(topic);\n long topicId = topic.getId();\n List assignments = createQueues(IntStream.range(0, request.getCount()),\n topicId, session);\n // Commit transaction\n session.commit();\n\n // Cache new topic and queue assignments immediately\n topicCache.apply(List.of(topic));\n assignmentCache.apply(assignments);\n future.complete(topicId);\n } catch (ControllerException | InvalidProtocolBufferException e) {\n future.completeExceptionally(e);\n }\n return future;\n } else {\n Optional leaderAddress = metadataStore.electionService().leaderAddress();\n if (leaderAddress.isEmpty()) {\n return CompletableFuture.failedFuture(new ControllerException(Code.NO_LEADER_VALUE, \"No leader is elected yet\"));\n }\n metadataStore.controllerClient().createTopic(leaderAddress.get(), request).whenComplete((res, e) -> {\n if (null != e) {\n future.completeExceptionally(e);\n } else {\n future.complete(res);\n }\n });\n }\n break;\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCreateTopic()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCreateTopic", + "reference": " @Test\n public void testCreateTopic() {\n MetadataStore metadataStore = Mockito.mock(MetadataStore.class);\n ControllerClient controllerClient = Mockito.mock(ControllerClient.class);\n Mockito.when(controllerClient.createTopic(ArgumentMatchers.anyString(), ArgumentMatchers.any())).thenReturn(\n CompletableFuture.failedFuture(new CompletionException(new RuntimeException()))\n );\n Mockito.when(metadataStore.controllerClient()).thenReturn(controllerClient);\n Mockito.when(metadataStore.isLeader()).thenReturn(false);\n\n ElectionService electionService = Mockito.mock(ElectionService.class);\n Mockito.when(metadataStore.electionService()).thenReturn(electionService);\n\n Mockito.when(electionService.leaderAddress()).thenReturn(Optional.of(\"localhost:1234\"));\n\n TopicManager topicManager = new TopicManager(metadataStore);\n topicManager.createTopic(CreateTopicRequest.newBuilder().build());\n }", + "completions": [ + "```java\n @Test\n public void testCreateTopic() {\n // Given\n CreateTopicRequest request = new CreateTopicRequest();\n request.setTopic(\"testTopic\");\n request.setCount(3);\n request.setRetentionHours(24);\n request.setAcceptTypes(List.of(MessageType.MESSAGE_TYPE_JSON));\n\n // When\n CompletableFuture future = controller.createTopic(request);\n\n // Then\n assertDoesNotThrow(() -> {\n Long topicId = future.get();\n assertNotNull(topicId);\n assertTrue(topicId > 0);\n });\n }\n```\n" + ], + "line": 63, + "token": 424, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 84, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull(){\n try(MockedStatic sessionMockedStatic = Mockito.mockStatic(Session.class)){\n Session mockSession = mock(Session.class);\n Transport transport = mock(Transport.class);\n sessionMockedStatic.when(()->Session.getDefaultInstance(any(Properties.class))).thenReturn(mockSession);\n when(blacklistService.inBlacklist(anyString())).thenReturn(false);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n when(settingsService.getByCategoryAndKey(anyString(),anyString())).thenReturn(\"test\").thenReturn(5678).thenReturn(\"test\").thenReturn(\"test\").thenReturn(\"123456\");\n when(mockSession.getTransport(anyString())).thenReturn(transport);\n doNothing().when(transport).connect(anyString(),anyInt(),anyString(),anyString());\n doAnswer(invocationOnMock -> {\n InternetAddress[] internetAddressList = invocationOnMock.getArgument(1);\n Assertions.assertEquals(addressList.get(0),internetAddressList[0].getAddress());\n return null;\n }).when(transport).sendMessage(any(MimeMessage.class),any());\n mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n } catch (MessagingException e) {\n throw new RuntimeException(e);\n }\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = Arrays.asList(\"test1@example.com\", \"test2@example.com\");\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://example.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 85, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead", + "reference": " @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 1,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n AstraConfigs.IndexerConfig headLocationAndNoRecoveryConfig =\n AstraConfigs.IndexerConfig.newBuilder()\n .setCreateRecoveryTasksOnStart(false)\n .setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST)\n .build();\n\n // When there is no data and ReadFromLocationOnStart is set to LATEST, return the current head\n assertThat(\n recoveryTaskCreator.determineStartingOffset(1000, 0, headLocationAndNoRecoveryConfig))\n .isEqualTo(1000);\n\n // Data exists for not for this partition.\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTime = 1;\n final long endTime = 100;\n final long maxOffset = 100;\n\n final SnapshotMetadata partition1 =\n new SnapshotMetadata(name, path, startTime, endTime, maxOffset, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition1);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition1));\n\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n\n final SnapshotMetadata partition11 =\n new SnapshotMetadata(\n name + \"1\", path, endTime + 1, endTime * 2, maxOffset * 2, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition11);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition11));\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(partition1, partition11);\n assertThat(recoveryTaskCreator.determineStartingOffset(0, 0, indexerConfig)).isNegative();\n\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n \"2\",\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(false);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n\n // When\n long result = determineStartingOffset(currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(currentEndOffsetForPartition, result);\n }\n```\n" + ], + "line": 128, + "token": 935, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 86, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime =\n LocalDateTime.of(2020, 10, 1, 10, 10, 0).atZone(ZoneOffset.UTC).toInstant();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n long startOffset = 1;\n long endOffset = msgsToProduce - 1;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry))\n .isEqualTo(endOffset - startOffset + 1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(0);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(0);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isEmpty();\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 0,\n 100,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 87, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"3g\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n String expirationTime = \"10x\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(expirationTime);\n assertEquals(CacheControlHeader.DEFAULT_BASE_VALUE + CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_empty_collection()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_empty_collection", + "reference": " @Test\n public void should_clone_empty_collection() {\n List original = new ArrayList<>();\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_empty_collection() {\n // Given\n Collection collection = new ArrayList<>();\n\n // When\n Collection clonedCollection = (Collection) clone(collection);\n\n // Then\n assertNotSame(collection, clonedCollection);\n assertEquals(collection, clonedCollection);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 89, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testTimeBoundSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testTimeBoundSearch", + "reference": " @Test\n public void testTimeBoundSearch() {\n Instant time = Instant.now();\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(1, time));\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(2, time.plusSeconds(100)));\n strictLogStore.logStore.commit();\n strictLogStore.logStore.refresh();\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(2);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(0);\n assertThat(getTimerCount(REFRESHES_TIMER, strictLogStore.metricsRegistry)).isEqualTo(1);\n\n // Start inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Extended range still only picking one element.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(90).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Both ranges are inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.toEpochMilli(),\n time.plusSeconds(100).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n // Extended range to pick up both events\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(1000).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n }", + "completions": [ + "```java\n @Test\n public void testTimeBoundSearch() {\n // Given\n String dataset = \"test_dataset\";\n String queryStr = \"test_query\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n for (LogMessage logMessage : result.getResults()) {\n assertTrue(logMessage.getTimestampMs() >= startTimeMsEpoch);\n assertTrue(logMessage.getTimestampMs() <= endTimeMsEpoch);\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 90, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "\t@ParameterizedTest\n\t@MethodSource", + "reference": "\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\t// setup\n\t\tString key = \"key\";\n\t\tTestParameters parameters = new TestParameters();\n\t\tObject testValue = theParams.getTestValue();\n\n\t\t// test\n\t\tif (theParams.isExpectedToWork()) {\n\t\t\tparameters.setUserData(key, testValue);\n\t\t\tassertFalse(parameters.getUserData().isEmpty());\n\t\t\tassertEquals(testValue, parameters.getUserData().get(key));\n\t\t} else {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, testValue);\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tString dataType = testValue.getClass().getName();\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid data type provided \" + dataType),\n\t\t\t\t\tex.getMessage());\n\t\t\t\tassertTrue(parameters.getUserData().isEmpty());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\tUser user = new User();\n\t\tuser.setUserData(theParams.getKey(), theParams.getValue());\n\t\tassertEquals(theParams.getValue(), user.getUserData().get(theParams.getKey()));\n\t}\n\n\tstatic Stream parameters() {\n\t\treturn Stream.of(\n\t\t\t\tnew TestParam(\"stringKey\", \"stringValue\"),\n\t\t\t\tnew TestParam(\"numberKey\", 123),\n\t\t\t\tnew TestParam(\"booleanKey\", true),\n\t\t\t\tnew TestParam(\"nullKey\", null)\n\t\t);\n\t}\n\n\tstatic class TestParam {\n\t\tprivate final String key;\n\t\tprivate final Object value;\n\n\t\tTestParam(String key, Object value) {\n\t\t\tthis.key = key;\n\t\t\tthis.value = value;\n\t\t}\n\n\t\tString getKey() {\n\t\t\treturn key;\n\t\t}\n\n\t\tObject getValue() {\n\t\t\treturn value;\n\t\t}\n\t}\n```\n" + ], + "line": 24, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\n\t\tString text = \"Hello, how are you?\";\n\n\t\tEmbedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));\n\n\t\tserver\n\t\t\t.expect(requestToUriTemplate(\"/models/{generative}:embedText?key={apiKey}\",\n\t\t\t\t\tVertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))\n\t\t\t.andExpect(method(HttpMethod.POST))\n\t\t\t.andExpect(content().json(objectMapper.writeValueAsString(Map.of(\"text\", text))))\n\t\t\t.andRespond(withSuccess(objectMapper.writeValueAsString(Map.of(\"embedding\", expectedEmbedding)),\n\t\t\t\t\tMediaType.APPLICATION_JSON));\n\n\t\tEmbedding embedding = client.embedText(text);\n\n\t\tassertThat(embedding).isEqualTo(expectedEmbedding);\n\n\t\tserver.verify();\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\tEmbeddingResponse response = new EmbeddingResponse(expectedEmbedding);\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(response);\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Mono syncAclWithAclCsv(KafkaCluster cluster, String csv) {\n return adminClientService.get(cluster)\n .flatMap(ac -> ac.listAcls(ResourcePatternFilter.ANY).flatMap(existingAclList -> {\n var existingSet = Set.copyOf(existingAclList);\n var newAcls = Set.copyOf(AclCsv.parseCsv(csv));\n var toDelete = Sets.difference(existingSet, newAcls);\n var toAdd = Sets.difference(newAcls, existingSet);\n logAclSyncPlan(cluster, toAdd, toDelete);\n if (toAdd.isEmpty() && toDelete.isEmpty()) {\n return Mono.empty();\n }\n log.info(\"Starting new ACLs creation\");\n return ac.createAcls(toAdd)\n .doOnSuccess(v -> {\n log.info(\"{} new ACLs created\", toAdd.size());\n log.info(\"Starting ACLs deletion\");\n })\n .then(ac.deleteAcls(toDelete)\n .doOnSuccess(v -> log.info(\"{} ACLs deleted\", toDelete.size())));\n }));\n }\n```\n###test function signature:\n```java\n @Test\n void testSyncAclWithAclCsv()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSyncAclWithAclCsv", + "reference": " @Test\n void testSyncAclWithAclCsv() {\n var existingBinding1 = new AclBinding(\n new ResourcePattern(ResourceType.TOPIC, \"*\", PatternType.LITERAL),\n new AccessControlEntry(\"User:test1\", \"*\", AclOperation.READ, AclPermissionType.ALLOW));\n\n var existingBinding2 = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"group1\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test2\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n var newBindingToBeAdded = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"groupNew\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test3\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n when(adminClientMock.listAcls(ResourcePatternFilter.ANY))\n .thenReturn(Mono.just(List.of(existingBinding1, existingBinding2)));\n\n ArgumentCaptor> createdCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.createAcls(createdCaptor.capture()))\n .thenReturn(Mono.empty());\n\n ArgumentCaptor> deletedCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.deleteAcls(deletedCaptor.capture()))\n .thenReturn(Mono.empty());\n\n aclsService.syncAclWithAclCsv(\n CLUSTER,\n \"Principal,ResourceType, PatternType, ResourceName,Operation,PermissionType,Host\\n\"\n + \"User:test1,TOPIC,LITERAL,*,READ,ALLOW,*\\n\"\n + \"User:test3,GROUP,PREFIXED,groupNew,DESCRIBE,DENY,localhost\"\n ).block();\n\n Collection createdBindings = createdCaptor.getValue();\n assertThat(createdBindings)\n .hasSize(1)\n .contains(newBindingToBeAdded);\n\n Collection deletedBindings = deletedCaptor.getValue();\n assertThat(deletedBindings)\n .hasSize(1)\n .contains(existingBinding2);\n }", + "completions": [ + "```java\n @Test\n void testSyncAclWithAclCsv() {\n // Given\n KafkaCluster cluster = new KafkaCluster(\"testCluster\", \"localhost:9092\");\n String csv = \"principal,host,operation,permissionType\\nUser:Alice,*,Read,Allow\\nUser:Bob,*,Write,Allow\";\n AdminClient adminClient = Mockito.mock(AdminClient.class);\n when(adminClientService.get(cluster)).thenReturn(Mono.just(adminClient));\n when(adminClient.listAcls(ResourcePatternFilter.ANY)).thenReturn(Flux.just(\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Alice\", \"*\", AclOperation.READ, AclPermissionType.ALLOW)),\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Bob\", \"*\", AclOperation.WRITE, AclPermissionType.ALLOW))\n ).collectList());\n\n // When\n Mono result = service.syncAclWithAclCsv(cluster, csv);\n\n // Then\n StepVerifier.create(result)\n .verifyComplete();\n\n verify(adminClient).createAcls(any());\n verify(adminClient).deleteAcls(any());\n }\n```\n" + ], + "line": 32, + "token": 259, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 93, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyNoResults", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final SearchResponse mockResponse = mock(SearchResponse.class);\n final SearchHits mockHits = mock(SearchHits.class);\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n when(mockHits.getHits()).thenReturn(new SearchHit[0]);\n when(tenantAwareClient.search(any(SearchRequest.class))).thenReturn(mockResponse);\n\n // When\n try {\n processInstanceService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 94, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyMultipleRecoveryBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(3);\n assertThat(recoveryTasks).contains(recoveryTask1, recoveryTask11);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo(recoveryStartOffset * 3 + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(100);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 95, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n private String saveAndChecksumBlockList(String jobName, LazyBlockList blockList) {\n BlobId blobId = getBlobId(jobName, getBlockListFileName(blockList.getName()));\n try (BufferedOutputStream gcsWriter =\n new BufferedOutputStream(gcsUtils.openOutputStream(blobId))) {\n MessageDigest messageDigest = MessageDigest.getInstance(checksumAlgorithm);\n blockList.consumeAll(\n (byteArray, length) -> {\n try {\n gcsWriter.write(byteArray, 0, length);\n } catch (IOException e) {\n throw new RuntimeException(e);\n }\n messageDigest.update(byteArray, 0, length);\n });\n return base16().lowerCase().encode(messageDigest.digest());\n } catch (IOException | NoSuchAlgorithmException e) {\n throw new RuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "saveAndChecksumBlockList_success", + "reference": " @Test\n void saveAndChecksumBlockList_success() throws Exception {\n String payload = \"somedata\\n\";\n String payloadChecksum = \"0737c8e591c68b93feccde50829aca86a80137547d8cfbe96bab6b20f8580c63\";\n\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream((\"bsa-checksum\\n\" + payload).getBytes(UTF_8)));\n lazyBlockList = new LazyBlockList(BlockListType.BLOCK, connection);\n\n ImmutableMap checksums =\n gcsClient.saveAndChecksumBlockList(\"some-name\", ImmutableList.of(lazyBlockList));\n assertThat(gcsUtils.existsAndNotEmpty(BlobId.of(\"my-bucket\", \"some-name/BLOCK.csv\"))).isTrue();\n assertThat(checksums).containsExactly(BlockListType.BLOCK, payloadChecksum);\n assertThat(gcsClient.readBlockList(\"some-name\", BlockListType.BLOCK))\n .containsExactly(\"somedata\");\n }", + "completions": [ + "```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception {\n // Given\n String jobName = \"testJob\";\n LazyBlockList blockList = mock(LazyBlockList.class);\n BlobId blobId = mock(BlobId.class);\n BufferedOutputStream gcsWriter = mock(BufferedOutputStream.class);\n MessageDigest messageDigest = mock(MessageDigest.class);\n when(gcsUtils.openOutputStream(blobId)).thenReturn(gcsWriter);\n when(checksumAlgorithm).thenReturn(\"SHA-256\");\n when(messageDigest.digest()).thenReturn(new byte[] {1, 2, 3, 4, 5});\n when(base16().lowerCase().encode(messageDigest.digest())).thenReturn(\"0102030405\");\n\n // When\n String checksum = saveAndChecksumBlockList(jobName, blockList);\n\n // Then\n assertEquals(\"0102030405\", checksum);\n verify(gcsWriter).write(any(), anyInt(), anyInt());\n verify(messageDigest).update(any(), anyInt(), anyInt());\n }\n```\n" + ], + "line": 30, + "token": 211, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ListenableFuture> findMissingBlobs(\n Iterable blobDigests, RequestMetadata requestMetadata) {\n // Some requests have been blocked, and we should tell the client we refuse to perform a lookup.\n try {\n if (inDenyList(requestMetadata)) {\n return immediateFailedFuture(\n Status.UNAVAILABLE\n .withDescription(\"The action associated with this request is forbidden\")\n .asException());\n }\n } catch (IOException e) {\n return immediateFailedFuture(Status.fromThrowable(e).asException());\n }\n\n // Empty blobs are an exceptional case. Filter them out.\n // If the user only requested empty blobs we can immediately tell them we already have it.\n Iterable nonEmptyDigests =\n Iterables.filter(blobDigests, (digest) -> digest.getSizeBytes() != 0);\n if (Iterables.isEmpty(nonEmptyDigests)) {\n return immediateFuture(ImmutableList.of());\n }\n\n if (configs.getServer().isFindMissingBlobsViaBackplane()) {\n return findMissingBlobsViaBackplane(nonEmptyDigests, requestMetadata);\n }\n\n return findMissingBlobsQueryingEachWorker(nonEmptyDigests, requestMetadata);\n }\n```\n###test function signature:\n```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "findMissingBlobsTest_ViaBackPlane", + "reference": " @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n Set activeWorkers = ImmutableSet.of(\"worker1\", \"worker2\", \"worker3\");\n Set expiredWorkers = ImmutableSet.of(\"workerX\", \"workerY\", \"workerZ\");\n Set imposterWorkers = ImmutableSet.of(\"imposter1\", \"imposter2\", \"imposter3\");\n\n Set availableDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFound1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build());\n\n Set missingDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"missing1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build());\n\n Set digestAvailableOnImposters =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFoundOnImposter1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter3\").setSizeBytes(1).build());\n\n Set emptyDigests =\n new HashSet<>(\n Arrays.asList(\n Digest.newBuilder().setHash(\"empty1\").build(),\n Digest.newBuilder().setHash(\"empty2\").build()));\n\n Iterable allDigests =\n Iterables.concat(\n availableDigests,\n missingDigests,\n emptyDigests,\n digestAvailableOnImposters,\n Arrays.asList(\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build()));\n\n Map> digestAndWorkersMap = new HashMap<>();\n\n for (Digest digest : availableDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(activeWorkers));\n }\n for (Digest digest : missingDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(expiredWorkers));\n }\n for (Digest digest : digestAvailableOnImposters) {\n digestAndWorkersMap.put(digest, getRandomSubset(imposterWorkers));\n }\n\n BuildfarmConfigs buildfarmConfigs = instance.getBuildFarmConfigs();\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(true);\n Set activeAndImposterWorkers =\n Sets.newHashSet(Iterables.concat(activeWorkers, imposterWorkers));\n\n when(mockBackplane.getStorageWorkers()).thenReturn(activeAndImposterWorkers);\n when(mockBackplane.getBlobDigestsWorkers(any(Iterable.class))).thenReturn(digestAndWorkersMap);\n when(mockInstanceLoader.load(anyString())).thenReturn(mockWorkerInstance);\n when(mockWorkerInstance.findMissingBlobs(anyIterable(), any(RequestMetadata.class)))\n .thenReturn(Futures.immediateFuture(new ArrayList<>()));\n\n long serverStartTime = 1686951033L; // june 15th, 2023\n Map workersStartTime = new HashMap<>();\n for (String worker : activeAndImposterWorkers) {\n workersStartTime.put(worker, serverStartTime);\n }\n when(mockBackplane.getWorkersStartTimeInEpochSecs(activeAndImposterWorkers))\n .thenReturn(workersStartTime);\n long oneDay = 86400L;\n for (Digest digest : availableDigests) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime + oneDay);\n }\n for (Digest digest : digestAvailableOnImposters) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime - oneDay);\n }\n\n Iterable actualMissingDigests =\n instance.findMissingBlobs(allDigests, RequestMetadata.getDefaultInstance()).get();\n Iterable expectedMissingDigests =\n Iterables.concat(missingDigests, digestAvailableOnImposters);\n\n assertThat(actualMissingDigests).containsExactlyElementsIn(expectedMissingDigests);\n verify(mockWorkerInstance, atMost(3))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n verify(mockWorkerInstance, atLeast(1))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n\n for (Digest digest : actualMissingDigests) {\n assertThat(digest).isNotIn(availableDigests);\n assertThat(digest).isNotIn(emptyDigests);\n assertThat(digest).isIn(expectedMissingDigests);\n }\n\n // reset BuildfarmConfigs\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(false);\n }", + "completions": [ + "```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n // Given\n Iterable blobDigests = Arrays.asList(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n RequestMetadata requestMetadata = RequestMetadata.newBuilder().setClientId(\"client1\").build();\n\n // When\n ListenableFuture> future =\n underTest.findMissingBlobs(blobDigests, requestMetadata);\n\n // Then\n Iterable missingBlobs = future.get();\n assertThat(missingBlobs).containsExactly(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n }\n```\n" + ], + "line": 37, + "token": 276, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 97, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent_notFound()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent_notFound", + "reference": " @Test\n public void testReadComponent_notFound() {\n Component component = mock(Component.class);\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(null);\n try {\n underTest.readComponent(\"someid\", \"testRepositoryName\");\n fail(\"Exception should have been thrown\");\n }\n catch (WebApplicationException e) {\n assertThat(e.getResponse(), is(notNullValue()));\n assertThat(e.getResponse().getStatus(), is(404));\n }\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent_notFound() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(() -> storageTx);\n when(storageTx.findComponent(componentId)).thenReturn(null);\n\n // When\n try {\n readComponent(repository, componentId);\n fail(\"Expected WebApplicationException\");\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Status.NOT_FOUND.getStatusCode(), e.getResponse().getStatus());\n }\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture auth(MQTT3AuthData authData) {\n try {\n HttpRequest request = HttpRequest.newBuilder()\n .uri(URI.create(webhookURI + \"/auth\"))\n .header(\"Content-Type\", \"application/json\")\n .POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))\n .timeout(Duration.ofSeconds(5))\n .build();\n return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())\n .thenApply(response -> {\n if (response.statusCode() == 200) {\n try {\n MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();\n JsonFormat.parser()\n .ignoringUnknownFields()\n .merge(response.body(), resultBuilder);\n return resultBuilder.build();\n } catch (Throwable e) {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build();\n }\n } else {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(\"Authenticate failed\")\n .build())\n .build();\n }\n })\n .exceptionally(e -> {\n System.out.println(\"Failed to call webhook: \" + e.getMessage());\n return null;\n });\n } catch (Throwable e) {\n return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAuth()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAuth", + "reference": " @Test\n public void testAuth() {\n MQTT3AuthData authedUser = MQTT3AuthData.newBuilder()\n .setUsername(\"authUser\")\n .build();\n\n MQTT3AuthData unauthUser = MQTT3AuthData.newBuilder()\n .setUsername(\"unauthUser\")\n .build();\n testServer.addAuthedUser(authedUser.getUsername());\n WebHookBasedAuthProvider provider = new WebHookBasedAuthProvider(testServer.getURI());\n MQTT3AuthResult authResult = provider.auth(authedUser).join();\n assertTrue(authResult.hasOk());\n\n authResult = provider.auth(unauthUser).join();\n assertFalse(authResult.hasOk());\n }", + "completions": [ + "```java\n @Test\n public void testAuth() {\n // Given\n String webhookURI = \"http://localhost:8080\";\n HttpClient httpClient = HttpClient.newHttpClient();\n MQTT3AuthData authData = MQTT3AuthData.newBuilder()\n .setClientId(\"client1\")\n .setUsername(\"user1\")\n .setPassword(\"password1\")\n .build();\n MQTT3AuthResult expectedResult = MQTT3AuthResult.newBuilder()\n .setSuccess(Success.newBuilder()\n .setIsSuperuser(false)\n .build())\n .build();\n\n // When\n CompletableFuture result = new MQTT3AuthService(webhookURI, httpClient).auth(authData);\n\n // Then\n assertNotNull(result);\n assertEquals(expectedResult, result.join());\n }\n```\n" + ], + "line": 59, + "token": 335, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 99, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_success", + "reference": " @Test\n public void testValidateRoleDuplicate_success() {\n // 调用,不会抛异常\n roleService.validateRoleDuplicate(randomString(), randomString(), null);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_success() {\n // Given\n String name = \"TestRole\";\n String code = \"TEST_ROLE\";\n Long id = 1L;\n\n // When\n // Call the method to be tested\n validateRoleDuplicate(name, code, id);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/3/QS/line_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Java/3/QS/line_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..7b438acb03bccfd96bda4543352b27c6a88dc02f --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/3/QS/line_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(22~37),subset_1(37~57),subset_2(57~128) +StarCoder2-15b,31.93,32.36,22.45 +CodeLlama-7b,34.19,32.50,24.35 +CodeLlama-13b,30.66,30.53,23.52 +CodeLlama-34b,31.38,32.19,23.85 +DeepSeek-Coder-1.3b,29.87,32.10,23.90 +DeepSeek-Coder-6.7b,20.28,20.43,22.83 +DeepSeek-Coder-33b,34.69,33.71,24.80 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/3/QS/token_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Java/3/QS/token_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..b384140fd7c2eb5ef1ae027aced9e5f506fb107f --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/3/QS/token_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~279),subset_1(279~377),subset_2(424~952) +StarCoder2-15b,30.79,34.29,21.70 +CodeLlama-7b,33.23,34.33,23.51 +CodeLlama-13b,29.66,31.77,23.32 +CodeLlama-34b,29.98,33.77,23.71 +DeepSeek-Coder-1.3b,28.72,33.62,23.56 +DeepSeek-Coder-6.7b,20.00,21.32,22.23 +DeepSeek-Coder-33b,33.30,35.19,24.75 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/3/QS/tongji.py b/dataset/Test Generation/ComplexCodeEval-Java/3/QS/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/3/QS/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/4/EI/EI.json b/dataset/Test Generation/ComplexCodeEval-Java/4/EI/EI.json new file mode 100644 index 0000000000000000000000000000000000000000..73c7e833444012d6b3d6d57357ecb211ead4c9fa --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/4/EI/EI.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTask", + "reference": " @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(TEST_S3_BUCKET);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 100L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockKafkaConsumer.consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(true);\n Mockito.when(mockKafkaConsumer.close()).thenReturn(true);\n Mockito.when(mockChunkManager.stopAsync()).thenReturn(true);\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenReturn(true);\n Mockito.when(AstraKafkaConsumer.makeKafkaConfig(Mockito.any(), Mockito.anyInt())).thenReturn(new Properties());\n Mockito.when(RecoveryChunkManager.fromConfig(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(mockChunkManager);\n Mockito.when(adminClient.listConsumerGroupOffsets(Mockito.anyString())).thenReturn(new ConsumerGroupOffsets());\n Mockito.when(AstraConfig.getRecoveryConfig()).thenReturn(new RecoveryConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig()).thenReturn(new KafkaConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic()).thenReturn(\"testTopic\");\n Mockito.when(AstraConfig.getIndexerConfig()).thenReturn(new IndexerConfig());\n Mockito.when(AstraConfig.getS3Config()).thenReturn(new S3Config());\n Mockito.when(meterRegistry.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class));\n Mockito.when(Timer.start(Mockito.any())).thenReturn(Mockito.mock(Timer.Sample.class));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n Assert.assertTrue(result);\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).prepConsumerForConsumption(Mockito.anyLong());\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong());\n Mockito.verify(mockChunkManager, Mockito.times(1)).waitForRollOvers();\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).close();\n Mockito.verify(mockChunkManager, Mockito.times(1)).stopAsync();\n Mockito.verify(mockChunkManager, Mockito.times(1)).awaitTerminated(Mockito.anyLong());\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 1, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskFailure", + "reference": " @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n String fakeS3Bucket = \"fakeBucket\";\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(fakeS3Bucket);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isFalse();\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 10L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(false);\n Mockito.when(mockKafkaConsumer.close()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.stopAsync()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertFalse(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 2, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public int migration(String srcTable, List uuids, String targetTable) {\n if (Objects.isNull(srcTable) || Objects.isNull(targetTable) || CollectionUtils.isEmpty(uuids)) {\n throw new SwValidationException(\n ValidSubject.EVALUATION, \"srcTable, targetTable and uuids must not be null\");\n }\n var filter = TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build();\n return store.migration(DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(filter)\n .build()\n );\n }\n```\n###test function signature:\n```java\n @Test\n public void testMigration()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMigration", + "reference": " @Test\n public void testMigration() {\n var srcTable = \"p/1/space/1/eval/summary\";\n var targetTable = \"p/1/eval/summary\";\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(null, List.of(\"uuid1\"), targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, null, targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, List.of(\"uuid1\"), null));\n\n evaluationRepo.migration(srcTable, List.of(\"uuid1\", \"uuid2\"), targetTable);\n\n verify(dataStore, times(1)).migration(\n DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(\n TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(List.of(\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid1\")))\n .build(),\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid2\")))\n .build()))\n .build())\n .build()\n );\n }", + "completions": [ + "```java\n @Test\n public void testMigration() {\n // Given\n String srcTable = \"srcTable\";\n List uuids = Arrays.asList(\"uuid1\", \"uuid2\", \"uuid3\");\n String targetTable = \"targetTable\";\n DataStoreMigrationRequest expectedRequest = DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build()\n )\n .build();\n when(store.migration(any(DataStoreMigrationRequest.class))).thenReturn(1);\n\n // When\n int result = migrationService.migration(srcTable, uuids, targetTable);\n\n // Then\n assertEquals(1, result);\n verify(store).migration(expectedRequest);\n }\n```\n" + ], + "line": 37, + "token": 221, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDatasMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDatasMapTest", + "reference": " @Test\n public void sendDatasMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"doris_test_db\";\n List> dataList = new ArrayList<>();\n for (int i = 0; i < 100; i++) {\n Map data = new HashMap<>();\n data.put(\"id\", i);\n data.put(\"name\", \"张三\");\n data.put(\"message\", \"股东风波的法规的规定翻跟斗广泛的给yryrtytr\");\n System.out.println(gson.toJson(data));\n dataList.add(data);\n\n }\n dorisStreamLoad.sendData(database, tableName, columnList, dataList);\n }", + "completions": [ + "```java\n @Test\n public void sendDataTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 4, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_serializable_key_and_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_serializable_key_and_value", + "reference": " @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map original = new HashMap<>();\n original.put(\"key\", new SerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map originalMap = new HashMap<>();\n originalMap.put(\"one\", 1);\n originalMap.put(\"two\", 2);\n originalMap.put(\"three\", 3);\n\n Map clonedMap = clone(originalMap);\n\n assertNotSame(originalMap, clonedMap);\n assertEquals(originalMap, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 5, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void encode(String schemaName, long nextTransactionOffset) {\n if (schemaName == null) {\n schemaName = UNKNOWN;\n }\n this.schemaName = schemaName;\n this.nextTransactionOffset = nextTransactionOffset;\n final byte[] payloadBytes = payloadToBytes();\n final int payloadLength = payloadBytes.length;\n\n // set logEventHeader\n int eventSize = eventHeaderLengthVersionGt1 + payloadLength;\n setLogEventHeader(new LogEventHeader(drc_filter_log_event.getType(), 0, eventSize, nextTransactionOffset)\n );\n\n // set payload\n final ByteBuf payloadByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(payloadLength);\n payloadByteBuf.writeBytes(payloadBytes);\n payloadByteBuf.skipBytes(payloadLength);\n setPayloadBuf(payloadByteBuf);\n }\n```\n###test function signature:\n```java\n @Test\n public void encode()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "encode", + "reference": " @Test\n public void encode() {\n FilterLogEvent filterLogEvent = new FilterLogEvent();\n filterLogEvent.encode(\"drc1\", 101);\n Assert.assertEquals(\"drc1\", filterLogEvent.getSchemaName());\n Assert.assertEquals(101, filterLogEvent.getNextTransactionOffset());\n\n ByteBuf headerBuf = filterLogEvent.getLogEventHeader().getHeaderBuf();\n headerBuf.readerIndex(0);\n ByteBuf payloadBuf = filterLogEvent.getPayloadBuf();\n payloadBuf.readerIndex(0);\n CompositeByteBuf compositeByteBuf = PooledByteBufAllocator.DEFAULT.compositeDirectBuffer();\n compositeByteBuf.addComponents(true, headerBuf, payloadBuf);\n\n FilterLogEvent newFilterLogEvent = new FilterLogEvent();\n newFilterLogEvent.read(compositeByteBuf);\n compositeByteBuf.release(compositeByteBuf.refCnt());\n Assert.assertEquals(\"drc1\", newFilterLogEvent.getSchemaName());\n Assert.assertEquals(101, newFilterLogEvent.getNextTransactionOffset());\n }", + "completions": [ + "```java\n @Test\n public void encode() {\n // Given\n String schemaName = \"testSchema\";\n long nextTransactionOffset = 1234567890L;\n drc_filter_log_event event = new drc_filter_log_event();\n\n // When\n event.encode(schemaName, nextTransactionOffset);\n\n // Then\n assertEquals(schemaName, event.getSchemaName());\n assertEquals(nextTransactionOffset, event.getNextTransactionOffset());\n assertEquals(drc_filter_log_event.getType(), event.getLogEventHeader().getType());\n assertEquals(0, event.getLogEventHeader().getServerId());\n assertEquals(eventHeaderLengthVersionGt1 + event.getPayloadBuf().readableBytes(), event.getLogEventHeader().getEventLength());\n assertEquals(nextTransactionOffset, event.getLogEventHeader().getNextTransactionOffset());\n assertEquals(event.getPayloadBuf().readableBytes(), event.getPayloadBuf().readableBytes());\n }\n```\n" + ], + "line": 29, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void put(String path, long position, ByteBuf data) {\n writeLock.lock();\n try {\n int dataLength = data.readableBytes();\n NavigableMap cache = path2cache.computeIfAbsent(path, k -> new TreeMap<>());\n Map.Entry pos2value = cache.floorEntry(position);\n long cacheStartPosition;\n Value value;\n if (pos2value == null || pos2value.getKey() + pos2value.getValue().dataLength < position) {\n cacheStartPosition = position;\n value = Value.EMPTY;\n } else {\n cacheStartPosition = pos2value.getKey();\n value = pos2value.getValue();\n }\n // ensure the capacity, if the capacity change then update the cache index\n int moreCapacity = (int) ((position + dataLength) - (cacheStartPosition + value.blocks.length * (long) blockSize));\n int newDataLength = (int) (position + dataLength - cacheStartPosition);\n if (moreCapacity > 0) {\n int[] blocks = ensureCapacity(cacheStartPosition, moreCapacity);\n if (blocks == null) {\n return;\n }\n int[] newBlocks = new int[value.blocks.length + blocks.length];\n System.arraycopy(value.blocks, 0, newBlocks, 0, value.blocks.length);\n System.arraycopy(blocks, 0, newBlocks, value.blocks.length, blocks.length);\n value = new Value(newBlocks, newDataLength);\n } else {\n value = new Value(value.blocks, newDataLength);\n }\n cache.put(cacheStartPosition, value);\n lru.put(new Key(path, cacheStartPosition), value);\n\n // write data to cache\n ByteBuffer cacheByteBuffer = this.cacheByteBuffer.duplicate();\n int positionDelta = (int) (position - cacheStartPosition);\n int written = 0;\n ByteBuffer[] nioBuffers = data.nioBuffers();\n int[] blocks = value.blocks;\n for (ByteBuffer nioBuffer : nioBuffers) {\n ByteBuf buf = Unpooled.wrappedBuffer(nioBuffer);\n while (buf.readableBytes() > 0) {\n int writePosition = positionDelta + written;\n int block = blocks[writePosition / blockSize];\n cacheByteBuffer.position(block * blockSize + writePosition % blockSize);\n int length = Math.min(buf.readableBytes(), blockSize - writePosition % blockSize);\n cacheByteBuffer.put(buf.slice(buf.readerIndex(), length).nioBuffer());\n buf.skipBytes(length);\n written += length;\n }\n }\n } finally {\n writeLock.unlock();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testMergePut() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMergePut", + "reference": " @Test\n public void testMergePut() throws IOException {\n FileCache fileCache = new FileCache(\"/tmp/file_cache_test\", 10 * 1024, 1024);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 1, 500));\n buf.addComponent(true, genBuf((byte) 2, 500));\n buf.addComponent(true, genBuf((byte) 3, 500));\n fileCache.put(\"test\", 3333L, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(1500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 4, 500));\n buf.addComponent(true, genBuf((byte) 5, 500));\n buf.addComponent(true, genBuf((byte) 6, 500));\n fileCache.put(\"test\", 3333L + 1000, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(2500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 7, 500));\n fileCache.put(\"test\", 3333L + 1000 + 1500, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(3000, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n\n assertTrue(verify(fileCache.get(\"test\", 3333L, 500).get(), (byte) 1));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 500, 500).get(), (byte) 2));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1000, 500).get(), (byte) 4));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1500, 500).get(), (byte) 5));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2000, 500).get(), (byte) 6));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2500, 500).get(), (byte) 7));\n }", + "completions": [ + "```java\n @Test\n public void testMergePut() throws IOException {\n // Given\n String path = \"/test/path\";\n long position = 100;\n ByteBuf data = Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5});\n\n // When\n cache.put(path, position, data);\n\n // Then\n NavigableMap cacheMap = cache.path2cache.get(path);\n assertNotNull(cacheMap);\n Map.Entry entry = cacheMap.floorEntry(position);\n assertNotNull(entry);\n assertEquals(position, entry.getKey().longValue());\n assertEquals(5, entry.getValue().dataLength);\n assertEquals(1, entry.getValue().blocks[0]);\n assertEquals(2, entry.getValue().blocks[1]);\n assertEquals(3, entry.getValue().blocks[2]);\n assertEquals(4, entry.getValue().blocks[3]);\n assertEquals(5, entry.getValue().blocks[4]);\n }\n```\n" + ], + "line": 65, + "token": 508, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 7, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchById()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchById", + "reference": " @Test\n public void testSearchById() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult index =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(index.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchById() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 8, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = new ArrayList<>();\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 9, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_object", + "reference": " @Test\n public void should_clone_serializable_object() {\n Object original = new SerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_object() {\n // Given\n Serializable object = new Serializable() {\n private static final long serialVersionUID = 1L;\n // Implement necessary methods\n };\n\n // When\n Serializable clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertEquals(object, clonedObject);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 10, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() {\n\n\t\tvar text = \"Hello, how are you?\";\n\n\t\tEmbedding response = vertexAiPaLm2Api.embedText(text);\n\n\t\tassertThat(response).isNotNull();\n\t\tassertThat(response.value()).hasSize(768);\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(new EmbeddingResponse(expectedEmbedding));\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Optional maybeFetchErrorIntervalMs(short errorCode, int intervalMs) {\n if (errorCode == Errors.NONE.code())\n return Optional.empty();\n\n int pushIntervalMs;\n String reason;\n\n Errors error = Errors.forCode(errorCode);\n switch (error) {\n case INVALID_REQUEST:\n case INVALID_RECORD:\n case UNSUPPORTED_VERSION:\n pushIntervalMs = Integer.MAX_VALUE;\n reason = \"The broker response indicates the client sent an request that cannot be resolved\"\n + \" by re-trying, hence disable telemetry\";\n break;\n case UNKNOWN_SUBSCRIPTION_ID:\n case UNSUPPORTED_COMPRESSION_TYPE:\n pushIntervalMs = 0;\n reason = error.message();\n break;\n case TELEMETRY_TOO_LARGE:\n case THROTTLING_QUOTA_EXCEEDED:\n reason = error.message();\n pushIntervalMs = (intervalMs != -1) ? intervalMs : ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS;\n break;\n default:\n reason = \"Unwrapped error code\";\n log.error(\"Error code: {}. Unmapped error for telemetry, disable telemetry.\", errorCode);\n pushIntervalMs = Integer.MAX_VALUE;\n }\n\n log.debug(\"Error code: {}, reason: {}. Push interval update to {} ms.\", errorCode, reason, pushIntervalMs);\n return Optional.of(pushIntervalMs);\n }\n```\n###test function signature:\n```java\n @Test\n public void testMaybeFetchErrorIntervalMs()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMaybeFetchErrorIntervalMs", + "reference": " @Test\n public void testMaybeFetchErrorIntervalMs() {\n assertEquals(Optional.empty(), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.NONE.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_REQUEST.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_RECORD.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SUBSCRIPTION_ID.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNSUPPORTED_COMPRESSION_TYPE.code(), -1));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), 20000));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), 20000));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SERVER_ERROR.code(), -1));\n }", + "completions": [ + "```java\n @Test\n public void testMaybeFetchErrorIntervalMs() {\n // Test case 1: errorCode is Errors.NONE.code()\n short errorCode = Errors.NONE.code();\n int intervalMs = 1000;\n Optional result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertFalse(result.isPresent());\n\n // Test case 2: errorCode is Errors.INVALID_REQUEST.code()\n errorCode = Errors.INVALID_REQUEST.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 3: errorCode is Errors.UNKNOWN_SUBSCRIPTION_ID.code()\n errorCode = Errors.UNKNOWN_SUBSCRIPTION_ID.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 4: errorCode is Errors.TELEMETRY_TOO_LARGE.code()\n errorCode = Errors.TELEMETRY_TOO_LARGE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 5: errorCode is Errors.THROTTLING_QUOTA_EXCEEDED.code()\n errorCode = Errors.THROTTLING_QUOTA_EXCEEDED.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 6: errorCode is Errors.UNSUPPORTED_COMPRESSION_TYPE.code()\n errorCode = Errors.UNSUPPORTED_COMPRESSION_TYPE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 7: errorCode is Errors.INVALID_RECORD.code()\n errorCode = Errors.INVALID_RECORD.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 8: errorCode is Errors.UNSUPPORTED_VERSION.code()\n errorCode = Errors.UNSUPPORTED_VERSION.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 9: errorCode is an unmapped error\n errorCode = (short) 9999;\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n }\n```\n" + ], + "line": 43, + "token": 290, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 12, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_non_serializable_object", + "reference": " @Test\n public void should_clone_non_serializable_object() {\n Object original = new NonSerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_non_serializable_object() {\n // Given\n Map nonSerializableMap = new HashMap<>();\n nonSerializableMap.put(\"key1\", new Object());\n nonSerializableMap.put(\"key2\", new Object());\n\n // When\n Map clonedMap = objectCloner.clone(nonSerializableMap);\n\n // Then\n assertNotSame(nonSerializableMap, clonedMap);\n assertNotSame(nonSerializableMap.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(nonSerializableMap.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 13, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueOperationCompletesOperationWithCachedActionResult", + "reference": " @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ActionKey actionKey = DigestUtil.asActionKey(Digest.newBuilder().setHash(\"test\").build());\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"operation-with-cached-action-result\")\n .setActionDigest(actionKey.getDigest())\n .build();\n\n ActionResult actionResult = ActionResult.getDefaultInstance();\n\n when(mockBackplane.getActionResult(eq(actionKey))).thenReturn(actionResult);\n\n Poller poller = mock(Poller.class);\n\n instance.queue(executeEntry, poller, DEFAULT_TIMEOUT).get();\n\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(CACHE_CHECK));\n verify(mockBackplane, never()).putOperation(any(Operation.class), eq(QUEUED));\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(COMPLETED));\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n ActionResult actionResult = ActionResult.newBuilder().setExitCode(0).build();\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(\"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n when(cache.get(any(ActionKey.class))).thenReturn(actionResult);\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n future.get();\n\n verify(poller).pause();\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 14, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueActionFailsQueueEligibility", + "reference": " @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n Directory inputRoot = Directory.newBuilder().build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(false);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_INVALID)\n .setSubject(INVALID_PLATFORM)\n .setDescription(\n \"properties are not valid for queue eligibility: []. If you think your\"\n + \" queue should still accept these poperties without them being\"\n + \" specified in queue configuration, consider configuring the queue\"\n + \" with `allow_unmatched: True`\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(ACTION_DIGEST));\n when(executeEntry.getStdoutStreamName()).thenReturn(STDOUT_STREAM_NAME);\n when(executeEntry.getStderrStreamName()).thenReturn(STDERR_STREAM_NAME);\n when(executeEntry.getOperationName()).thenReturn(OPERATION_NAME);\n when(executeEntry.getRequestMetadata()).thenReturn(REQUEST_METADATA);\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n when(poller.pause()).thenThrow(new RuntimeException(\"Queue eligibility failed\"));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n\n assertThrows(RuntimeException.class, future::get);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new SerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new SerializableObject(\"name2\"),\n new SerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key\", new SerializableObject());\n SerializableObject object = new SerializableObject(map);\n\n // When\n SerializableObject clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertNotSame(object.getMap(), clonedObject.getMap());\n assertNotSame(object.getMap().get(\"key\"), clonedObject.getMap().get(\"key\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 16, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForSumAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForSumAgg", + "reference": " @Test\n public void testFullIndexSearchForSumAgg() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new SumAggBuilder(\"test\", TEST_SOURCE_LONG_PROPERTY, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalSum internalSum =\n (InternalSum) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n // 1, 3, 4, 5\n assertThat(internalSum.getValue()).isEqualTo(13);\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForSumAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new SumAggBuilder(\"testField\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertEquals(SumAggBuilder.class, result.getAggregation().getClass());\n assertEquals(\"testField\", ((SumAggBuilder) result.getAggregation()).getField());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 17, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void process(String base64AuthenticatorData, String signature, String clientDataJson, Fido2RegistrationData registration,\n Fido2AuthenticationData authenticationEntity) {\n log.debug(\"Registration: {}\", registration);\n\n AuthData authData = authenticatorDataParser.parseAssertionData(base64AuthenticatorData);\n commonVerifiers.verifyRpIdHash(authData, registration.getDomain());\n\n log.debug(\"User verification option: {}\", authenticationEntity.getUserVerificationOption());\n userVerificationVerifier.verifyUserVerificationOption(authenticationEntity.getUserVerificationOption(), authData);\n\n byte[] clientDataHash = DigestUtils.getSha256Digest().digest(base64Service.urlDecode(clientDataJson));\n\n try {\n int counter = authenticatorDataParser.parseCounter(authData.getCounters());\n commonVerifiers.verifyCounter(registration.getCounter(), counter);\n registration.setCounter(counter);\n\n JsonNode uncompressedECPointNode = dataMapperService.cborReadTree(base64Service.urlDecode(registration.getUncompressedECPoint()));\n PublicKey publicKey = coseService.createUncompressedPointFromCOSEPublicKey(uncompressedECPointNode);\n\n log.debug(\"Uncompressed ECpoint node: {}\", uncompressedECPointNode);\n log.debug(\"EC Public key hex: {}\", Hex.encodeHexString(publicKey.getEncoded()));\n log.debug(\"Registration algorithm: {}, default use: -7\", registration.getSignatureAlgorithm());\n authenticatorDataVerifier.verifyAssertionSignature(authData, clientDataHash, signature, publicKey, -7);\n\n } catch (Fido2CompromisedDevice ex) {\n log.error(\"Error compromised device: {}\", ex.getMessage());\n throw ex;\n } catch (Exception ex) {\n log.error(\"Error to check none assertion: {}\", ex.getMessage());\n throw new Fido2RuntimeException(\"Failed to check none assertion: {}\", ex.getMessage(), ex);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "process_ifCborReadTreeThrowException_fido2RuntimeException", + "reference": " @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n String base64AuthenticatorData = \"base64AuthenticatorData_test\";\n String signature = \"signature_test\";\n String clientDataJson = \"clientDataJson_test\";\n Fido2RegistrationData registration = mock(Fido2RegistrationData.class);\n Fido2AuthenticationData authenticationEntity = mock(Fido2AuthenticationData.class);\n\n when(authenticationEntity.getUserVerificationOption()).thenReturn(UserVerification.preferred);\n when(registration.getDomain()).thenReturn(\"domain_test\");\n when(registration.getCounter()).thenReturn(100);\n when(registration.getUncompressedECPoint()).thenReturn(\"uncompressedECPoint_test\");\n\n when(authenticatorDataParser.parseAssertionData(any())).thenReturn(mock(AuthData.class));\n when(base64Service.urlDecode(any(String.class))).thenReturn(\"decode_test\".getBytes());\n when(dataMapperService.cborReadTree(any(byte[].class))).thenThrow(new IOException(\"IOException_test\"));\n\n Fido2RuntimeException ex = assertThrows(Fido2RuntimeException.class, () -> noneAssertionFormatProcessor.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity));\n assertNotNull(ex);\n assertEquals(ex.getMessage(), \"IOException_test\");\n\n verify(log).debug(eq(\"Registration: {}\"), any(Fido2RegistrationData.class));\n verify(log).debug(eq(\"User verification option: {}\"), any(UserVerification.class));\n verify(commonVerifiers).verifyRpIdHash(any(AuthData.class), any(String.class));\n verify(authenticatorDataParser).parseCounter(any());\n verify(log).error(eq(\"Error to check none assertion: {}\"), any(String.class));\n\n verify(log, never()).error(eq(\"Error compromised device: {}\"), any(String.class));\n verifyNoInteractions(coseService, authenticatorDataVerifier);\n verifyNoMoreInteractions(log);\n }", + "completions": [ + "```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n // Arrange\n String base64AuthenticatorData = \"base64AuthenticatorData\";\n String signature = \"signature\";\n String clientDataJson = \"clientDataJson\";\n Fido2RegistrationData registration = new Fido2RegistrationData();\n Fido2AuthenticationData authenticationEntity = new Fido2AuthenticationData();\n\n when(authenticatorDataParser.parseAssertionData(base64AuthenticatorData)).thenReturn(new AuthData());\n when(authenticatorDataParser.parseCounter(any())).thenReturn(1);\n when(dataMapperService.cborReadTree(any())).thenThrow(new IOException());\n\n // Act & Assert\n Assertions.assertThrows(Fido2RuntimeException.class, () -> {\n fido2AuthenticationService.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity);\n });\n }\n```\n" + ], + "line": 38, + "token": 365, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 18, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void setUserData_invalidKey_throws()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "setUserData_invalidKey_throws", + "reference": "\t@Test\n\tpublic void setUserData_invalidKey_throws() {\n\t\t// setup\n\t\tTestParameters parameters = new TestParameters();\n\n\t\t// test\n\t\tfor (String key : new String[] { null, \"\" }) {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, \"test\");\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid key; key must be non-empty, non-null\"),\n\t\t\t\t\tex.getMessage());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@Test(expected = IllegalArgumentException.class)\n\tpublic void setUserData_invalidKey_throws() {\n\t\tTestClass testClass = new TestClass();\n\t\ttestClass.setUserData(\"\", \"value\");\n\t}\n```\n\nIn this test, we are creating an instance of the TestClass and calling the setUserData method with an empty key and a valid value. We expect this to throw an IllegalArgumentException, which is what we specify in the @Test annotation.\n" + ], + "line": 23, + "token": 170, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueDirectoryMissingErrorsOperation", + "reference": " @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ByteString foo = ByteString.copyFromUtf8(\"foo\");\n Digest subdirDigest = DIGEST_UTIL.compute(foo);\n Directory inputRoot =\n Directory.newBuilder()\n .addDirectories(\n DirectoryNode.newBuilder().setName(\"missing-subdir\").setDigest(subdirDigest))\n .build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(true);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_MISSING)\n .setSubject(\"blobs/\" + DigestUtil.toString(subdirDigest))\n .setDescription(\"The directory `/missing-subdir` was not found in the CAS.\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(10);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(DigestUtil.DEFAULT_INSTANCE_NAME + \"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n ActionKey actionKey = DigestUtil.asActionKey(executeEntry.getActionDigest());\n when(checkCacheFuture(actionKey, any(), any())).thenReturn(immediateFuture(true));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n assertThat(future.get()).isEqualTo(null);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 20, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime = Instant.now();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Start recovery with an offset range that is partially unavailable\n long startOffset = 50;\n long endOffset = 150;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry)).isEqualTo(50);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(51);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 100L,\n 200L,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 21, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public List selectHeaderAccept(String[] accepts) {\n if (accepts.length == 0) {\n return null;\n }\n for (String accept : accepts) {\n MediaType mediaType = MediaType.parseMediaType(accept);\n if (isJsonMime(mediaType) && !isProblemJsonMime(accept)) {\n return Collections.singletonList(mediaType);\n }\n }\n return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts));\n }\n```\n###test function signature:\n```java\n @Test\n public void testSelectHeaderAccept()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSelectHeaderAccept", + "reference": " @Test\n public void testSelectHeaderAccept() {\n String[] accepts = {\"application/json\", \"application/xml\"};\n assertEquals(\"application/json\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"APPLICATION/XML\", \"APPLICATION/JSON\"};\n assertEquals(\"APPLICATION/JSON\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"application/xml\", \"application/json; charset=UTF8\"};\n assertEquals(\"application/json; charset=UTF8\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"text/plain\", \"application/xml\"};\n assertEquals(\"text/plain,application/xml\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{};\n assertNull(apiClient.selectHeaderAccept(accepts));\n }", + "completions": [ + "```java\n @Test\n public void testSelectHeaderAccept() {\n // Given\n String[] accepts = {\"application/json\", \"application/problem+json\"};\n // When\n List result = selectHeaderAccept(accepts);\n // Then\n assertEquals(1, result.size());\n assertEquals(\"application/json\", result.get(0).toString());\n }\n```\n" + ], + "line": 23, + "token": 154, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollEventSentOnConsumerPoll", + "reference": " @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n SubscriptionState subscriptions = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);\n consumer = newConsumer(\n mock(FetchBuffer.class),\n new ConsumerInterceptors<>(Collections.emptyList()),\n mock(ConsumerRebalanceListenerInvoker.class),\n subscriptions,\n singletonList(new RoundRobinAssignor()),\n \"group-id\",\n \"client-id\");\n final TopicPartition tp = new TopicPartition(\"topic\", 0);\n final List> records = singletonList(\n new ConsumerRecord<>(\"topic\", 0, 2, \"key1\", \"value1\"));\n doAnswer(invocation -> Fetch.forPartition(tp, records, true))\n .when(fetchCollector)\n .collectFetch(Mockito.any(FetchBuffer.class));\n\n consumer.subscribe(singletonList(\"topic1\"));\n consumer.poll(Duration.ofMillis(100));\n verify(applicationEventHandler).add(any(PollEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n // Arrange\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n consumer.poll(Duration.ofMillis(100));\n\n // Act\n ConsumerRecords records = consumer.poll(Duration.ofMillis(100));\n\n // Assert\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 23, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetMultiplePartitionRecoveriesBehind", + "reference": " @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n final RecoveryTaskMetadata recoveryTask2 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"2\",\n \"2\",\n recoveryStartOffset * 3 + 1,\n recoveryStartOffset * 4,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask2);\n final RecoveryTaskMetadata recoveryTask21 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"21\", \"2\", recoveryStartOffset * 4 + 1, 50000, createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask21);\n await()\n .until(\n () ->\n recoveryTaskStore\n .listSync()\n .containsAll(\n List.of(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(5);\n assertThat(recoveryTasks)\n .contains(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n // Given\n String partitionId = \"partition1\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 24, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture commitStreamObject(apache.rocketmq.controller.v1.S3StreamObject streamObject,\n List compactedObjects) {\n LOGGER.info(\"commitStreamObject with streamObject: {}, compactedObjects: {}\", TextFormat.shortDebugString(streamObject),\n compactedObjects);\n\n CompletableFuture future = new CompletableFuture<>();\n try (SqlSession session = sessionFactory.openSession()) {\n if (streamObject.getObjectId() == S3Constants.NOOP_OBJECT_ID) {\n LOGGER.error(\"S3StreamObject[object-id={}] is null or objectId is unavailable\", streamObject.getObjectId());\n String msg = String.format(\"S3StreamObject[object-id=%d] is null or objectId is unavailable\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.NOT_FOUND_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n\n long committedTs = System.currentTimeMillis();\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n\n // commit object\n if (!commitObject(streamObject.getObjectId(), streamObject.getStreamId(), streamObject.getObjectSize(), session)) {\n String msg = String.format(\"S3StreamObject[object-id=%d] is not ready for commit\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.ILLEGAL_STATE_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n long dataTs = committedTs;\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n dataTs = compactedObjects.stream()\n .map(id -> {\n // mark destroy compacted object\n S3Object object = s3ObjectMapper.getById(id);\n object.setState(S3ObjectState.BOS_WILL_DELETE);\n object.setMarkedForDeletionTimestamp(new Date());\n s3ObjectMapper.markToDelete(object.getId(), new Date());\n\n // update dataTs to the min compacted object's dataTs\n com.automq.rocketmq.metadata.dao.S3StreamObject s3StreamObject =\n s3StreamObjectMapper.getByObjectId(id);\n return s3StreamObject.getBaseDataTimestamp().getTime();\n })\n .min(Long::compareTo).get();\n }\n\n List toCache = new ArrayList<>();\n\n // create a new S3StreamObject to replace committed ones\n if (streamObject.getObjectId() != S3Constants.NOOP_OBJECT_ID) {\n com.automq.rocketmq.metadata.dao.S3StreamObject newS3StreamObj =\n new com.automq.rocketmq.metadata.dao.S3StreamObject();\n newS3StreamObj.setStreamId(streamObject.getStreamId());\n newS3StreamObj.setObjectId(streamObject.getObjectId());\n newS3StreamObj.setObjectSize(streamObject.getObjectSize());\n newS3StreamObj.setStartOffset(streamObject.getStartOffset());\n newS3StreamObj.setEndOffset(streamObject.getEndOffset());\n newS3StreamObj.setBaseDataTimestamp(new Date(dataTs));\n newS3StreamObj.setCommittedTimestamp(new Date(committedTs));\n s3StreamObjectMapper.create(newS3StreamObj);\n toCache.add(newS3StreamObj);\n }\n\n // delete the compactedObjects of S3Stream\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n compactedObjects.forEach(id -> s3StreamObjectMapper.delete(null, null, id));\n }\n session.commit();\n\n // Update Cache\n s3StreamObjectCache.cache(streamObject.getStreamId(), toCache);\n s3StreamObjectCache.onCompact(streamObject.getStreamId(), compactedObjects);\n\n LOGGER.info(\"S3StreamObject[object-id={}] commit success, compacted objects: {}\",\n streamObject.getObjectId(), compactedObjects);\n future.complete(null);\n } catch (Exception e) {\n LOGGER.error(\"CommitStream failed\", e);\n ControllerException ex = new ControllerException(Code.INTERNAL_VALUE, \"CommitStream failed\" + e.getMessage());\n future.completeExceptionally(ex);\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCommitStreamObject() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCommitStreamObject", + "reference": " @Test\n public void testCommitStreamObject() throws IOException {\n long objectId, streamId = 1;\n\n try (S3MetadataService metadataService = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n objectId = metadataService.prepareS3Objects(3, 5).get();\n } catch (Exception e) {\n throw new RuntimeException(e);\n }\n\n S3StreamObject news3StreamObject = S3StreamObject.newBuilder()\n .setObjectId(objectId + 2)\n .setStreamId(streamId)\n .setObjectSize(111L)\n .build();\n\n try (SqlSession session = this.getSessionFactory().openSession()) {\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n buildS3StreamObjs(objectId, 2, 3, 100L).forEach(s3StreamObjectMapper::create);\n session.commit();\n }\n\n try (S3MetadataService service = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n List compactedObjects = new ArrayList<>();\n compactedObjects.add(objectId);\n compactedObjects.add(objectId + 1);\n service.commitStreamObject(news3StreamObject, compactedObjects);\n }\n\n try (SqlSession session = getSessionFactory().openSession()) {\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3Object s3Object = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object.getState());\n\n S3Object s3Object1 = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object1.getState());\n\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n for (long index = objectId; index < objectId + 2; index++) {\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(index);\n Assertions.assertNull(object);\n }\n\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(objectId + 2);\n Assertions.assertEquals(111L, object.getObjectSize());\n Assertions.assertEquals(streamId, object.getStreamId());\n Assertions.assertTrue(object.getBaseDataTimestamp().getTime() > 0);\n Assertions.assertTrue(object.getCommittedTimestamp().getTime() > 0);\n }\n }", + "completions": [ + "```java\n @Test\n public void testCommitStreamObject() throws IOException {\n // Given\n apache.rocketmq.controller.v1.S3StreamObject streamObject = new apache.rocketmq.controller.v1.S3StreamObject();\n streamObject.setObjectId(1L);\n streamObject.setStreamId(2L);\n streamObject.setObjectSize(3L);\n streamObject.setStartOffset(4L);\n streamObject.setEndOffset(5L);\n\n List compactedObjects = Arrays.asList(6L, 7L, 8L);\n\n // When\n CompletableFuture future = controller.commitStreamObject(streamObject, compactedObjects);\n\n // Then\n assertDoesNotThrow(() -> future.get());\n }\n```\n" + ], + "line": 85, + "token": 664, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 25, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyTooManyResults", + "reference": " @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(totalHits.value).thenReturn(2L);\n Mockito.when(hits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(response.getHits()).thenReturn(hits);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processService.getProcessByKey(processDefinitionKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Expected exception message to match\",\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n e.getMessage());\n }\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String doGet(String url) {\n try (CloseableHttpClient client = HttpClients.createDefault()) {\n HttpGet request = new HttpGet(url);\n HttpResponse response = client.execute(request);\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n return EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n }\n } catch (IOException e) {\n log.error(\"DoGet exception.\", e);\n\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n void testDoGet()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDoGet", + "reference": " @Test\n void testDoGet() {\n String result = HttpUtil.doGet(\"http://baidu.com\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testDoGet() {\n String url = \"http://example.com\";\n String expectedResponse = \"Hello, world!\";\n\n // Mock the HttpClient and HttpResponse\n CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);\n HttpResponse mockResponse = Mockito.mock(HttpResponse.class);\n StatusLine mockStatusLine = Mockito.mock(StatusLine.class);\n HttpEntity mockEntity = Mockito.mock(HttpEntity.class);\n\n // Set up the mock objects to return the expected values\n Mockito.when(mockClient.execute(Mockito.any(HttpGet.class))).thenReturn(mockResponse);\n Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);\n Mockito.when(mockStatusLine.getStatusCode()).thenReturn(HttpStatus.SC_OK);\n Mockito.when(mockResponse.getEntity()).thenReturn(mockEntity);\n Mockito.when(EntityUtils.toString(mockEntity, StandardCharsets.UTF_8)).thenReturn(expectedResponse);\n\n // Call the method under test\n String actualResponse = YourClassName.doGet(url);\n\n // Verify the result\n Assertions.assertEquals(expectedResponse, actualResponse);\n }\n```\nPlease replace `YourClassName` with the actual name of the class where the `doGet` method is located.\n" + ], + "line": 23, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n dpopService.validateDpopThumprint(null, \"any\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n // Given\n String existingThumprint = \"\";\n String requestThumprint = \"requestThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n try {\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), e.getResponse().getStatus());\n assertEquals(\"Thumprint does not match\", e.getMessage());\n }\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDataMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDataMapTest", + "reference": " @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"hera_log_doris_table_2_11\";\n for (int i = 0; i < 1; i++) {\n String jsonString = \"{\\\"linenumber\\\":1256,\\\"tailId\\\":90219,\\\"mqtag\\\":\\\"tags_4_13_90219\\\",\\\"filename\\\":\\\"/home/work/log/log-agent/server.log\\\",\\\"tail\\\":\\\"demo-client-agent\\\",\\\"mqtopic\\\":\\\"90219_hera-demo-client\\\",\\\"message\\\":\\\"2024-01-10 19:34:40,360|INFO ||NettyClientPublicExecutor_1|c.x.mone.log.agent.rpc.task.PingTask|83|ping res: log-agent-server:2022-12-05:0.0.2->2024-01-10 19:34:40 358->10.53.129.250:9899\\\",\\\"logstore\\\":\\\"测试doris日志\\\",\\\"logip\\\":\\\"10.53.129.176\\\",\\\"timestamp\\\":1704886481241}\";\n\n Map map = gson.fromJson(jsonString, Map.class);\n dorisStreamLoad.sendData(database, tableName, columnListNew, map);\n }\n }", + "completions": [ + "```java\n @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 29, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyNoResults", + "reference": " @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final NotFoundException exception =\n Assertions.assertThrows(\n NotFoundException.class, () -> processService.getProcessByKey(processDefinitionKey));\n\n // Then\n Assertions.assertEquals(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 30, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String getDiagramByKey(Long processDefinitionKey) {\n final IdsQueryBuilder q = idsQuery().addIds(processDefinitionKey.toString());\n\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n\n if (response.getHits().getTotalHits().value == 1) {\n final Map result = response.getHits().getHits()[0].getSourceAsMap();\n return (String) result.get(BPMN_XML);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with id '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with id '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\n \"Exception occurred, while obtaining the process diagram: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetDiagramByKeyTooManyResults", + "reference": " @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getDiagramByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 1L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHits.Relation.EQUAL_TO));\n\n // When\n try {\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n processDiagramService.getDiagramByKey(processDefinitionKey);\n } catch (NotFoundException e) {\n // Then\n Assert.assertEquals(\n \"Could not find unique process with id '\" + processDefinitionKey + \"'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 37, + "token": 309, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 31, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n void shutdown() {\n LOG.info(\"Running shutdown hook.\");\n try {\n serviceManager.stopAsync().awaitStopped(30, TimeUnit.SECONDS);\n } catch (Exception e) {\n // stopping timed out\n LOG.error(\"ServiceManager shutdown timed out\", e);\n }\n try {\n curatorFramework.unwrap().close();\n } catch (Exception e) {\n LOG.error(\"Error while closing curatorFramework \", e);\n }\n LOG.info(\"Shutting down LogManager\");\n LogManager.shutdown();\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexerShutdownTwice() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexerShutdownTwice", + "reference": " @Test\n public void testIndexerShutdownTwice() throws Exception {\n startKafkaServer();\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // Create a live partition for this partiton\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTimeMs = 1;\n final long endTimeMs = 100;\n final long maxOffset = 30;\n SnapshotMetadata livePartition0 =\n new SnapshotMetadata(\n name + \"live0\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"0\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition0);\n\n SnapshotMetadata livePartition1 =\n new SnapshotMetadata(\n name + \"live1\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"1\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition1);\n\n final SnapshotMetadata partition0 =\n new SnapshotMetadata(name, path, startTimeMs, endTimeMs, maxOffset, \"0\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition0);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .containsOnly(livePartition1, livePartition0, partition0);\n\n // Empty consumer offset since there is no prior consumer.\n astraIndexer =\n new AstraIndexer(\n chunkManagerUtil.chunkManager,\n curatorFramework,\n makeIndexerConfig(50),\n getKafkaConfig(),\n metricsRegistry);\n astraIndexer.startAsync();\n astraIndexer.awaitRunning(DEFAULT_START_STOP_DURATION);\n await().until(() -> kafkaServer.getConnectedConsumerGroups() == 1);\n\n // Produce more messages since the recovery task is created for head.\n produceMessagesToKafka(kafkaServer.getBroker(), startTime);\n\n consumeMessagesAndSearchMessagesTest(100, 1);\n\n // Live snapshot is deleted, recovery task is created.\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(livePartition1, partition0);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isEqualTo(4);\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).size()).isEqualTo(1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(searchMetadataStore).size()).isEqualTo(1);\n RecoveryTaskMetadata recoveryTask1 =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).get(0);\n assertThat(recoveryTask1.startOffset).isEqualTo(31);\n assertThat(recoveryTask1.endOffset).isEqualTo(99);\n assertThat(recoveryTask1.partitionId).isEqualTo(\"0\");\n\n // Shutting down is idempotent. So, doing it twice shouldn't throw an error.\n astraIndexer.shutDown();\n astraIndexer.shutDown();\n astraIndexer = null;\n }", + "completions": [ + "```java\n @Test\n public void testIndexerShutdownTwice() throws Exception {\n // Given\n Indexer indexer = new Indexer();\n indexer.serviceManager = Mockito.mock(ServiceManager.class);\n indexer.curatorFramework = Mockito.mock(CuratorFramework.class);\n Mockito.when(indexer.serviceManager.stopAsync()).thenReturn(Mockito.mock(ServiceManager.class));\n Mockito.when(indexer.serviceManager.awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class))).thenThrow(new TimeoutException());\n Mockito.when(indexer.curatorFramework.unwrap()).thenReturn(Mockito.mock(Closeable.class));\n Mockito.doThrow(new IOException()).when(indexer.curatorFramework.unwrap()).close();\n\n // When\n indexer.shutdown();\n\n // Then\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).stopAsync();\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class));\n Mockito.verify(indexer.curatorFramework.unwrap(), Mockito.times(1)).close();\n }\n```\n" + ], + "line": 27, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(addressList,\"test\",\"test\",\"test\", SystemEnum.AGENT, MsgTypeEnum.ALARM);\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = null;\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 33, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyWithException", + "reference": " @Test\n public void testGetProcessByKeyWithException() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(OperateRuntimeException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(totalHits.value).thenReturn(2L);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final Executable executable = () -> processService.getProcessByKey(processDefinitionKey);\n\n // Then\n final NotFoundException exception =\n Assertions.assertThrows(NotFoundException.class, executable);\n Assertions.assertEquals(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 34, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testConcurrentSearches() throws InterruptedException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testConcurrentSearches", + "reference": " @Test\n public void testConcurrentSearches() throws InterruptedException {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n AtomicInteger searchFailures = new AtomicInteger(0);\n AtomicInteger statsFailures = new AtomicInteger(0);\n AtomicInteger searchExceptions = new AtomicInteger(0);\n AtomicInteger successfulRuns = new AtomicInteger(0);\n\n Runnable searchRun =\n () -> {\n for (int i = 0; i < 100; i++) {\n try {\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n 0L,\n MAX_TIME,\n 100,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n if (babies.hits.size() != 2) {\n searchFailures.addAndGet(1);\n } else {\n successfulRuns.addAndGet(1);\n }\n } catch (Exception e) {\n searchExceptions.addAndGet(1);\n }\n }\n };\n\n Thread t1 = new Thread(searchRun);\n Thread t2 = new Thread(searchRun);\n t1.start();\n t2.start();\n t1.join();\n t2.join();\n assertThat(searchExceptions.get()).isEqualTo(0);\n assertThat(statsFailures.get()).isEqualTo(0);\n assertThat(searchFailures.get()).isEqualTo(0);\n assertThat(successfulRuns.get()).isEqualTo(200);\n }", + "completions": [ + "```java\n @Test\n public void testConcurrentSearches() throws InterruptedException {\n // Create a fixed thread pool with 10 threads\n ExecutorService executorService = Executors.newFixedThreadPool(10);\n\n // Create a list to hold the futures\n List>> futures = new ArrayList<>();\n\n // Submit 100 search tasks to the executor service\n for (int i = 0; i < 100; i++) {\n // Create a search task\n Callable> searchTask = () -> {\n // Generate a random dataset\n String dataset = UUID.randomUUID().toString();\n\n // Generate a random query string\n String queryStr = UUID.randomUUID().toString();\n\n // Generate a random start time\n long startTimeMsEpoch = System.currentTimeMillis() - ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random end time\n long endTimeMsEpoch = startTimeMsEpoch + ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random number of hits requested\n int howMany = ThreadLocalRandom.current().nextInt(100);\n\n // Generate a random aggregation builder\n AggBuilder aggBuilder = new AggBuilder();\n\n // Perform the search\n return logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n };\n\n // Submit the search task to the executor service and add the future to the list\n futures.add(executorService.submit(searchTask));\n }\n\n // Wait for all search tasks to complete\n for (Future> future : futures) {\n future.get();\n }\n\n // Shutdown the executor service\n executorService.shutdown();\n }\n```\n" + ], + "line": 85, + "token": 608, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 35, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String sendGetData(String path, Map headMap) {\n log.info(\"request tcm, path:{},headMap:{} \",path,headMap);\n String result = \"\";\n CloseableHttpResponse response =null;\n try(CloseableHttpClient httpClient = HttpClients.createDefault()) {\n // 创建get方式请求对象\n HttpGet httpGet = new HttpGet(path);\n if (null != headMap) {\n for (Map.Entry entry : headMap.entrySet()) {\n httpGet.addHeader(entry.getKey(), entry.getValue());\n }\n }\n // 通过请求对象获取响应对象\n response = httpClient.execute(httpGet);\n\n // 获取结果实体\n // 判断网络连接状态码是否正常(0--200都数正常)\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n result = EntityUtils.toString(response.getEntity(), UTF_8);\n } else {\n log.error(\"get请求传输 异常.url:{}, headMap:{}\", path, headMap);\n }\n } catch (ClientProtocolException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n } catch (IOException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n }\n finally {\n try {\n if (null!=response){\n response.close();\n }\n } catch (IOException e) {\n log.error(\"关闭response 异常 \", e);\n }\n }\n log.debug(result);\n return result;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendGetData()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendGetData", + "reference": " @Test\n void testSendGetData() {\n final Map headMap = new HashMap<>();\n final String result = HttpUtils.sendGetData(\"path\", headMap);\n assertThat(result).isEqualTo(\"\");\n }", + "completions": [ + "```java\n @Test\n void testSendGetData() {\n String path = \"http://localhost:8080/api/test\";\n Map headMap = new HashMap<>();\n headMap.put(\"Content-Type\", \"application/json\");\n headMap.put(\"Authorization\", \"Bearer your_token\");\n\n String result = sendGetData(path, headMap);\n\n // Add assertions to verify the result\n // For example, if the result is expected to be a JSON string, you can use a JSON library to parse the result and assert the values\n // Here is an example using the JSON library Jackson\n ObjectMapper objectMapper = new ObjectMapper();\n JsonNode jsonNode = objectMapper.readTree(result);\n assertEquals(\"expectedValue\", jsonNode.get(\"key\").asText());\n }\n```\nPlease replace \"your_token\" with the actual token you want to use for the test. Also, replace \"expectedValue\" and \"key\" with the actual expected value and key from the JSON response.\n" + ], + "line": 51, + "token": 367, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 36, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ByteString retainKeyPrefix(ByteString tenantNS, List topicFilterLevels) {\n ByteString prefix = ByteString.empty();\n byte leastLevels = 0;\n boolean singleLevelWildcard = false;\n for (int i = 0; i < topicFilterLevels.size(); i++) {\n String tfl = topicFilterLevels.get(i);\n if (\"+\".equals(tfl)) {\n leastLevels++;\n singleLevelWildcard = true;\n continue;\n }\n if (\"#\".equals(tfl)) {\n break;\n }\n leastLevels++;\n if (!singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(tfl));\n }\n if (i + 1 < topicFilterLevels.size()) {\n if (!topicFilterLevels.get(i + 1).equals(\"#\") && !singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(NUL));\n }\n }\n }\n return tenantNS.concat(unsafeWrap(new byte[] {leastLevels})).concat(prefix);\n }\n```\n###test function signature:\n```java\n @Test\n public void testRetainKeyPrefix()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRetainKeyPrefix", + "reference": " @Test\n public void testRetainKeyPrefix() {\n String tenantId = \"tenantA\";\n ByteString tenantNS = tenantNS(tenantId);\n assertEquals(retainKeyPrefix(tenantNS, parse(\"#\", false)), tenantNS.concat(levelByte(0)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/#\", false)),\n tenantNS.concat(levelByte(1).concat(copyFromUtf8(\"a\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/+\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(\"a\").concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/b\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)).concat(copyFromUtf8(\"b\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/#\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+\", false)),\n tenantNS.concat(\n levelByte(3).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+/+\", false)),\n tenantNS.concat(\n levelByte(4).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/\", false)),\n tenantNS.concat(levelByte(4).concat(copyFromUtf8(NUL))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/+/\", false)),\n tenantNS.concat(levelByte(5).concat(copyFromUtf8(NUL))));\n }", + "completions": [ + "```java\n @Test\n public void testRetainKeyPrefix() {\n ByteString tenantNS = ByteString.copyFromUtf8(\"tenant1\");\n List topicFilterLevels = Arrays.asList(\"level1\", \"level2\", \"+\", \"#\");\n ByteString expectedPrefix = ByteString.copyFromUtf8(\"tenant1\\u0001level1\\u0000level2\");\n ByteString actualPrefix = retainKeyPrefix(tenantNS, topicFilterLevels);\n assertEquals(expectedPrefix, actualPrefix);\n }\n```\n" + ], + "line": 37, + "token": 255, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 37, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n @Transactional(rollbackFor = Exception.class)\n public void updateFileConfigMaster(Long id) {\n // 校验存在\n validateFileConfigExists(id);\n // 更新其它为非 master\n fileConfigMapper.updateBatch(new FileConfigDO().setMaster(false));\n // 更新\n fileConfigMapper.updateById(new FileConfigDO().setId(id).setMaster(true));\n // 发送刷新配置的消息\n TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {\n\n @Override\n public void afterCommit() {\n fileConfigProducer.sendFileConfigRefreshMessage();\n }\n\n });\n }\n```\n###test function signature:\n```java\n @Test\n public void testUpdateFileConfigMaster_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testUpdateFileConfigMaster_success", + "reference": " @Test\n public void testUpdateFileConfigMaster_success() {\n // mock 数据\n FileConfigDO dbFileConfig = randomFileConfigDO().setMaster(false);\n fileConfigMapper.insert(dbFileConfig);// @Sql: 先插入出一条存在的数据\n FileConfigDO masterFileConfig = randomFileConfigDO().setMaster(true);\n fileConfigMapper.insert(masterFileConfig);// @Sql: 先插入出一条存在的数据\n\n // 调用\n fileConfigService.updateFileConfigMaster(dbFileConfig.getId());\n // 断言数据\n assertTrue(fileConfigMapper.selectById(dbFileConfig.getId()).getMaster());\n assertFalse(fileConfigMapper.selectById(masterFileConfig.getId()).getMaster());\n // verify 调用\n verify(fileConfigProducer).sendFileConfigRefreshMessage();\n }", + "completions": [ + "```java\n @Test\n public void testUpdateFileConfigMaster_success() {\n // 准备数据\n FileConfigDO fileConfigDO = new FileConfigDO();\n fileConfigDO.setId(1L);\n fileConfigDO.setMaster(false);\n fileConfigMapper.insert(fileConfigDO);\n\n // 执行方法\n fileConfigService.updateFileConfigMaster(1L);\n\n // 验证结果\n FileConfigDO updatedFileConfigDO = fileConfigMapper.selectById(1L);\n Assert.assertTrue(updatedFileConfigDO.getMaster());\n\n // 清理数据\n fileConfigMapper.deleteById(1L);\n }\n```\n" + ], + "line": 28, + "token": 164, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\t@Override\n\tpublic void add(List documents) {\n\n\t\tAssert.notNull(documents, \"The document list should not be null.\");\n\t\tif (CollectionUtils.isEmpty(documents)) {\n\t\t\treturn; // nothing to do;\n\t\t}\n\n\t\tfinal var searchDocuments = documents.stream().map(document -> {\n\t\t\tfinal var embeddings = this.embeddingClient.embed(document);\n\t\t\tSearchDocument searchDocument = new SearchDocument();\n\t\t\tsearchDocument.put(ID_FIELD_NAME, document.getId());\n\t\t\tsearchDocument.put(EMBEDDING_FIELD_NAME, embeddings);\n\t\t\tsearchDocument.put(CONTENT_FIELD_NAME, document.getContent());\n\t\t\tsearchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());\n\n\t\t\t// Add the filterable metadata fields as top level fields, allowing filler\n\t\t\t// expressions on them.\n\t\t\tfor (MetadataField mf : this.filterMetadataFields) {\n\t\t\t\tif (document.getMetadata().containsKey(mf.name())) {\n\t\t\t\t\tsearchDocument.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name()));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn searchDocument;\n\t\t}).toList();\n\n\t\tIndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);\n\n\t\tfor (IndexingResult indexingResult : result.getResults()) {\n\t\t\tAssert.isTrue(indexingResult.isSucceeded(),\n\t\t\t\t\tString.format(\"Document with key %s upload is not successfully\", indexingResult.getKey()));\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void addAndSearchTest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "addAndSearchTest", + "reference": "\t@Test\n\tpublic void addAndSearchTest() {\n\n\t\tcontextRunner.run(context -> {\n\n\t\t\tVectorStore vectorStore = context.getBean(VectorStore.class);\n\n\t\t\tvectorStore.add(documents);\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\t\t\t}, hasSize(1));\n\n\t\t\tList results = vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\n\t\t\tassertThat(results).hasSize(1);\n\t\t\tDocument resultDoc = results.get(0);\n\t\t\tassertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());\n\t\t\tassertThat(resultDoc.getContent()).contains(\"The Great Depression (1929–1939) was an economic shock\");\n\t\t\tassertThat(resultDoc.getMetadata()).hasSize(2);\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"meta2\");\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"distance\");\n\n\t\t\t// Remove all documents from the store\n\t\t\tvectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Hello\").withTopK(1));\n\t\t\t}, hasSize(0));\n\t\t});\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void addAndSearchTest() {\n\t\t// Given\n\t\tList documents = new ArrayList<>();\n\t\tDocument document1 = new Document();\n\t\tdocument1.setId(\"1\");\n\t\tdocument1.setContent(\"This is a test document\");\n\t\tdocument1.setMetadata(new HashMap<>());\n\t\tdocument1.getMetadata().put(\"author\", \"John Doe\");\n\t\tdocument1.getMetadata().put(\"date\", \"2022-01-01\");\n\t\tdocuments.add(document1);\n\n\t\tDocument document2 = new Document();\n\t\tdocument2.setId(\"2\");\n\t\tdocument2.setContent(\"This is another test document\");\n\t\tdocument2.setMetadata(new HashMap<>());\n\t\tdocument2.getMetadata().put(\"author\", \"Jane Doe\");\n\t\tdocument2.getMetadata().put(\"date\", \"2022-01-02\");\n\t\tdocuments.add(document2);\n\n\t\t// When\n\t\tsearchService.add(documents);\n\n\t\t// Then\n\t\tSearchRequest searchRequest = new SearchRequest();\n\t\tsearchRequest.setQuery(\"test\");\n\t\tsearchRequest.setFilter(new HashMap<>());\n\t\tsearchRequest.getFilter().put(\"author\", \"John Doe\");\n\t\tSearchResponse searchResponse = searchService.search(searchRequest);\n\n\t\tAssert.assertEquals(1, searchResponse.getDocuments().size());\n\t\tAssert.assertEquals(\"1\", searchResponse.getDocuments().get(0).getId());\n\t}\n```\n" + ], + "line": 39, + "token": 312, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 39, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {\n long startTime = System.currentTimeMillis();\n\n // 校验参数\n try {\n validateParams(metrics);\n } catch (Exception e) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(e.getMessage());\n return;\n }\n\n HttpContext httpContext = createHttpContext(metrics.getNginx());\n HttpUriRequest request = createHttpRequest(metrics.getNginx());\n try (CloseableHttpResponse response = CommonHttpClient.getHttpClient().execute(request, httpContext)){\n // 发起http请求,获取响应数据\n int statusCode = response.getStatusLine().getStatusCode();\n if (statusCode != SUCCESS_CODE) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(\"StatusCode \" + statusCode);\n return;\n }\n String resp = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n\n Long responseTime = System.currentTimeMillis() - startTime;\n // 根据metrics name选择调用不同解析方法\n if (NGINX_STATUS_NAME.equals(metrics.getName()) || AVAILABLE.equals(metrics.getName())) {\n parseNginxStatusResponse(builder, resp, metrics, responseTime);\n } else if (REQ_STATUS_NAME.equals(metrics.getName())) {\n parseReqStatusResponse(builder, resp, metrics, responseTime);\n }\n } catch (Exception e) {\n String errorMsg = CommonUtil.getMessageFromThrowable(e);\n log.info(errorMsg);\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(errorMsg);\n } finally {\n if (request != null) {\n request.abort();\n }\n }\n\n }\n```\n###test function signature:\n```java\n @Test\n public void testNginxCollect()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNginxCollect", + "reference": " @Test\n public void testNginxCollect() {\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n long monitorId = 999;\n String app = \"testNginx\";\n\n Metrics metrics = new Metrics();\n metrics.setNginx(nginxProtocol);\n nginxCollect.collect(builder, monitorId, app, metrics);\n }", + "completions": [ + "```java\n @Test\n public void testNginxCollect() {\n // 创建一个Metrics对象\n Metrics metrics = new Metrics();\n metrics.setName(\"nginx_status\");\n metrics.setNginx(new Nginx(\"http://localhost:8080/nginx_status\"));\n\n // 创建一个CollectRep.MetricsData.Builder对象\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n\n // 创建一个NginxCollect对象\n NginxCollect nginxCollect = new NginxCollect();\n\n // 调用collect方法\n nginxCollect.collect(builder, 1, \"test\", metrics);\n\n // 验证结果\n assertEquals(CollectRep.Code.SUCCESS, builder.getCode());\n assertNotNull(builder.getData());\n assertTrue(builder.getData().getMetrics().containsKey(\"active_connections\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"accepts\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"handled\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"requests\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"reading\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"writing\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"waiting\"));\n }\n```\n" + ], + "line": 51, + "token": 335, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 40, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Transactional(rollbackOn = Exception.class)\n public void submit(@NonNull PartitionPlanConfig partitionPlanConfig)\n throws SchedulerException, ClassNotFoundException {\n Long databaseId = partitionPlanConfig.getDatabaseId();\n Validate.notNull(databaseId, \"DatabaseId can not be null\");\n // disable all related partition plan task\n Database database = this.databaseService.detail(databaseId);\n disablePartitionPlan(database.getId());\n PartitionPlanEntity partitionPlanEntity = modelToEntity(partitionPlanConfig);\n partitionPlanEntity = this.partitionPlanRepository.save(partitionPlanEntity);\n if (!partitionPlanConfig.isEnabled()\n || CollectionUtils.isEmpty(partitionPlanConfig.getPartitionTableConfigs())) {\n log.info(\"Partition plan is disabled or table config is empty, do nothing and return\");\n return;\n }\n Validate.isTrue(partitionPlanConfig.getCreationTrigger() != null, \"Creation trigger can not be null\");\n if (partitionPlanConfig.getDroppingTrigger() == null) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(partitionPlanConfig.getPartitionTableConfigs(),\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n return;\n }\n Map> strategy2TblCfgs =\n partitionPlanConfig.getPartitionTableConfigs().stream().flatMap(tableConfig -> {\n Map> strategy2Cfgs =\n tableConfig.getPartitionKeyConfigs().stream()\n .collect(Collectors.groupingBy(PartitionPlanKeyConfig::getStrategy));\n return strategy2Cfgs.values().stream().map(cfgs -> {\n PartitionPlanTableConfig cfg = new PartitionPlanTableConfig();\n cfg.setPartitionKeyConfigs(cfgs);\n cfg.setTableName(tableConfig.getTableName());\n cfg.setEnabled(tableConfig.isEnabled());\n cfg.setPartitionNameInvoker(tableConfig.getPartitionNameInvoker());\n cfg.setPartitionNameInvokerParameters(tableConfig.getPartitionNameInvokerParameters());\n return cfg;\n });\n }).collect(Collectors.groupingBy(cfg -> cfg.getPartitionKeyConfigs().get(0).getStrategy()));\n List createConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.CREATE);\n List dropConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.DROP);\n if (CollectionUtils.isNotEmpty(createConfigs)) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(createConfigs,\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n if (CollectionUtils.isNotEmpty(dropConfigs)) {\n ScheduleEntity dropScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getDroppingTrigger());\n createPartitionPlanTables(dropConfigs,\n partitionPlanEntity.getId(), dropScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "submit_bothCreateAndDropTrigger_submitSucceed", + "reference": " @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n PartitionPlanTableConfig tableConfig = new PartitionPlanTableConfig();\n tableConfig.setTableName(MYSQL_REAL_RANGE_TABLE_NAME);\n tableConfig.setPartitionNameInvoker(\"CUSTOM_PARTITION_NAME_GENERATOR\");\n SqlExprBasedGeneratorConfig config = new SqlExprBasedGeneratorConfig();\n config.setGenerateExpr(\"concat('p', date_format(from_unixtime(unix_timestamp(\"\n + \"STR_TO_DATE(20240125, '%Y%m%d')) + \"\n + PartitionPlanVariableKey.INTERVAL.getVariable() + \"), '%Y%m%d'))\");\n config.setIntervalGenerateExpr(\"86400\");\n tableConfig.setPartitionNameInvokerParameters(getSqlExprBasedNameGeneratorParameters(config));\n PartitionPlanKeyConfig c3Create = getMysqlc3CreateConfig();\n PartitionPlanKeyConfig datekeyCreate = getMysqldatekeyCreateConfig();\n PartitionPlanKeyConfig dropConfig = getDropConfig();\n tableConfig.setPartitionKeyConfigs(Arrays.asList(c3Create, datekeyCreate, dropConfig));\n\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setPartitionTableConfigs(Collections.singletonList(tableConfig));\n partitionPlanConfig.setFlowInstanceId(1L);\n partitionPlanConfig.setTimeoutMillis(180000L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setDatabaseId(1L);\n\n long t1 = System.currentTimeMillis();\n long t2 = System.currentTimeMillis() + 1000;\n TriggerConfig createTrigger = TestRandom.nextObject(TriggerConfig.class);\n createTrigger.setStartAt(new Date(t1));\n partitionPlanConfig.setCreationTrigger(createTrigger);\n\n TriggerConfig dropTrigger = TestRandom.nextObject(TriggerConfig.class);\n dropTrigger.setStartAt(new Date(t2));\n partitionPlanConfig.setDroppingTrigger(dropTrigger);\n Database database = TestRandom.nextObject(Database.class);\n database.setId(1L);\n Mockito.when(this.databaseService.detail(1L)).thenReturn(database);\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t1;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t2;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.doNothing().when(this.scheduleService).enable(Mockito.isA(ScheduleEntity.class));\n this.partitionPlanScheduleService.submit(partitionPlanConfig);\n List actuals = this.partitionPlanTableRepository.findAll();\n Set scheduleIds = actuals.stream()\n .map(PartitionPlanTableEntity::getScheduleId).collect(Collectors.toSet());\n Assert.assertEquals(scheduleIds.size(), 2);\n }", + "completions": [ + "```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n // Given\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setDatabaseId(1L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanConfig.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanConfig.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanConfig.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanConfig.setTaskId(\"taskId\");\n partitionPlanConfig.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanConfig.setTimeoutMillis(1000L);\n\n Database database = new Database();\n database.setId(1L);\n database.setName(\"database\");\n database.setType(DatabaseType.MYSQL);\n database.setHost(\"localhost\");\n database.setPort(3306);\n database.setUsername(\"root\");\n database.setPassword(\"password\");\n database.setDatabaseName(\"test\");\n database.setProperties(\"{}\");\n database.setStatus(DatabaseStatus.ONLINE);\n database.setCreateTime(new Date());\n database.setUpdateTime(new Date());\n when(databaseService.detail(1L)).thenReturn(database);\n\n PartitionPlanEntity partitionPlanEntity = new PartitionPlanEntity();\n partitionPlanEntity.setId(1L);\n partitionPlanEntity.setDatabaseId(1L);\n partitionPlanEntity.setEnabled(true);\n partitionPlanEntity.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanEntity.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanEntity.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanEntity.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanEntity.setTaskId(\"taskId\");\n partitionPlanEntity.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanEntity.setTimeoutMillis(1000L);\n when(partitionPlanRepository.save(any(PartitionPlanEntity.class))).thenReturn(partitionPlanEntity);\n\n ScheduleEntity createScheduleEntity = new ScheduleEntity();\n createScheduleEntity.setId(1L);\n createScheduleEntity.setDatabaseId(1L);\n createScheduleEntity.setTriggerConfig(new TriggerConfig(TriggerType" + ], + "line": 70, + "token": 530, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 41, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAllQueryWithFullTextSearchDisabled", + "reference": " @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n Instant time = Instant.now();\n Trace.KeyValue customField =\n Trace.KeyValue.newBuilder()\n .setVStr(\"value\")\n .setKey(\"customField\")\n .setFieldType(Schema.SchemaFieldType.KEYWORD)\n .build();\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time, List.of(customField)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n SearchResult termQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"customField:value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(termQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermStrQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermStrQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermNumericQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermNumericQuery.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"*\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNull(result.getInternalAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 42, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollExecutedCommitAsyncCallbacks", + "reference": " @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks() {\n consumer = newConsumer();\n MockCommitCallback callback = new MockCommitCallback();\n completeCommitAsyncApplicationEventSuccessfully();\n doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n completeFetchedCommittedOffsetApplicationEventSuccessfully(mkMap());\n\n consumer.assign(Collections.singleton(new TopicPartition(\"foo\", 0)));\n assertDoesNotThrow(() -> consumer.commitAsync(new HashMap<>(), callback));\n assertMockCommitCallbackInvoked(() -> consumer.poll(Duration.ZERO),\n callback,\n null);\n }", + "completions": [ + "```java\n @Test\n public void testPoll() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertNotNull(records);\n assertEquals(0, records.count());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 43, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static URI resolve(URI base, URI ref) {\n String path = ref.getRawPath();\n if (path != null && !path.isEmpty()) {\n path = base.getRawPath();\n if (path == null || !path.endsWith(\"/\")) {\n try {\n base = new URI(base.getScheme(), base.getAuthority(), base.getPath() + '/', null, null);\n } catch (URISyntaxException e) {\n throw new IllegalStateException(e);\n }\n }\n }\n return URIUtils.resolve(base, ref);\n }\n```\n###test function signature:\n```java\n @Test\n void testResolve_BasePathTrailingSlash()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testResolve_BasePathTrailingSlash", + "reference": " @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://host/sub/dir/\");\n assertEquals(\"http://host/sub/dir/file.jar\", resolve(base, \"file.jar\"));\n assertEquals(\"http://host/sub/dir/dir/file.jar\", resolve(base, \"dir/file.jar\"));\n assertEquals(\"http://host/sub/dir/?arg=val\", resolve(base, \"?arg=val\"));\n assertEquals(\"http://host/sub/dir/file?arg=val\", resolve(base, \"file?arg=val\"));\n assertEquals(\"http://host/sub/dir/dir/file?arg=val\", resolve(base, \"dir/file?arg=val\"));\n }", + "completions": [ + "```java\n @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://example.com/path/\");\n URI ref = URI.create(\"http://example.com/path/file.txt\");\n URI expected = URI.create(\"http://example.com/path/file.txt\");\n URI actual = resolve(base, ref);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 25, + "token": 187, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_collection_of_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_collection_of_non_serializable_object", + "reference": " @Test\n public void should_clone_collection_of_non_serializable_object() {\n List original = new ArrayList<>();\n original.add(new NonSerializableObject(\"value\"));\n List cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_collection_of_non_serializable_object() {\n // Given\n List nonSerializableList = new ArrayList<>();\n nonSerializableList.add(new NonSerializableObject(\"test\"));\n nonSerializableList.add(new NonSerializableObject(\"test2\"));\n\n // When\n List clonedList = objectCloner.clone(nonSerializableList);\n\n // Then\n assertNotSame(nonSerializableList, clonedList);\n assertEquals(nonSerializableList.size(), clonedList.size());\n for (int i = 0; i < nonSerializableList.size(); i++) {\n assertNotSame(nonSerializableList.get(i), clonedList.get(i));\n assertEquals(nonSerializableList.get(i), clonedList.get(i));\n }\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 45, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssign()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssign", + "reference": " @Test\n public void testAssign() {\n consumer = newConsumer();\n final TopicPartition tp = new TopicPartition(\"foo\", 3);\n consumer.assign(singleton(tp));\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().contains(tp));\n verify(applicationEventHandler).add(any(AssignmentChangeEvent.class));\n verify(applicationEventHandler).add(any(NewTopicsMetadataUpdateRequestEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testAssign() {\n // Create a mock Consumer instance\n Consumer mockConsumer = mock(Consumer.class);\n\n // Create a collection of TopicPartition instances\n Collection partitions = new ArrayList<>();\n partitions.add(new TopicPartition(\"testTopic\", 0));\n partitions.add(new TopicPartition(\"testTopic\", 1));\n\n // Call the assign method\n mockConsumer.assign(partitions);\n\n // Verify that the assign method was called with the correct parameters\n verify(mockConsumer).assign(partitions);\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 46, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String relativePath(Path absolutePath, Path basePath) {\n Preconditions.checkArgument(absolutePath.isAbsolute(), \"absolutePath must be an absolute path\");\n Preconditions.checkArgument(basePath.isAbsolute(), \"basePath must be an absolute path\");\n\n List absolutePathComponents = Lists.newArrayList(toPathComponents(absolutePath)); // make a copy\n List basePathComponents = toPathComponents(basePath);\n boolean hasCommonPrefix = basePathComponents.isEmpty(); // when basePath is \"/\", we always have a common prefix\n\n for (String base : basePathComponents) {\n if (absolutePathComponents.get(0).equals(base)) {\n absolutePathComponents.remove(0);\n hasCommonPrefix = true;\n } else {\n break;\n }\n }\n if (hasCommonPrefix) {\n return PATH_JOINER.join(absolutePathComponents);\n } else {\n return absolutePath.toString();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testRelativePath() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRelativePath", + "reference": " @Test\n public void testRelativePath() throws Exception {\n assertEquals(\"a\", PathUtils.relativePath(Path.of(\"/a\"), Path.of(\"/\")));\n assertEquals(\"b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/a\")));\n assertEquals(\"b/c.json\", PathUtils.relativePath(Path.of(\"/a/b/c.json\"), Path.of(\"/a\")));\n assertEquals(\"c/d/e\", PathUtils.relativePath(Path.of(\"/a/b/c/d/e\"), Path.of(\"/a/b\")));\n assertEquals(\"/a/b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/c/d\"))); // no common prefix\n }", + "completions": [ + "```java\n @Test\n public void testRelativePath() throws Exception {\n Path absolutePath = Paths.get(\"/a/b/c/d\");\n Path basePath = Paths.get(\"/a/b\");\n String expected = \"c/d\";\n String actual = relativePath(absolutePath, basePath);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 31, + "token": 234, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Boolean isJson(String value) {\n JsonElement jsonElement;\n try {\n jsonElement = JsonParser.parseString(value);\n } catch (Exception e) {\n return false;\n }\n if (jsonElement == null || jsonElement.isJsonNull()) {\n return false;\n }\n if (!jsonElement.isJsonObject() && !jsonElement.isJsonArray()) {\n return false;\n }\n return true;\n }\n```\n###test function signature:\n```java\n @Test\n void testIsJson()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJson", + "reference": " @Test\n void testIsJson() {\n assertThat(JsonUtil.isJson(\"value\")).isFalse();\n }", + "completions": [ + "```java\n @Test\n void testIsJson() {\n // Test case 1: Valid JSON object\n String jsonObject = \"{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}\";\n assertTrue(isJson(jsonObject));\n\n // Test case 2: Valid JSON array\n String jsonArray = \"[{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}, {\\\"name\\\":\\\"Jane\\\", \\\"age\\\":25, \\\"city\\\":\\\"Chicago\\\"}]\";\n assertTrue(isJson(jsonArray));\n\n // Test case 3: Invalid JSON (not an object or array)\n String invalidJson = \"\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"\";\n assertFalse(isJson(invalidJson));\n\n // Test case 4: Empty string\n String emptyString = \"\";\n assertFalse(isJson(emptyString));\n\n // Test case 5: Null value\n String nullValue = null;\n assertFalse(isJson(nullValue));\n }\n```\n" + ], + "line": 26, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchOnlyHistogram()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchOnlyHistogram", + "reference": " @Test\n public void testSearchOnlyHistogram() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 0,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(0);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(2);\n\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(1).getDocCount()).isEqualTo(1);\n\n assertThat(\n Long.parseLong(histogram.getBuckets().get(0).getKeyAsString()) >= time.toEpochMilli())\n .isTrue();\n assertThat(\n Long.parseLong(histogram.getBuckets().get(1).getKeyAsString())\n <= time.plusSeconds(10).toEpochMilli())\n .isTrue();\n }", + "completions": [ + "```java\n @Test\n public void testSearchOnlyHistogram() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 0;\n AggBuilder aggBuilder = new AggBuilder();\n aggBuilder.addHistogramAgg(\"testField\", \"testInterval\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation().getHistogramAggs().size() > 0);\n assertEquals(\"testField\", result.getAggregation().getHistogramAggs().get(0).getField());\n assertEquals(\"testInterval\", result.getAggregation().getHistogramAggs().get(0).getInterval());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 49, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent", + "reference": " @Test\n public void testReadComponent() {\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n EntityMetadata entityMetadata = mock(EntityMetadata.class);\n VariableSource variableSource = mock(VariableSource.class);\n when(contentPermissionChecker.isPermitted(anyString(), any(), eq(BreadActions.BROWSE), any())).thenReturn(true);\n when(component.getEntityMetadata()).thenReturn(entityMetadata);\n when(entityMetadata.getId()).thenReturn(new DetachedEntityId(\"someid\"));\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(Arrays.asList(asset));\n when(assetVariableResolver.fromAsset(asset)).thenReturn(variableSource);\n ComponentXO componentXO = underTest.readComponent(\"someid\", \"testRepositoryName\");\n\n assertThat(componentXO, is(notNullValue()));\n assertThat(componentXO.getId(), is(\"someid\"));\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n Supplier txSupplier = () -> storageTx;\n\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(txSupplier);\n when(storageTx.findComponent(componentId)).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(ImmutableList.of(asset));\n\n // When\n ComponentXO result = readComponent(repository, componentId);\n\n // Then\n verify(repository).facet(StorageFacet.class);\n verify(storageFacet).txSupplier();\n verify(storageTx).begin();\n verify(storageTx).findComponent(componentId);\n verify(storageTx).browseAssets(component);\n verify(storageTx).close();\n // Add assertions for the result\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 50, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ListResult listEntities(String tableName, String searchIndexName, List matchFilters,\n List queryFilters, List multiMatchFilter, String nextToken, List sorters, Class clazz) {\n SearchQuery searchQuery = createSearchQuery(matchFilters, queryFilters, multiMatchFilter);\n SearchRequest searchRequest = new SearchRequest(tableName, searchIndexName, searchQuery);\n if (!StringUtils.isEmpty(nextToken)) {\n byte[] tokenBytes = EncryptionUtil.decode(nextToken);\n searchRequest.getSearchQuery().setToken(tokenBytes);\n } else {\n if (sorters != null &&!sorters.isEmpty()) {\n searchQuery.setSort(new Sort(sorters));\n }\n }\n SearchRequest.ColumnsToGet columnsToGet = new SearchRequest.ColumnsToGet();\n columnsToGet.setColumns(ReflectionUtil.getPropertyNames(clazz));\n searchRequest.setColumnsToGet(columnsToGet);\n log.info(\"searchRequest:{}\", JsonUtil.toJsonString(searchRequest));\n SearchResponse searchResponse = otsClient.search(searchRequest);\n if (searchResponse == null || searchResponse.getRows() == null) {\n return ListResult.genSuccessListResult(null, 0);\n }\n byte[] nextTokenBytes = searchResponse.getNextToken();\n nextToken = nextTokenBytes == null || nextTokenBytes.length == 0 ? null : EncryptionUtil.encode(nextTokenBytes);\n List result = searchResponse.getRows().stream()\n .map(row -> OtsUtil.convertRowToDTO(row, clazz))\n .collect(Collectors.toList());\n return ListResult.genSuccessListResult(result, searchResponse.getTotalCount(), nextToken);\n }\n```\n###test function signature:\n```java\n @Test\n void testListEntities()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListEntities", + "reference": " @Test\n void testListEntities() {\n final List matchFilters = createMatchFilters();\n final List queryFilters = createQueryFilters();\n final ListResult expectedResult = new ListResult<>();\n FieldSort fieldSort = new FieldSort(OrderOtsConstant.GMT_CREATE_LONG);\n fieldSort.setOrder(SortOrder.DESC);\n expectedResult.setData(Arrays.asList());\n expectedResult.setCount(0L);\n String nextToken = \"CAESFQoTChEKDWdtdENyZWF0ZUxvbmcQARgBIlQKCQBI8UqGigEAAApHA0IAAAAxUzM1MzQzMTM0NjQzMjYzMzAzMzYyMzE2MTMzMzkzOTM1MzEzNjM2MzM2NDM2MzAzMDMwNjYzNTM1MzA2NjY0MzM=\";\n expectedResult.setNextToken(nextToken);\n final SearchResponse searchResponse = new SearchResponse(new Response(\"requestId\"));\n when(mockOtsClient.search(any(SearchRequest.class))).thenReturn(searchResponse);\n ListResult result = baseOtsHelper.listEntities(\"order\", \"order_index\", matchFilters, queryFilters, null, nextToken, Arrays.asList(fieldSort), OrderDTO.class);\n assertThat(result.getCount()).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n void testListEntities() {\n // Given\n String tableName = \"testTable\";\n String searchIndexName = \"testIndex\";\n List matchFilters = Arrays.asList(new OtsFilter(\"field1\", OtsFilter.CompareOperator.EQUAL, \"value1\"),\n new OtsFilter(\"field2\", OtsFilter.CompareOperator.EQUAL, \"value2\"));\n List queryFilters = Arrays.asList(new OtsFilter(\"field3\", OtsFilter.CompareOperator.EQUAL, \"value3\"),\n new OtsFilter(\"field4\", OtsFilter.CompareOperator.EQUAL, \"value4\"));\n List multiMatchFilter = Arrays.asList(new OtsFilter(\"field5\", OtsFilter.CompareOperator.EQUAL, \"value5\"),\n new OtsFilter(\"field6\", OtsFilter.CompareOperator.EQUAL, \"value6\"));\n String nextToken = \"testNextToken\";\n List sorters = Arrays.asList(new Sort.Sorter(\"field7\", Sort.SortOrder.ASC),\n new Sort.Sorter(\"field8\", Sort.SortOrder.DESC));\n Class clazz = TestDTO.class;\n\n // When\n ListResult result = listEntities(tableName, searchIndexName, matchFilters, queryFilters, multiMatchFilter, nextToken, sorters, clazz);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalCount());\n assertNull(result.getNextToken());\n assertTrue(result.getData().isEmpty());\n }\n```\n" + ], + "line": 38, + "token": 346, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 51, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String[] listFiles(URI fileUri, boolean recursive) throws IOException {\n try {\n ImmutableList.Builder builder = ImmutableList.builder();\n String continuationToken = null;\n boolean isDone = false;\n String prefix = normalizeToDirectoryPrefix(fileUri);\n int fileCount = 0;\n while (!isDone) {\n ListObjectsV2Request.Builder listObjectsV2RequestBuilder =\n ListObjectsV2Request.builder().bucket(fileUri.getHost());\n if (!prefix.equals(DELIMITER)) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.prefix(prefix);\n }\n if (!recursive) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.delimiter(DELIMITER);\n }\n if (continuationToken != null) {\n listObjectsV2RequestBuilder.continuationToken(continuationToken);\n }\n ListObjectsV2Request listObjectsV2Request = listObjectsV2RequestBuilder.build();\n LOG.debug(\"Trying to send ListObjectsV2Request {}\", listObjectsV2Request);\n ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);\n LOG.debug(\"Getting ListObjectsV2Response: {}\", listObjectsV2Response);\n List filesReturned = listObjectsV2Response.contents();\n fileCount += filesReturned.size();\n filesReturned.stream()\n .forEach(\n object -> {\n // Only add files and not directories\n if (!object.key().equals(fileUri.getPath())\n && !object.key().endsWith(DELIMITER)) {\n String fileKey = object.key();\n if (fileKey.startsWith(DELIMITER)) {\n fileKey = fileKey.substring(1);\n }\n builder.add(S3_SCHEME + fileUri.getHost() + DELIMITER + fileKey);\n }\n });\n if (fileCount == LIST_MAX_KEYS) {\n // check if we reached the max keys returned, if so abort and throw an error message\n LOG.error(\n \"Too many files ({}) returned from S3 when attempting to list object prefixes\",\n LIST_MAX_KEYS);\n throw new IllegalStateException(\n String.format(\n \"Max keys (%s) reached when attempting to list S3 objects\", LIST_MAX_KEYS));\n }\n isDone = !listObjectsV2Response.isTruncated();\n continuationToken = listObjectsV2Response.nextContinuationToken();\n }\n String[] listedFiles = builder.build().toArray(new String[0]);\n LOG.debug(\n \"Listed {} files from URI: {}, is recursive: {}\", listedFiles.length, fileUri, recursive);\n return listedFiles;\n } catch (Throwable t) {\n throw new IOException(t);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListFilesInFolderRecursive", + "reference": " @Test\n public void testListFilesInFolderRecursive() throws Exception {\n String folder = \"list-files-rec\";\n String[] nestedFolders = new String[] {\"list-files-child-1\", \"list-files-child-2\"};\n String[] originalFiles = new String[] {\"a-list-3.txt\", \"b-list-3.txt\", \"c-list-3.txt\"};\n\n List expectedResultList = new ArrayList<>();\n for (String childFolder : nestedFolders) {\n String folderName = folder + DELIMITER + childFolder;\n for (String fileName : originalFiles) {\n createEmptyFile(folderName, fileName);\n expectedResultList.add(\n String.format(FILE_FORMAT, SCHEME, bucket, folderName + DELIMITER + fileName));\n }\n }\n String[] actualFiles =\n s3BlobFs.listFiles(URI.create(String.format(FILE_FORMAT, SCHEME, bucket, folder)), true);\n\n actualFiles =\n Arrays.stream(actualFiles).filter(x -> x.contains(\"list-3\")).toArray(String[]::new);\n assertEquals(actualFiles.length, expectedResultList.size());\n assertTrue(Arrays.equals(expectedResultList.toArray(), actualFiles));\n }", + "completions": [ + "```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception {\n // Given\n URI fileUri = new URI(\"s3://bucket/folder/\");\n boolean recursive = true;\n String[] expectedFiles = {\n \"s3://bucket/folder/file1.txt\",\n \"s3://bucket/folder/file2.txt\",\n \"s3://bucket/folder/subfolder/file3.txt\",\n \"s3://bucket/folder/subfolder/file4.txt\"\n };\n\n // When\n String[] actualFiles = s3FileSystem.listFiles(fileUri, recursive);\n\n // Then\n assertArrayEquals(expectedFiles, actualFiles);\n }\n```\n" + ], + "line": 70, + "token": 490, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 52, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_nameDuplicate", + "reference": " @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // mock 数据\n RoleDO roleDO = randomPojo(RoleDO.class, o -> o.setName(\"role_name\"));\n roleMapper.insert(roleDO);\n // 准备参数\n String name = \"role_name\";\n\n // 调用,并断言异常\n assertServiceException(() -> roleService.validateRoleDuplicate(name, randomString(), null),\n ROLE_NAME_DUPLICATE, name);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // Given\n String name = \"duplicateName\";\n String code = \"uniqueCode\";\n Long id = 1L;\n RoleDO existingRole = new RoleDO();\n existingRole.setId(2L);\n existingRole.setName(name);\n existingRole.setCode(code);\n when(roleMapper.selectByName(name)).thenReturn(existingRole);\n\n // When\n try {\n roleService.validateRoleDuplicate(name, code, id);\n fail(\"Expected exception not thrown\");\n } catch (ServiceException e) {\n // Then\n assertEquals(ROLE_NAME_DUPLICATE, e.getCode());\n assertEquals(name, e.getArgs()[0]);\n }\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public String sign(Map map) {\n try {\n Date date = new Date(System.currentTimeMillis() + jwtProperties.getExpireTimeMills());\n Algorithm algorithm = Algorithm.HMAC256(jwtProperties.getTokenSecret());\n Map header = new HashMap<>(2);\n header.put(\"typ\", \"jwt\");\n JWTCreator.Builder builder = JWT.create()\n .withHeader(header)\n .withIssuedAt(new Date())\n .withExpiresAt(date);\n map.entrySet().forEach(entry -> {\n if (entry.getValue() instanceof Integer) {\n builder.withClaim(entry.getKey(), (Integer) entry.getValue());\n } else if (entry.getValue() instanceof Long) {\n builder.withClaim(entry.getKey(), (Long) entry.getValue());\n } else if (entry.getValue() instanceof Boolean) {\n builder.withClaim(entry.getKey(), (Boolean) entry.getValue());\n } else if (entry.getValue() instanceof String) {\n builder.withClaim(entry.getKey(), String.valueOf(entry.getValue()));\n } else if (entry.getValue() instanceof Double) {\n builder.withClaim(entry.getKey(), (Double) entry.getValue());\n } else if (entry.getValue() instanceof Date) {\n builder.withClaim(entry.getKey(), (Date) entry.getValue());\n }\n });\n return builder.sign(algorithm);\n } catch (Exception e) {\n return null;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sign_generateJwtToken_succeed()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sign_generateJwtToken_succeed", + "reference": " @Test\n public void sign_generateJwtToken_succeed() {\n Map map = new HashMap<>();\n map.put(\"userId\", 123);\n map.put(\"username\", \"testUser\");\n String token = jwtService.sign(map);\n Assert.assertNotNull(token);\n }", + "completions": [ + "```java\n @Test\n public void sign_generateJwtToken_succeed() {\n // Given\n Map map = new HashMap<>();\n map.put(\"userId\", 1);\n map.put(\"username\", \"testUser\");\n map.put(\"isAdmin\", true);\n map.put(\"expireTime\", System.currentTimeMillis() + 3600000);\n\n // When\n String jwtToken = jwtService.sign(map);\n\n // Then\n assertNotNull(jwtToken);\n assertTrue(jwtToken.length() > 0);\n }\n```\n" + ], + "line": 41, + "token": 349, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 54, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_non_serializable_value_with_null_value", + "reference": " @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map original = new LinkedHashMap<>();\n original.put(\"null\", null);\n original.put(\"key\", new NonSerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map map = new HashMap<>();\n map.put(\"key\", null);\n\n Map clonedMap = (Map) clone(map);\n\n assertNotSame(map, clonedMap);\n assertEquals(map, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 55, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void init(final DiscoveryConfig config) {\n if (this.namingService != null) {\n LOGGER.info(\"Nacos naming service already registered\");\n return;\n }\n Properties properties = config.getProps();\n Properties nacosProperties = new Properties();\n this.groupName = properties.getProperty(\"groupName\", \"SHENYU_GROUP\");\n String serverAddr = config.getServerList();\n nacosProperties.put(PropertyKeyConst.SERVER_ADDR, serverAddr);\n nacosProperties.put(PropertyKeyConst.NAMESPACE, properties.getProperty(NAMESPACE, \"\"));\n nacosProperties.put(PropertyKeyConst.USERNAME, properties.getProperty(PropertyKeyConst.USERNAME, \"\"));\n nacosProperties.put(PropertyKeyConst.PASSWORD, properties.getProperty(PropertyKeyConst.PASSWORD, \"\"));\n nacosProperties.put(PropertyKeyConst.ACCESS_KEY, properties.getProperty(PropertyKeyConst.ACCESS_KEY, \"\"));\n nacosProperties.put(PropertyKeyConst.SECRET_KEY, properties.getProperty(PropertyKeyConst.SECRET_KEY, \"\"));\n try {\n this.namingService = NamingFactory.createNamingService(nacosProperties);\n LOGGER.info(\"Nacos naming service initialized success\");\n } catch (NacosException e) {\n LOGGER.error(\"Error initializing Nacos naming service\", e);\n throw new ShenyuException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testInit", + "reference": " @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n // Set the discovery config\n setField(nacosDiscoveryServiceUnderTest.getClass(), \"namingService\", null);\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n config.setServerList(\"127.0.0.1:8848\");\n properties.setProperty(\"groupName\", \"SHENYU_GROUP\");\n config.setProps(properties);\n\n try (MockedStatic mockedNamingFactory = mockStatic(NamingFactory.class)) {\n // Mock the successful creation of NamingService\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenReturn(namingService);\n nacosDiscoveryServiceUnderTest.init(config);\n mockedNamingFactory.verify(() -> NamingFactory.createNamingService(any(Properties.class)));\n assertEquals(namingService, getField(nacosDiscoveryServiceUnderTest, \"namingService\"));\n // Mock the situation where NamingService fails to be created and throws an exception\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenThrow(new NacosException());\n assertDoesNotThrow(() -> nacosDiscoveryServiceUnderTest.init(config));\n }\n }", + "completions": [ + "```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n properties.put(\"groupName\", \"SHENYU_GROUP\");\n properties.put(\"serverAddr\", \"localhost:8848\");\n properties.put(\"namespace\", \"\");\n properties.put(\"username\", \"\");\n properties.put(\"password\", \"\");\n properties.put(\"accessKey\", \"\");\n properties.put(\"secretKey\", \"\");\n config.setProps(properties);\n config.setServerList(\"localhost:8848\");\n\n NacosDiscoveryService nacosDiscoveryService = new NacosDiscoveryService();\n nacosDiscoveryService.init(config);\n\n Field namingServiceField = NacosDiscoveryService.class.getDeclaredField(\"namingService\");\n namingServiceField.setAccessible(true);\n NamingService namingService = (NamingService) namingServiceField.get(nacosDiscoveryService);\n\n assertNotNull(namingService);\n }\n```\n" + ], + "line": 35, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 56, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyRecoveryNotBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(850, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 2) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(750, 0, indexerConfig));\n\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n assertThat(recoveryTaskCreator.determineStartingOffset(1201, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTaskCreator.determineStartingOffset(1200, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(1150, 0, indexerConfig));\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore))\n .contains(recoveryTask1, recoveryTask11);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 57, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public boolean isJsonMime(String mediaType) {\n // \"* / *\" is default to JSON\n if (\"*/*\".equals(mediaType)) {\n return true;\n }\n\n try {\n return isJsonMime(MediaType.parseMediaType(mediaType));\n } catch (InvalidMediaTypeException e) {\n }\n return false;\n }\n```\n###test function signature:\n```java\n @Test\n public void testIsJsonMime()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJsonMime", + "reference": " @Test\n public void testIsJsonMime() {\n assertFalse(apiClient.isJsonMime(null));\n assertFalse(apiClient.isJsonMime(\"\"));\n assertFalse(apiClient.isJsonMime(\"text/plain\"));\n assertFalse(apiClient.isJsonMime(\"application/xml\"));\n assertFalse(apiClient.isJsonMime(\"application/jsonp\"));\n assertFalse(apiClient.isJsonMime(\"example/json\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+jsonx\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+xjson\"));\n\n assertTrue(apiClient.isJsonMime(\"application/json\"));\n assertTrue(apiClient.isJsonMime(\"application/json; charset=UTF8\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/JSON\"));\n\n assertTrue(apiClient.isJsonMime(\"application/problem+json\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/PROBLEM+JSON\"));\n assertTrue(apiClient.isJsonMime(\"application/json\\t\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+bar+json\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json;x;y\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json\\t;\"));\n assertTrue(apiClient.isJsonMime(\"Example/fOO+JSON\"));\n }", + "completions": [ + "```java\n @Test\n public void testIsJsonMime() {\n // Test case 1: \"* / *\"\n assertTrue(isJsonMime(\"*/*\"));\n\n // Test case 2: \"application/json\"\n assertTrue(isJsonMime(\"application/json\"));\n\n // Test case 3: \"application/json; charset=UTF-8\"\n assertTrue(isJsonMime(\"application/json; charset=UTF-8\"));\n\n // Test case 4: \"application/xml\"\n assertFalse(isJsonMime(\"application/xml\"));\n\n // Test case 5: \"text/plain\"\n assertFalse(isJsonMime(\"text/plain\"));\n\n // Test case 6: \"invalid/media/type\"\n assertFalse(isJsonMime(\"invalid/media/type\"));\n\n // Test case 7: null\n assertFalse(isJsonMime(null));\n }\n```\n" + ], + "line": 22, + "token": 139, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n dpopService.validateDpopThumprint(\"test\", \"test\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n // Given\n String existingThumprint = \"existingThumprint\";\n String requestThumprint = \"existingThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n // Call the method to test\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String getExpandUrl(String url, Object paramObj) {\n UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url);\n if (paramObj instanceof Map) {\n for (Map.Entry param : ((Map) paramObj).entrySet()) {\n if (null != param.getValue()) {\n builder.queryParam(param.getKey(), param.getValue());\n }\n }\n } else {\n for (Field param : paramObj.getClass().getDeclaredFields()) {\n param.setAccessible(true);\n\n try {\n Object value = param.get(paramObj);\n if (null != value) {\n if (value instanceof List) {\n builder.queryParam(param.getName(), (List) value);\n } else {\n builder.queryParam(param.getName(), value);\n }\n }\n } catch (IllegalAccessException e) {\n log.warn(\"Field: {} value retrieval failed, exception: {}\", param.getName(), e);\n }\n }\n }\n\n return builder.build().encode().toUri().toString();\n }\n```\n###test function signature:\n```java\n @Test\n void testGetExpandUrl()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetExpandUrl", + "reference": " @Test\n void testGetExpandUrl() {\n String result = HttpUtil.getExpandUrl(\"http://baidu.com\", \"paramObj\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testGetExpandUrl() {\n // Test with a Map\n Map map = new HashMap<>();\n map.put(\"param1\", \"value1\");\n map.put(\"param2\", \"value2\");\n String url = \"http://example.com\";\n String expected = \"http://example.com?param1=value1¶m2=value2\";\n String actual = getExpandUrl(url, map);\n assertEquals(expected, actual);\n\n // Test with a custom object\n TestObject obj = new TestObject(\"value1\", \"value2\");\n expected = \"http://example.com?param1=value1¶m2=value2\";\n actual = getExpandUrl(url, obj);\n assertEquals(expected, actual);\n }\n\n public static class TestObject {\n private String param1;\n private String param2;\n\n public TestObject(String param1, String param2) {\n this.param1 = param1;\n this.param2 = param2;\n }\n\n public String getParam1() {\n return param1;\n }\n\n public String getParam2() {\n return param2;\n }\n }\n```\n" + ], + "line": 38, + "token": 277, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 60, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"asdf\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n String invalidExpirationTime = \"invalid\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(invalidExpirationTime);\n assertEquals(CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 61, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Get\n @Path(\"/:indexName/_mapping\")\n public HttpResponse mapping(\n @Param(\"indexName\") Optional indexName,\n @Param(\"startTimeEpochMs\") Optional startTimeEpochMs,\n @Param(\"endTimeEpochMs\") Optional endTimeEpochMs)\n throws IOException {\n // Use a tree map so the results are naturally sorted\n Map> propertiesMap = new TreeMap<>();\n\n // we default the schema search to the last hour if params are not provided\n AstraSearch.SchemaResult schemaResult =\n searcher.getSchema(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(\n startTimeEpochMs.orElse(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build());\n\n Map schema = SearchResultUtils.fromSchemaResultProto(schemaResult);\n schema.forEach((key, value) -> propertiesMap.put(key, Map.of(\"type\", value.getName())));\n\n // todo - remove this after we add support for a \"date\" type\n // override the timestamp as a date field for proper autocomplete\n propertiesMap.put(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, Map.of(\"type\", \"date\"));\n\n return HttpResponse.of(\n HttpStatus.OK,\n MediaType.JSON,\n JsonUtil.writeAsString(\n ImmutableMap.of(\n indexName.orElseThrow(),\n ImmutableMap.of(\"mappings\", ImmutableMap.of(\"properties\", propertiesMap)))));\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexMapping() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexMapping", + "reference": " @Test\n public void testIndexMapping() throws IOException {\n AstraQueryServiceBase searcher = mock(AstraQueryServiceBase.class);\n ElasticsearchApiService serviceUnderTest = new ElasticsearchApiService(searcher);\n\n Instant start = Instant.now();\n Instant end = start.minusSeconds(60);\n\n when(searcher.getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build())))\n .thenReturn(AstraSearch.SchemaResult.newBuilder().build());\n\n HttpResponse response =\n serviceUnderTest.mapping(\n Optional.of(\"foo\"), Optional.of(start.toEpochMilli()), Optional.of(end.toEpochMilli()));\n verify(searcher)\n .getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build()));\n\n // handle response\n AggregatedHttpResponse aggregatedRes = response.aggregate().join();\n String body = aggregatedRes.content(StandardCharsets.UTF_8);\n JsonNode jsonNode = new ObjectMapper().readTree(body);\n\n assertThat(aggregatedRes.status().code()).isEqualTo(200);\n\n assertThat(jsonNode.findValue(\"foo\")).isNotNull();\n assertThat(\n jsonNode.findValue(\"foo\").findValue(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName))\n .isNotNull();\n\n when(searcher.getSchema(any()))\n .thenAnswer(\n invocationOnMock -> {\n AstraSearch.SchemaRequest request =\n ((AstraSearch.SchemaRequest) invocationOnMock.getArguments()[0]);\n assertThat(request.getDataset()).isEqualTo(\"bar\");\n assertThat(request.getStartTimeEpochMs())\n .isCloseTo(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli(),\n Offset.offset(1000L));\n assertThat(request.getEndTimeEpochMs())\n .isCloseTo(Instant.now().toEpochMilli(), Offset.offset(1000L));\n return AstraSearch.SchemaResult.newBuilder().build();\n });\n serviceUnderTest.mapping(Optional.of(\"bar\"), Optional.empty(), Optional.empty());\n }", + "completions": [ + "```java\n @Test\n public void testIndexMapping() throws IOException {\n Optional indexName = Optional.of(\"testIndex\");\n Optional startTimeEpochMs = Optional.of(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli());\n Optional endTimeEpochMs = Optional.of(Instant.now().toEpochMilli());\n\n AstraSearch.SchemaResult schemaResult =\n AstraSearch.SchemaResult.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(startTimeEpochMs.orElse(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build();\n\n when(searcher.getSchema(any(AstraSearch.SchemaRequest.class))).thenReturn(schemaResult);\n\n HttpResponse response = mapping(indexName, startTimeEpochMs, endTimeEpochMs);\n\n assertEquals(HttpStatus.OK, response.getStatus());\n assertEquals(MediaType.JSON, response.getContentType());\n assertNotNull(response.getContent());\n }\n```\n" + ], + "line": 43, + "token": 354, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 62, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "@SuppressWarnings", + "reference": " @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n consumer = newConsumer();\n Exception e = assertThrows(UnsupportedOperationException.class, () -> consumer.poll(0L));\n assertEquals(\"Consumer.poll(long) is not supported when \\\"group.protocol\\\" is \\\"consumer\\\". \" +\n \"This method is deprecated and will be removed in the next major release.\", e.getMessage());\n }", + "completions": [ + "```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 46, + "token": 373, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 63, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testNullSearchString()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNullSearchString", + "reference": " @Test\n public void testNullSearchString() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n assertThatExceptionOfType(IllegalArgumentException.class)\n .isThrownBy(\n () ->\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME + \"miss\",\n null,\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\")));\n }", + "completions": [ + "```java\n @Test\n public void testNullSearchString() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = null;\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n try {\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n fail(\"Expected an IllegalArgumentException to be thrown\");\n } catch (IllegalArgumentException e) {\n // Expected\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 64, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture sendMessageBack(ProxyContext ctx, ReceiptHandle handle, String messageId,\n ConsumerSendMsgBackRequestHeader requestHeader, long timeoutMillis) {\n // Build the response.\n final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null, null);\n\n Integer delayLevel = requestHeader.getDelayLevel();\n if (Objects.isNull(delayLevel)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument delay level is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n Long offset = requestHeader.getOffset();\n if (Objects.isNull(offset)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument offset is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n VirtualQueue virtualQueue = new VirtualQueue(requestHeader.getBname());\n\n CompletableFuture topicFuture = topicOf(requestHeader.getOriginTopic());\n CompletableFuture groupFuture = consumerGroupOf(requestHeader.getGroup());\n return topicFuture.thenCombine(groupFuture, (topic, group) -> {\n if (topic.getTopicId() != virtualQueue.topicId()) {\n LOGGER.error(\"Topic id in request header {} does not match topic id in message queue {}, maybe the topic is recreated.\",\n topic.getTopicId(), virtualQueue.topicId());\n throw new ProxyException(apache.rocketmq.v2.Code.TOPIC_NOT_FOUND, \"Topic resource does not exist.\");\n }\n return Pair.of(topic, group);\n }).thenCompose(pair -> {\n Topic topic = pair.getLeft();\n ConsumerGroup group = pair.getRight();\n\n return store.pull(StoreContext.EMPTY, group.getGroupId(), topic.getTopicId(), virtualQueue.physicalQueueId(),\n Filter.DEFAULT_FILTER, requestHeader.getOffset(), 1, false)\n .thenApply(pullResult -> {\n if (pullResult.status() == com.automq.rocketmq.store.model.message.PullResult.Status.FOUND) {\n return pullResult.messageList().get(0);\n }\n throw new ProxyException(apache.rocketmq.v2.Code.MESSAGE_NOT_FOUND, \"Message not found from server.\");\n }).thenCompose(messageExt -> {\n if (messageExt.deliveryAttempts() > group.getMaxDeliveryAttempt()) {\n return deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n }\n\n // Message consume retry strategy\n // <0: no retry,put into DLQ directly\n // =0: broker control retry frequency\n // >0: client control retry frequency\n return switch (Integer.compare(delayLevel, 0)) {\n case -1 ->\n deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n case 0 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n // Keep the same logic as apache RocketMQ.\n int serverDelayLevel = messageExt.deliveryAttempts() + 1;\n messageExt.setDeliveryAttempts(serverDelayLevel);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(serverDelayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put messageExt to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n case 1 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n messageExt.setDeliveryAttempts(messageExt.deliveryAttempts() + 1);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(delayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put message to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n default -> throw new IllegalStateException(\"Never reach here\");\n };\n });\n }).whenComplete((nil, throwable) -> {\n if (throwable != null) {\n response.setCode(ResponseCode.SYSTEM_ERROR);\n response.setRemark(throwable.getMessage());\n }\n }).thenApply(nil -> response);\n }\n```\n###test function signature:\n```java\n @Test\n void sendMessageBack()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendMessageBack", + "reference": " @Test\n void sendMessageBack() {\n String topicName = \"topic\";\n VirtualQueue virtualQueue = new VirtualQueue(2, 0);\n ConsumerSendMsgBackRequestHeader header = new ConsumerSendMsgBackRequestHeader();\n header.setBname(virtualQueue.brokerName());\n header.setGroup(\"group\");\n header.setOriginTopic(topicName);\n header.setOffset(0L);\n header.setDelayLevel(0);\n\n CompletionException exception = assertThrowsExactly(CompletionException.class, () -> messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join());\n assertInstanceOf(ProxyException.class, ExceptionUtils.getRealException(exception));\n ProxyException realException = (ProxyException) ExceptionUtils.getRealException(exception);\n assertEquals(realException.getErrorCode(), Code.MESSAGE_NOT_FOUND);\n assertEquals(realException.getMessage(), \"Message not found from server.\");\n\n long topicId = metadataService.topicOf(topicName).join().getTopicId();\n messageStore.put(StoreContext.EMPTY, FlatMessageUtil.convertTo(ProxyContextExt.create(), topicId, 0, \"\", new Message(topicName, \"\", new byte[] {})));\n\n // Broker controlled delay level.\n header.setDelayLevel(0);\n RemotingCommand response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Client controlled delay level.\n header.setDelayLevel(16);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Forward message into dead letter topic.\n header.setDelayLevel(-1);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.times(1)).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n }", + "completions": [ + "```java\n @Test\n void sendMessageBack() {\n // Given\n ProxyContext ctx = mock(ProxyContext.class);\n ReceiptHandle handle = mock(ReceiptHandle.class);\n String messageId = \"testMessageId\";\n ConsumerSendMsgBackRequestHeader requestHeader = mock(ConsumerSendMsgBackRequestHeader.class);\n long timeoutMillis = 1000L;\n\n when(requestHeader.getDelayLevel()).thenReturn(1);\n when(requestHeader.getOffset()).thenReturn(1L);\n when(requestHeader.getOriginTopic()).thenReturn(\"testOriginTopic\");\n when(requestHeader.getGroup()).thenReturn(\"testGroup\");\n when(requestHeader.getBname()).thenReturn(\"testBname\");\n\n Topic topic = mock(Topic.class);\n when(topic.getTopicId()).thenReturn(1);\n\n ConsumerGroup group = mock(ConsumerGroup.class);\n when(group.getGroupId()).thenReturn(\"testGroupId\");\n when(group.getMaxDeliveryAttempt()).thenReturn(3);\n\n FlatMessage message = mock(FlatMessage.class);\n when(message.deliveryAttempts()).thenReturn(1);\n when(message.originalOffset()).thenReturn(1L);\n\n MessageExt messageExt = mock(MessageExt.class);\n when(messageExt.message()).thenReturn(message);\n when(messageExt.deliveryAttempts()).thenReturn(1);\n\n PullResult pullResult = mock(PullResult.class);\n when(pullResult.status()).thenReturn(PullResult.Status.FOUND);\n when(pullResult.messageList()).thenReturn(List.of(messageExt));\n\n Store store = mock(Store.class);\n when(store.pull(any(), any(), any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(pullResult));\n when(store.put(any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n DeadLetterService deadLetterService = mock(DeadLetterService.class);\n when(deadLetterService.send(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n ProxyContextExt ctxExt = mock(ProxyContextExt.class);\n\n TopicService topicService = mock(TopicService.class);\n when(topicService.topicOf(any())).thenReturn(CompletableFuture.completedFuture(topic));\n when(topicService.consumerGroupOf(any())).thenReturn(CompletableFuture.completedFuture(group));\n\n ConsumerService consumerService = new ConsumerService(store, deadLetterService, topicService);\n\n // When\n CompletableFuture future = consumerService.sendMessageBack(ctx, handle, messageId, requestHeader, timeoutMillis);\n\n // Then\n RemotingCommand response = future.join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n }\n```\n" + ], + "line": 99, + "token": 793, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 65, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInRequestThreads()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInRequestThreads", + "reference": " @Test\n public void testGetTracerInRequestThreads() {\n ApolloAuditTracer mockTracer = new ApolloAuditTracer(Mockito.mock(ApolloAuditScopeManager.class), supplier);\n RequestAttributes mockRequestAttributes = Mockito.mock(RequestAttributes.class);\n RequestContextHolder.setRequestAttributes(mockRequestAttributes);\n Mockito.when(mockRequestAttributes.getAttribute(Mockito.eq(ApolloAuditConstants.TRACER), Mockito.eq(RequestAttributes.SCOPE_REQUEST)))\n .thenReturn(mockTracer);\n ApolloAuditTracer get = traceContext.tracer();\n assertNotNull(get);\n Mockito.verify(traceContext, Mockito.times(0))\n .setTracer(Mockito.any(ApolloAuditTracer.class));\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInRequestThreads() {\n // Given\n RequestAttributes requestAttributes = Mockito.mock(RequestAttributes.class);\n ApolloAuditTracer tracer = Mockito.mock(ApolloAuditTracer.class);\n Mockito.when(requestAttributes.getAttribute(ApolloAuditConstants.TRACER, RequestAttributes.SCOPE_REQUEST)).thenReturn(tracer);\n Mockito.when(RequestContextHolder.getRequestAttributes()).thenReturn(requestAttributes);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n Assert.assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Map generate(String modelName) {\n Map kv = new HashMap<>();\n if (modelName != null) {\n final Schema schema = this.globalContext.getSchemaMap().get(modelName);\n if (schema != null) {\n String example = Json.pretty(this.resolveModelToExample(modelName, schema));\n\n if (example != null) {\n kv.put(EXAMPLE, example);\n return Map.copyOf(kv);\n }\n }\n }\n return Collections.emptyMap();\n }\n```\n###test function signature:\n```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "shouldGenerateOneOfWhenOneOfInRoot", + "reference": " @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n OpenAPIModelGenerator generator = setupPayloadGenerator();\n\n Map example = generator.generate(\"PetType\");\n String exampleJson = example.get(\"example\");\n\n Assertions.assertThat(exampleJson)\n .contains(\"PetTypeONE_OF#/components/schemas/Husky\")\n .contains(\"PetTypeONE_OF#/components/schemas/Labrador\");\n }", + "completions": [ + "```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n // Given\n String modelName = \"ModelName\";\n Schema schema = new Schema();\n schema.setOneOf(Arrays.asList(new Schema(), new Schema()));\n this.globalContext.getSchemaMap().put(modelName, schema);\n\n // When\n Map result = this.generate(modelName);\n\n // Then\n assertNotNull(result);\n assertFalse(result.isEmpty());\n assertTrue(result.containsKey(EXAMPLE));\n assertNotNull(result.get(EXAMPLE));\n }\n```\n" + ], + "line": 25, + "token": 181, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInAnotherThreadButSameRequest", + "reference": " @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n ApolloAuditTracer mockTracer = Mockito.mock(ApolloAuditTracer.class);\n {\n Mockito.when(traceContext.tracer()).thenReturn(mockTracer);\n }\n CountDownLatch latch = new CountDownLatch(1);\n Executors.newSingleThreadExecutor().submit(() -> {\n ApolloAuditTracer tracer = traceContext.tracer();\n\n assertEquals(mockTracer, tracer);\n\n latch.countDown();\n });\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n // Given\n RequestAttributes requestAttributes = new ServletRequestAttributes(new MockHttpServletRequest());\n RequestContextHolder.setRequestAttributes(requestAttributes);\n ApolloAuditTracer tracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n requestAttributes.setAttribute(ApolloAuditConstants.TRACER, tracer, RequestAttributes.SCOPE_REQUEST);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchMultipleItemsAndIndices()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchMultipleItemsAndIndices", + "reference": " @Test\n public void testSearchMultipleItemsAndIndices() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(1);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchMultipleItemsAndIndices() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(10, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 69, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n static Map>> getCandidateJobs(String tableNameWithType,\n Map> allJobMetadata)\n throws Exception {\n long nowMs = System.currentTimeMillis();\n Map>> candidates = new HashMap<>();\n // If the job started most recently has already completed, then skip retry for the table.\n Pair latestStartedJob = null;\n Pair latestCompletedJob = null;\n // The processing order of job metadata from the given Map is not deterministic. Track the completed original\n // jobs so that we can simply skip the retry jobs belonging to the completed original jobs.\n Map completedOriginalJobs = new HashMap<>();\n Set cancelledOriginalJobs = new HashSet<>();\n for (Map.Entry> entry : allJobMetadata.entrySet()) {\n String jobId = entry.getKey();\n Map jobMetadata = entry.getValue();\n long statsUpdatedAt = Long.parseLong(jobMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));\n String jobStatsInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);\n if (StringUtils.isEmpty(jobStatsInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job progress stats\", jobId);\n continue;\n }\n String jobCtxInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n if (StringUtils.isEmpty(jobCtxInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job context\", jobId);\n continue;\n }\n TableRebalanceProgressStats jobStats = JsonUtils.stringToObject(jobStatsInStr, TableRebalanceProgressStats.class);\n TableRebalanceContext jobCtx = JsonUtils.stringToObject(jobCtxInStr, TableRebalanceContext.class);\n long jobStartTimeMs = jobStats.getStartTimeMs();\n if (latestStartedJob == null || latestStartedJob.getRight() < jobStartTimeMs) {\n latestStartedJob = Pair.of(jobId, jobStartTimeMs);\n }\n String originalJobId = jobCtx.getOriginalJobId();\n RebalanceResult.Status jobStatus = jobStats.getStatus();\n if (jobStatus == RebalanceResult.Status.DONE || jobStatus == RebalanceResult.Status.NO_OP) {\n LOGGER.info(\"Skip rebalance job: {} as it has completed with status: {}\", jobId, jobStatus);\n completedOriginalJobs.put(originalJobId, jobId);\n if (latestCompletedJob == null || latestCompletedJob.getRight() < jobStartTimeMs) {\n latestCompletedJob = Pair.of(jobId, jobStartTimeMs);\n }\n continue;\n }\n if (jobStatus == RebalanceResult.Status.FAILED || jobStatus == RebalanceResult.Status.ABORTED) {\n LOGGER.info(\"Found rebalance job: {} for original job: {} has been stopped with status: {}\", jobId,\n originalJobId, jobStatus);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n continue;\n }\n if (jobStatus == RebalanceResult.Status.CANCELLED) {\n LOGGER.info(\"Found cancelled rebalance job: {} for original job: {}\", jobId, originalJobId);\n cancelledOriginalJobs.add(originalJobId);\n continue;\n }\n // Check if an IN_PROGRESS job is still actively running.\n long heartbeatTimeoutMs = jobCtx.getConfig().getHeartbeatTimeoutInMs();\n if (nowMs - statsUpdatedAt < heartbeatTimeoutMs) {\n LOGGER.info(\"Rebalance job: {} is actively running with status updated at: {} within timeout: {}. Skip \"\n + \"retry for table: {}\", jobId, statsUpdatedAt, heartbeatTimeoutMs, tableNameWithType);\n return Collections.emptyMap();\n }\n // The job is considered failed, but it's possible it is still running, then we might end up with more than one\n // rebalance jobs running in parallel for a table. The rebalance algorithm is idempotent, so this should be fine\n // for the correctness.\n LOGGER.info(\"Found stuck rebalance job: {} for original job: {}\", jobId, originalJobId);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n }\n if (latestCompletedJob != null && latestCompletedJob.getLeft().equals(latestStartedJob.getLeft())) {\n LOGGER.info(\"Rebalance job: {} started most recently has already done. Skip retry for table: {}\",\n latestCompletedJob.getLeft(), tableNameWithType);\n return Collections.emptyMap();\n }\n for (String jobId : cancelledOriginalJobs) {\n LOGGER.info(\"Skip original job: {} as it's cancelled\", jobId);\n candidates.remove(jobId);\n }\n for (Map.Entry entry : completedOriginalJobs.entrySet()) {\n LOGGER.info(\"Skip original job: {} as it's completed by attempt: {}\", entry.getKey(), entry.getValue());\n candidates.remove(entry.getKey());\n }\n return candidates;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetCandidateJobs()\n throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetCandidateJobs", + "reference": " @Test\n public void testGetCandidateJobs()\n throws Exception {\n String tableName = \"table01\";\n Map> allJobMetadata = new HashMap<>();\n\n // Original job run as job1, and all its retry jobs failed too.\n RebalanceConfig jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n TableRebalanceProgressStats stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(1000);\n TableRebalanceContext jobCtx = TableRebalanceContext.forInitialAttempt(\"job1\", jobCfg);\n Map jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job1\", stats, jobCtx);\n allJobMetadata.put(\"job1\", jobMetadata);\n // 3 failed retry runs for job1\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 2, 1100, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 3, 1200, RebalanceResult.Status.ABORTED);\n allJobMetadata.put(\"job1_3\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 4, 1300, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_4\", jobMetadata);\n\n // Original job run as job2, and its retry job job2_1 completed.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(2000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job2\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job2\", stats, jobCtx);\n allJobMetadata.put(\"job2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job2\", 2, 2100, RebalanceResult.Status.DONE);\n allJobMetadata.put(\"job2_2\", jobMetadata);\n\n // Original job run as job3, and failed to send out heartbeat in time.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.IN_PROGRESS);\n stats.setStartTimeMs(3000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job3\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job3\", stats, jobCtx);\n jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"3000\");\n allJobMetadata.put(\"job3\", jobMetadata);\n\n // Original job run as job4, which didn't have retryJobCfg as from old version of the code.\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(4000);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job4\", stats, null);\n jobMetadata.remove(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n allJobMetadata.put(\"job4\", jobMetadata);\n\n // Only need to retry job1 and job3, as job2 is completed and job4 is from old version of code.\n Map>> jobs =\n RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 2);\n assertTrue(jobs.containsKey(\"job1\"));\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job1\").size(), 4); // four runs including job1,job1_1,job1_2,job1_3\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Abort job1 and cancel its retries, then only job3 is retry candidate.\n jobMetadata = allJobMetadata.get(\"job1_4\");\n cancelRebalanceJob(jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 1);\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Add latest job5 that's already done, thus no need to retry for table.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.DONE);\n stats.setStartTimeMs(5000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job5\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job5\", stats, jobCtx);\n allJobMetadata.put(\"job5\", jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 0);\n }", + "completions": [ + "```java\n @Test\n public void testGetCandidateJobs()\n throws Exception {\n // Given\n String tableNameWithType = \"table1_type1\";\n Map> allJobMetadata = new HashMap<>();\n Map jobMetadata1 = new HashMap<>();\n jobMetadata1.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"FAILED\\\"}\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job1\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job1\", jobMetadata1);\n Map jobMetadata2 = new HashMap<>();\n jobMetadata2.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"CANCELLED\\\"}\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job2\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job2\", jobMetadata2);\n\n // When\n Map>> result = getCandidateJobs(tableNameWithType, allJobMetadata);\n\n // Then\n assertEquals(1, result.size());\n assertTrue(result.containsKey(\"job2\"));\n Set> jobSet = result.get(\"job2\");\n assertEquals(1, jobSet.size());\n Pair jobPair = jobSet.iterator().next();\n assertEquals(\"job2\", jobPair.getLeft().getOriginalJobId());\n assertEquals(1640995200000L, jobPair.getRight().longValue());\n }\n```\n" + ], + "line": 94, + "token": 952, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 70, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchError()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchError", + "reference": " @Test\n public void distClientMatchError() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.ERROR);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.ERROR);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.NO_SUB));\n for (IMQTTTransientSession session : sessions) {\n verify(session, never()).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchError() {\n // Create a mock for the ITenantMeter interface\n ITenantMeter mockTenantMeter = Mockito.mock(ITenantMeter.class);\n // Create a mock for the ILocalDistService interface\n ILocalDistService mockLocalDistService = Mockito.mock(ILocalDistService.class);\n // Create a mock for the sessionMap\n Map mockSessionMap = new HashMap<>();\n // Create a mock for the routeMap\n Map> mockRouteMap = new HashMap<>();\n // Create a mock for the resourceThrottler\n IResourceThrottler mockResourceThrottler = Mockito.mock(IResourceThrottler.class);\n\n // Set up the mock objects\n Mockito.when(mockTenantMeter.get(Mockito.anyString())).thenReturn(mockTenantMeter);\n Mockito.when(mockLocalDistService.isGlobal(Mockito.anyString())).thenReturn(false);\n Mockito.when(mockSessionMap.get(Mockito.anyString())).thenReturn(null);\n Mockito.when(mockRouteMap.get(Mockito.any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(null));\n Mockito.when(mockResourceThrottler.hasResource(Mockito.anyString(), Mockito.anyLong())).thenReturn(false);\n\n // Create a new instance of the class under test\n DistService distService = new DistService(mockTenantMeter, mockLocalDistService, mockSessionMap, mockRouteMap, mockResourceThrottler);\n\n // Create a test request\n DeliveryRequest request = DeliveryRequest.newBuilder().build();\n\n // Call the method under test\n CompletableFuture replyFuture = distService.dist(request);\n\n // Verify the results\n DeliveryReply reply = replyFuture.join();\n Assert.assertEquals(0, reply.getResultMap().size());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 71, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchOK()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchOK", + "reference": " @Test\n public void distClientMatchOK() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n when(session.publish(any(), anyList())).thenReturn(true);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.OK);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.OK);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.OK));\n for (IMQTTTransientSession session : sessions) {\n verify(session, times(1)).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchOK() {\n // Given\n DeliveryRequest request = DeliveryRequest.newBuilder()\n .putPackage(\"tenant1\", DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .setMessagePack(TopicMessagePack.newBuilder()\n .setTopic(\"topic1\")\n .setPayload(\"payload1\")\n .build())\n .addMatchInfo(MatchInfo.newBuilder()\n .setReceiverId(\"receiver1\")\n .setTopicFilter(\"topic1\")\n .build())\n .build())\n .build())\n .build();\n sessionMap.put(\"receiver1\", mock(IMQTTTransientSession.class));\n when(sessionMap.get(\"receiver1\").publish(any(MatchInfo.class), anyList())).thenReturn(true);\n when(routeMap.get(any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(LocalRoutes.newBuilder()\n .setLocalizedReceiverId(\"receiver1\")\n .putRoute(\"receiver1\", mock(IMQTTTransientSession.class))\n .build()));\n when(ITenantMeter.get(\"tenant1\")).thenReturn(mock(ITenantMeter.class));\n when(resourceThrottler.hasResource(\"tenant1\", TotalTransientFanOutBytesPerSeconds)).thenReturn(true);\n\n // When\n CompletableFuture replyFuture = dist(request);\n\n // Then\n DeliveryReply reply = replyFuture.join();\n DeliveryResults results = reply.getResultMap().get(\"tenant1\");\n assertEquals(1, results.getResultCount());\n DeliveryResult result = results.getResult(0);\n assertEquals(DeliveryResult.Code.OK, result.getCode());\n assertEquals(\"receiver1\", result.getMatchInfo().getReceiverId());\n assertEquals(\"topic1\", result.getMatchInfo().getTopicFilter());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 72, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Range calcInterquartileRangeApproximation(List sortedValues) {\n Range iqr = null;\n\n if (sortedValues.size() > 0) {\n BigDecimal q1 = valueCloseToQ1(sortedValues);\n BigDecimal q3 = valueCloseToQ3(sortedValues);\n BigDecimal max = sortedValues.get(sortedValues.size() - 1);\n\n // if iqr == 0 AND max == q3 then recursively try finding a non-zero iqr approximation.\n if (q1.compareTo(q3) == 0 && max.compareTo(q3) == 0) {\n // filter out max and try again\n iqr = this.calcInterquartileRangeApproximation(\n sortedValues.stream().filter(d -> d.compareTo(max) < 0).collect(Collectors.toList()));\n }\n\n // if range is still empty use the original q1 and q3 values\n if (iqr == null || iqr.isEmpty()) {\n iqr = Range.closedOpen(q1, q3);\n }\n }\n\n return iqr;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCalcInterquartileRangeApproximation()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCalcInterquartileRangeApproximation", + "reference": " @Test\n public void testCalcInterquartileRangeApproximation() {\n List values = decList(1, 2, 3, 4, 5, 6, 7, 8);\n assertEquals(new BigDecimal(\"3\"), dataBinHelper.calcInterquartileRangeApproximation(values).lowerEndpoint());\n assertEquals(new BigDecimal(\"7\"), dataBinHelper.calcInterquartileRangeApproximation(values).upperEndpoint());\n }", + "completions": [ + "```java\n @Test\n public void testCalcInterquartileRangeApproximation() {\n List sortedValues = Arrays.asList(\n new BigDecimal(\"1.0\"),\n new BigDecimal(\"2.0\"),\n new BigDecimal(\"3.0\"),\n new BigDecimal(\"4.0\"),\n new BigDecimal(\"5.0\"),\n new BigDecimal(\"6.0\"),\n new BigDecimal(\"7.0\"),\n new BigDecimal(\"8.0\"),\n new BigDecimal(\"9.0\"),\n new BigDecimal(\"10.0\")\n );\n\n Range expectedIqr = Range.closedOpen(new BigDecimal(\"2.5\"), new BigDecimal(\"7.5\"));\n Range actualIqr = calcInterquartileRangeApproximation(sortedValues);\n\n assertEquals(expectedIqr, actualIqr);\n }\n```\n" + ], + "line": 30, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 73, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_non_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new NonSerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new NonSerializableObject(\"name2\"),\n new NonSerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObjectWithNonSerializableNestedObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serialable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key1\", \"value1\");\n map.put(\"key2\", new NonSerializableObject(\"nonSerializableValue\"));\n\n // When\n Map clonedMap = objectCloner.clone(map);\n\n // Then\n assertNotSame(map, clonedMap);\n assertEquals(map.size(), clonedMap.size());\n assertEquals(map.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(map.get(\"key2\"), clonedMap.get(\"key2\"));\n assertEquals(map.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 74, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SignedJWT token2Jwt(String token) {\n if (token == null) {\n log.error(\"Token is null.\");\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n try {\n return SignedJWT.parse(token);\n } catch (ParseException e) {\n log.error(\"Token to JWT Error,message = {}\", e.getMessage());\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testToken2Jwt()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testToken2Jwt", + "reference": " @Test\n public void testToken2Jwt() {\n Assertions.assertThrows(BizException.class, () -> tokenParseHelper.token2Jwt(null));\n }", + "completions": [ + "```java\n @Test\n public void testToken2Jwt() {\n // Given\n String token = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c\";\n\n // When\n SignedJWT jwt = token2Jwt(token);\n\n // Then\n assertNotNull(jwt);\n assertEquals(token, jwt.serialize());\n }\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testLongPollWaitIsLimited()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testLongPollWaitIsLimited", + "reference": " @Test\n public void testLongPollWaitIsLimited() {\n consumer = newConsumer();\n String topicName = \"topic1\";\n consumer.subscribe(singletonList(topicName));\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertTrue(consumer.assignment().isEmpty());\n\n final int partition = 3;\n final TopicPartition tp = new TopicPartition(topicName, partition);\n final List> records = asList(\n new ConsumerRecord<>(topicName, partition, 2, \"key1\", \"value1\"),\n new ConsumerRecord<>(topicName, partition, 3, \"key2\", \"value2\")\n );\n\n // On the first iteration, return no data; on the second, return two records\n doAnswer(invocation -> {\n // Mock the subscription being assigned as the first fetch is collected\n consumer.subscriptions().assignFromSubscribed(Collections.singleton(tp));\n return Fetch.empty();\n }).doAnswer(invocation -> {\n return Fetch.forPartition(tp, records, true);\n }).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n\n // And then poll for up to 10000ms, which should return 2 records without timing out\n ConsumerRecords returnedRecords = consumer.poll(Duration.ofMillis(10000));\n assertEquals(2, returnedRecords.count());\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertEquals(singleton(tp), consumer.assignment());\n }", + "completions": [ + "```java\n @Test\n public void testLongPollWaitIsLimited() {\n // Arrange\n Duration timeout = Duration.ofMillis(1000);\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // Act\n long startTime = System.currentTimeMillis();\n ConsumerRecords records = consumer.poll(timeout);\n long endTime = System.currentTimeMillis();\n\n // Assert\n long pollDuration = endTime - startTime;\n assertTrue(pollDuration >= timeout.toMillis());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 76, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForMinAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForMinAgg", + "reference": " @Test\n public void testFullIndexSearchForMinAgg() {\n Instant time = Instant.now();\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new MinAggBuilder(\n \"test\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalMin internalMin =\n (InternalMin) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n assertThat(Double.valueOf(internalMin.getValue()).longValue()).isEqualTo(time.toEpochMilli());\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForMinAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new MinAggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation() instanceof MinAgg);\n assertTrue(result.getElapsedTimeMicros() > 0);\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 77, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_bsaChecksumFetched", + "reference": " @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n setupMocks();\n when(connection.getResponseCode()).thenReturn(SC_OK);\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream(\"bsa-checksum\\ndata\".getBytes(UTF_8)));\n LazyBlockList download = fetcher.tryFetch(BlockListType.BLOCK);\n assertThat(download.getName()).isEqualTo(BlockListType.BLOCK);\n assertThat(download.checksum()).isEqualTo(\"bsa-checksum\");\n verify(connection, times(1)).setRequestMethod(\"GET\");\n verify(connection, times(1)).setRequestProperty(\"Authorization\", \"Bearer authToken\");\n }", + "completions": [ + "```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n // Given\n URL dataUrl = new URL(\"https://example.com/bsaChecksum\");\n HttpsURLConnection connection = mock(HttpsURLConnection.class);\n when(connection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);\n when(connection.getErrorStream()).thenReturn(new ByteArrayInputStream(\"\".getBytes()));\n when(urlConnectionService.createConnection(dataUrl)).thenReturn(connection);\n\n // When\n LazyBlockList lazyBlockList = bsaService.tryFetch(BlockListType.BSA_CHECKSUM);\n\n // Then\n assertEquals(BlockListType.BSA_CHECKSUM, lazyBlockList.getBlockListType());\n assertEquals(connection, lazyBlockList.getConnection());\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDisabledFullTextSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDisabledFullTextSearch", + "reference": " @Test\n public void testDisabledFullTextSearch() {\n Instant time = Instant.ofEpochSecond(1593365471);\n Trace.KeyValue field1Tag =\n Trace.KeyValue.newBuilder()\n .setVInt32(1234)\n .setKey(\"field1\")\n .setFieldType(Schema.SchemaFieldType.INTEGER)\n .build();\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(2, \"apple baby\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(3, \"baby car 1234\", time.plusSeconds(4)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Without the _all field as default.\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(3);\n\n // empty string\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"app*\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Returns baby or car, 2 messages.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby car\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Test numbers\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"apple 1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"123\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n }", + "completions": [ + "```java\n @Test\n public void testDisabledFullTextSearch() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getResults().size());\n assertNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 79, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyWithException", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(\n OperateRuntimeException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHitsRelation.EQUAL_TO));\n Mockito.when(searchHits.getHits()).thenReturn(new SearchHit[0]);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processInstanceListViewService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException to be thrown\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find unique process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 80, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_SecurityException_notRetriable", + "reference": " @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n when(connectionService.createConnection(any(URL.class)))\n .thenThrow(new GeneralSecurityException());\n assertThat(\n assertThrows(BsaException.class, () -> fetcher.tryFetch(BlockListType.BLOCK))\n .isRetriable())\n .isFalse();\n }", + "completions": [ + "```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n // Arrange\n String blockListTypeName = \"TEST\";\n BlockListType blockListType = new BlockListType(blockListTypeName);\n URL dataUrl = new URL(\"http://example.com\");\n when(blockListUrls.get(blockListTypeName)).thenReturn(dataUrl.toString());\n when(urlConnectionService.createConnection(dataUrl)).thenThrow(new GeneralSecurityException());\n\n // Act and Assert\n assertThrows(BsaException.class, () -> {\n try {\n blockListService.tryFetch(blockListType);\n } catch (BsaException e) {\n assertFalse(e.isRetriable());\n throw e;\n }\n });\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 81, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssignOnEmptyTopicPartition()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssignOnEmptyTopicPartition", + "reference": " @Test\n public void testAssignOnEmptyTopicPartition() {\n consumer = newConsumer();\n completeUnsubscribeApplicationEventSuccessfully();\n\n consumer.assign(Collections.emptyList());\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().isEmpty());\n }", + "completions": [ + "```java\n @Test\n public void testAssignOnEmptyTopicPartition() {\n // Given\n Collection partitions = new ArrayList<>();\n ConsumerImpl consumer = new ConsumerImpl();\n\n // When\n consumer.assign(partitions);\n\n // Then\n // Add assertions to verify the expected behavior\n // For example, you can check if the consumer is unsubscribed and if the fetch buffer is empty\n // assertFalse(consumer.isSubscribed());\n // assertTrue(consumer.getFetchBuffer().isEmpty());\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 82, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ImmutableList applyLabelDiff(\n ImmutableList labels,\n IdnChecker idnChecker,\n DownloadSchedule schedule,\n DateTime now) {\n ImmutableList.Builder nonBlockedDomains = new ImmutableList.Builder<>();\n ImmutableMap> labelsByType =\n ImmutableMap.copyOf(\n labels.stream().collect(groupingBy(BlockLabel::labelType, toImmutableList())));\n\n tm().transact(\n () -> {\n for (Map.Entry> entry :\n labelsByType.entrySet()) {\n switch (entry.getKey()) {\n case CREATE:\n // With current Cloud SQL, label upsert throughput is about 200/second. If\n // better performance is needed, consider bulk insert in native SQL.\n tm().putAll(\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(\n label ->\n new BsaLabel(label.label(), schedule.jobCreationTime()))\n .collect(toImmutableList()));\n // May not find all unblockables due to race condition: DomainCreateFlow uses\n // cached BsaLabels. Eventually will be consistent.\n nonBlockedDomains.addAll(\n tallyUnblockableDomainsForNewLabels(entry.getValue(), idnChecker, now));\n break;\n case DELETE:\n ImmutableSet deletedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n // Delete labels in DB. Also cascade-delete BsaUnblockableDomain.\n int nDeleted = Queries.deleteBsaLabelByLabels(deletedLabels);\n if (nDeleted != deletedLabels.size()) {\n logger.atSevere().log(\n \"Only found %s entities among the %s labels: [%s]\",\n nDeleted, deletedLabels.size(), deletedLabels);\n }\n break;\n case NEW_ORDER_ASSOCIATION:\n ImmutableSet affectedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n ImmutableSet labelsInDb =\n Queries.queryBsaLabelByLabels(affectedLabels)\n .map(BsaLabel::getLabel)\n .collect(toImmutableSet());\n verify(\n labelsInDb.size() == affectedLabels.size(),\n \"Missing labels in DB: %s\",\n LazyArgs.lazy(() -> difference(affectedLabels, labelsInDb)));\n\n // Reuse registered and reserved names that are already computed.\n Queries.queryBsaUnblockableDomainByLabels(affectedLabels)\n .map(BsaUnblockableDomain::toUnblockableDomain)\n .forEach(nonBlockedDomains::add);\n\n for (BlockLabel label : entry.getValue()) {\n getInvalidTldsForLabel(label, idnChecker)\n .map(tld -> UnblockableDomain.of(label.label(), tld, Reason.INVALID))\n .forEach(nonBlockedDomains::add);\n }\n break;\n }\n }\n },\n TRANSACTION_REPEATABLE_READ);\n logger.atInfo().log(\"Processed %s of labels.\", labels.size());\n return nonBlockedDomains.build();\n }\n```\n###test function signature:\n```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "applyLabelDiffs_newAssociationOfLabelToOrder", + "reference": " @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n tm().transact(\n () -> {\n tm().insert(new BsaLabel(\"label\", fakeClock.nowUtc()));\n tm().insert(new BsaUnblockableDomain(\"label\", \"app\", Reason.REGISTERED));\n });\n when(idnChecker.getSupportingTlds(any())).thenReturn(ImmutableSet.of(app));\n when(idnChecker.getForbiddingTlds(any()))\n .thenReturn(Sets.difference(ImmutableSet.of(dev), ImmutableSet.of()).immutableCopy());\n\n ImmutableList unblockableDomains =\n applyLabelDiff(\n ImmutableList.of(\n BlockLabel.of(\"label\", LabelType.NEW_ORDER_ASSOCIATION, ImmutableSet.of())),\n idnChecker,\n schedule,\n fakeClock.nowUtc());\n assertThat(unblockableDomains)\n .containsExactly(\n UnblockableDomain.of(\"label.app\", UnblockableDomain.Reason.REGISTERED),\n UnblockableDomain.of(\"label.dev\", UnblockableDomain.Reason.INVALID));\n assertThat(tm().transact(() -> tm().loadByKeyIfPresent(BsaLabel.vKey(\"label\")))).isPresent();\n assertThat(\n tm().transact(() -> tm().loadByKeyIfPresent(BsaUnblockableDomain.vKey(\"label\", \"app\"))))\n .isPresent();\n }", + "completions": [ + "```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n // Given\n ImmutableList labels = ImmutableList.of(BlockLabel.of(\"label1\", LabelType.NEW_ORDER_ASSOCIATION));\n IdnChecker idnChecker = mock(IdnChecker.class);\n DownloadSchedule schedule = mock(DownloadSchedule.class);\n DateTime now = DateTime.now();\n\n // When\n ImmutableList result = applyLabelDiff(labels, idnChecker, schedule, now);\n\n // Then\n // Add assertions here to verify the result\n }\n```\n" + ], + "line": 85, + "token": 602, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 83, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture createTopic(CreateTopicRequest request) {\n CompletableFuture future = new CompletableFuture<>();\n for (; ; ) {\n if (metadataStore.isLeader()) {\n try (SqlSession session = metadataStore.openSession()) {\n if (!metadataStore.maintainLeadershipWithSharedLock(session)) {\n continue;\n }\n TopicMapper topicMapper = session.getMapper(TopicMapper.class);\n if (null != topicMapper.get(null, request.getTopic())) {\n ControllerException e = new ControllerException(Code.DUPLICATED_VALUE,\n String.format(\"Topic %s was taken\", request.getTopic()));\n future.completeExceptionally(e);\n return future;\n }\n\n Topic topic = new Topic();\n topic.setName(request.getTopic());\n topic.setQueueNum(request.getCount());\n topic.setStatus(TopicStatus.TOPIC_STATUS_ACTIVE);\n topic.setAcceptMessageTypes(JsonFormat.printer().print(request.getAcceptTypes()));\n topic.setRetentionHours(request.getRetentionHours());\n topicMapper.create(topic);\n long topicId = topic.getId();\n List assignments = createQueues(IntStream.range(0, request.getCount()),\n topicId, session);\n // Commit transaction\n session.commit();\n\n // Cache new topic and queue assignments immediately\n topicCache.apply(List.of(topic));\n assignmentCache.apply(assignments);\n future.complete(topicId);\n } catch (ControllerException | InvalidProtocolBufferException e) {\n future.completeExceptionally(e);\n }\n return future;\n } else {\n Optional leaderAddress = metadataStore.electionService().leaderAddress();\n if (leaderAddress.isEmpty()) {\n return CompletableFuture.failedFuture(new ControllerException(Code.NO_LEADER_VALUE, \"No leader is elected yet\"));\n }\n metadataStore.controllerClient().createTopic(leaderAddress.get(), request).whenComplete((res, e) -> {\n if (null != e) {\n future.completeExceptionally(e);\n } else {\n future.complete(res);\n }\n });\n }\n break;\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCreateTopic()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCreateTopic", + "reference": " @Test\n public void testCreateTopic() {\n MetadataStore metadataStore = Mockito.mock(MetadataStore.class);\n ControllerClient controllerClient = Mockito.mock(ControllerClient.class);\n Mockito.when(controllerClient.createTopic(ArgumentMatchers.anyString(), ArgumentMatchers.any())).thenReturn(\n CompletableFuture.failedFuture(new CompletionException(new RuntimeException()))\n );\n Mockito.when(metadataStore.controllerClient()).thenReturn(controllerClient);\n Mockito.when(metadataStore.isLeader()).thenReturn(false);\n\n ElectionService electionService = Mockito.mock(ElectionService.class);\n Mockito.when(metadataStore.electionService()).thenReturn(electionService);\n\n Mockito.when(electionService.leaderAddress()).thenReturn(Optional.of(\"localhost:1234\"));\n\n TopicManager topicManager = new TopicManager(metadataStore);\n topicManager.createTopic(CreateTopicRequest.newBuilder().build());\n }", + "completions": [ + "```java\n @Test\n public void testCreateTopic() {\n // Given\n CreateTopicRequest request = new CreateTopicRequest();\n request.setTopic(\"testTopic\");\n request.setCount(3);\n request.setRetentionHours(24);\n request.setAcceptTypes(List.of(MessageType.MESSAGE_TYPE_JSON));\n\n // When\n CompletableFuture future = controller.createTopic(request);\n\n // Then\n assertDoesNotThrow(() -> {\n Long topicId = future.get();\n assertNotNull(topicId);\n assertTrue(topicId > 0);\n });\n }\n```\n" + ], + "line": 63, + "token": 424, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 84, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull(){\n try(MockedStatic sessionMockedStatic = Mockito.mockStatic(Session.class)){\n Session mockSession = mock(Session.class);\n Transport transport = mock(Transport.class);\n sessionMockedStatic.when(()->Session.getDefaultInstance(any(Properties.class))).thenReturn(mockSession);\n when(blacklistService.inBlacklist(anyString())).thenReturn(false);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n when(settingsService.getByCategoryAndKey(anyString(),anyString())).thenReturn(\"test\").thenReturn(5678).thenReturn(\"test\").thenReturn(\"test\").thenReturn(\"123456\");\n when(mockSession.getTransport(anyString())).thenReturn(transport);\n doNothing().when(transport).connect(anyString(),anyInt(),anyString(),anyString());\n doAnswer(invocationOnMock -> {\n InternetAddress[] internetAddressList = invocationOnMock.getArgument(1);\n Assertions.assertEquals(addressList.get(0),internetAddressList[0].getAddress());\n return null;\n }).when(transport).sendMessage(any(MimeMessage.class),any());\n mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n } catch (MessagingException e) {\n throw new RuntimeException(e);\n }\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = Arrays.asList(\"test1@example.com\", \"test2@example.com\");\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://example.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 85, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead", + "reference": " @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 1,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n AstraConfigs.IndexerConfig headLocationAndNoRecoveryConfig =\n AstraConfigs.IndexerConfig.newBuilder()\n .setCreateRecoveryTasksOnStart(false)\n .setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST)\n .build();\n\n // When there is no data and ReadFromLocationOnStart is set to LATEST, return the current head\n assertThat(\n recoveryTaskCreator.determineStartingOffset(1000, 0, headLocationAndNoRecoveryConfig))\n .isEqualTo(1000);\n\n // Data exists for not for this partition.\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTime = 1;\n final long endTime = 100;\n final long maxOffset = 100;\n\n final SnapshotMetadata partition1 =\n new SnapshotMetadata(name, path, startTime, endTime, maxOffset, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition1);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition1));\n\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n\n final SnapshotMetadata partition11 =\n new SnapshotMetadata(\n name + \"1\", path, endTime + 1, endTime * 2, maxOffset * 2, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition11);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition11));\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(partition1, partition11);\n assertThat(recoveryTaskCreator.determineStartingOffset(0, 0, indexerConfig)).isNegative();\n\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n \"2\",\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(false);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n\n // When\n long result = determineStartingOffset(currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(currentEndOffsetForPartition, result);\n }\n```\n" + ], + "line": 128, + "token": 935, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 86, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime =\n LocalDateTime.of(2020, 10, 1, 10, 10, 0).atZone(ZoneOffset.UTC).toInstant();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n long startOffset = 1;\n long endOffset = msgsToProduce - 1;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry))\n .isEqualTo(endOffset - startOffset + 1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(0);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(0);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isEmpty();\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 0,\n 100,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 87, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"3g\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n String expirationTime = \"10x\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(expirationTime);\n assertEquals(CacheControlHeader.DEFAULT_BASE_VALUE + CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_empty_collection()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_empty_collection", + "reference": " @Test\n public void should_clone_empty_collection() {\n List original = new ArrayList<>();\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_empty_collection() {\n // Given\n Collection collection = new ArrayList<>();\n\n // When\n Collection clonedCollection = (Collection) clone(collection);\n\n // Then\n assertNotSame(collection, clonedCollection);\n assertEquals(collection, clonedCollection);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 89, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testTimeBoundSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testTimeBoundSearch", + "reference": " @Test\n public void testTimeBoundSearch() {\n Instant time = Instant.now();\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(1, time));\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(2, time.plusSeconds(100)));\n strictLogStore.logStore.commit();\n strictLogStore.logStore.refresh();\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(2);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(0);\n assertThat(getTimerCount(REFRESHES_TIMER, strictLogStore.metricsRegistry)).isEqualTo(1);\n\n // Start inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Extended range still only picking one element.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(90).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Both ranges are inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.toEpochMilli(),\n time.plusSeconds(100).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n // Extended range to pick up both events\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(1000).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n }", + "completions": [ + "```java\n @Test\n public void testTimeBoundSearch() {\n // Given\n String dataset = \"test_dataset\";\n String queryStr = \"test_query\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n for (LogMessage logMessage : result.getResults()) {\n assertTrue(logMessage.getTimestampMs() >= startTimeMsEpoch);\n assertTrue(logMessage.getTimestampMs() <= endTimeMsEpoch);\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 90, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "\t@ParameterizedTest\n\t@MethodSource", + "reference": "\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\t// setup\n\t\tString key = \"key\";\n\t\tTestParameters parameters = new TestParameters();\n\t\tObject testValue = theParams.getTestValue();\n\n\t\t// test\n\t\tif (theParams.isExpectedToWork()) {\n\t\t\tparameters.setUserData(key, testValue);\n\t\t\tassertFalse(parameters.getUserData().isEmpty());\n\t\t\tassertEquals(testValue, parameters.getUserData().get(key));\n\t\t} else {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, testValue);\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tString dataType = testValue.getClass().getName();\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid data type provided \" + dataType),\n\t\t\t\t\tex.getMessage());\n\t\t\t\tassertTrue(parameters.getUserData().isEmpty());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\tUser user = new User();\n\t\tuser.setUserData(theParams.getKey(), theParams.getValue());\n\t\tassertEquals(theParams.getValue(), user.getUserData().get(theParams.getKey()));\n\t}\n\n\tstatic Stream parameters() {\n\t\treturn Stream.of(\n\t\t\t\tnew TestParam(\"stringKey\", \"stringValue\"),\n\t\t\t\tnew TestParam(\"numberKey\", 123),\n\t\t\t\tnew TestParam(\"booleanKey\", true),\n\t\t\t\tnew TestParam(\"nullKey\", null)\n\t\t);\n\t}\n\n\tstatic class TestParam {\n\t\tprivate final String key;\n\t\tprivate final Object value;\n\n\t\tTestParam(String key, Object value) {\n\t\t\tthis.key = key;\n\t\t\tthis.value = value;\n\t\t}\n\n\t\tString getKey() {\n\t\t\treturn key;\n\t\t}\n\n\t\tObject getValue() {\n\t\t\treturn value;\n\t\t}\n\t}\n```\n" + ], + "line": 24, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\n\t\tString text = \"Hello, how are you?\";\n\n\t\tEmbedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));\n\n\t\tserver\n\t\t\t.expect(requestToUriTemplate(\"/models/{generative}:embedText?key={apiKey}\",\n\t\t\t\t\tVertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))\n\t\t\t.andExpect(method(HttpMethod.POST))\n\t\t\t.andExpect(content().json(objectMapper.writeValueAsString(Map.of(\"text\", text))))\n\t\t\t.andRespond(withSuccess(objectMapper.writeValueAsString(Map.of(\"embedding\", expectedEmbedding)),\n\t\t\t\t\tMediaType.APPLICATION_JSON));\n\n\t\tEmbedding embedding = client.embedText(text);\n\n\t\tassertThat(embedding).isEqualTo(expectedEmbedding);\n\n\t\tserver.verify();\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\tEmbeddingResponse response = new EmbeddingResponse(expectedEmbedding);\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(response);\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Mono syncAclWithAclCsv(KafkaCluster cluster, String csv) {\n return adminClientService.get(cluster)\n .flatMap(ac -> ac.listAcls(ResourcePatternFilter.ANY).flatMap(existingAclList -> {\n var existingSet = Set.copyOf(existingAclList);\n var newAcls = Set.copyOf(AclCsv.parseCsv(csv));\n var toDelete = Sets.difference(existingSet, newAcls);\n var toAdd = Sets.difference(newAcls, existingSet);\n logAclSyncPlan(cluster, toAdd, toDelete);\n if (toAdd.isEmpty() && toDelete.isEmpty()) {\n return Mono.empty();\n }\n log.info(\"Starting new ACLs creation\");\n return ac.createAcls(toAdd)\n .doOnSuccess(v -> {\n log.info(\"{} new ACLs created\", toAdd.size());\n log.info(\"Starting ACLs deletion\");\n })\n .then(ac.deleteAcls(toDelete)\n .doOnSuccess(v -> log.info(\"{} ACLs deleted\", toDelete.size())));\n }));\n }\n```\n###test function signature:\n```java\n @Test\n void testSyncAclWithAclCsv()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSyncAclWithAclCsv", + "reference": " @Test\n void testSyncAclWithAclCsv() {\n var existingBinding1 = new AclBinding(\n new ResourcePattern(ResourceType.TOPIC, \"*\", PatternType.LITERAL),\n new AccessControlEntry(\"User:test1\", \"*\", AclOperation.READ, AclPermissionType.ALLOW));\n\n var existingBinding2 = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"group1\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test2\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n var newBindingToBeAdded = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"groupNew\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test3\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n when(adminClientMock.listAcls(ResourcePatternFilter.ANY))\n .thenReturn(Mono.just(List.of(existingBinding1, existingBinding2)));\n\n ArgumentCaptor> createdCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.createAcls(createdCaptor.capture()))\n .thenReturn(Mono.empty());\n\n ArgumentCaptor> deletedCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.deleteAcls(deletedCaptor.capture()))\n .thenReturn(Mono.empty());\n\n aclsService.syncAclWithAclCsv(\n CLUSTER,\n \"Principal,ResourceType, PatternType, ResourceName,Operation,PermissionType,Host\\n\"\n + \"User:test1,TOPIC,LITERAL,*,READ,ALLOW,*\\n\"\n + \"User:test3,GROUP,PREFIXED,groupNew,DESCRIBE,DENY,localhost\"\n ).block();\n\n Collection createdBindings = createdCaptor.getValue();\n assertThat(createdBindings)\n .hasSize(1)\n .contains(newBindingToBeAdded);\n\n Collection deletedBindings = deletedCaptor.getValue();\n assertThat(deletedBindings)\n .hasSize(1)\n .contains(existingBinding2);\n }", + "completions": [ + "```java\n @Test\n void testSyncAclWithAclCsv() {\n // Given\n KafkaCluster cluster = new KafkaCluster(\"testCluster\", \"localhost:9092\");\n String csv = \"principal,host,operation,permissionType\\nUser:Alice,*,Read,Allow\\nUser:Bob,*,Write,Allow\";\n AdminClient adminClient = Mockito.mock(AdminClient.class);\n when(adminClientService.get(cluster)).thenReturn(Mono.just(adminClient));\n when(adminClient.listAcls(ResourcePatternFilter.ANY)).thenReturn(Flux.just(\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Alice\", \"*\", AclOperation.READ, AclPermissionType.ALLOW)),\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Bob\", \"*\", AclOperation.WRITE, AclPermissionType.ALLOW))\n ).collectList());\n\n // When\n Mono result = service.syncAclWithAclCsv(cluster, csv);\n\n // Then\n StepVerifier.create(result)\n .verifyComplete();\n\n verify(adminClient).createAcls(any());\n verify(adminClient).deleteAcls(any());\n }\n```\n" + ], + "line": 32, + "token": 259, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 93, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyNoResults", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final SearchResponse mockResponse = mock(SearchResponse.class);\n final SearchHits mockHits = mock(SearchHits.class);\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n when(mockHits.getHits()).thenReturn(new SearchHit[0]);\n when(tenantAwareClient.search(any(SearchRequest.class))).thenReturn(mockResponse);\n\n // When\n try {\n processInstanceService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 94, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyMultipleRecoveryBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(3);\n assertThat(recoveryTasks).contains(recoveryTask1, recoveryTask11);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo(recoveryStartOffset * 3 + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(100);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 95, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n private String saveAndChecksumBlockList(String jobName, LazyBlockList blockList) {\n BlobId blobId = getBlobId(jobName, getBlockListFileName(blockList.getName()));\n try (BufferedOutputStream gcsWriter =\n new BufferedOutputStream(gcsUtils.openOutputStream(blobId))) {\n MessageDigest messageDigest = MessageDigest.getInstance(checksumAlgorithm);\n blockList.consumeAll(\n (byteArray, length) -> {\n try {\n gcsWriter.write(byteArray, 0, length);\n } catch (IOException e) {\n throw new RuntimeException(e);\n }\n messageDigest.update(byteArray, 0, length);\n });\n return base16().lowerCase().encode(messageDigest.digest());\n } catch (IOException | NoSuchAlgorithmException e) {\n throw new RuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "saveAndChecksumBlockList_success", + "reference": " @Test\n void saveAndChecksumBlockList_success() throws Exception {\n String payload = \"somedata\\n\";\n String payloadChecksum = \"0737c8e591c68b93feccde50829aca86a80137547d8cfbe96bab6b20f8580c63\";\n\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream((\"bsa-checksum\\n\" + payload).getBytes(UTF_8)));\n lazyBlockList = new LazyBlockList(BlockListType.BLOCK, connection);\n\n ImmutableMap checksums =\n gcsClient.saveAndChecksumBlockList(\"some-name\", ImmutableList.of(lazyBlockList));\n assertThat(gcsUtils.existsAndNotEmpty(BlobId.of(\"my-bucket\", \"some-name/BLOCK.csv\"))).isTrue();\n assertThat(checksums).containsExactly(BlockListType.BLOCK, payloadChecksum);\n assertThat(gcsClient.readBlockList(\"some-name\", BlockListType.BLOCK))\n .containsExactly(\"somedata\");\n }", + "completions": [ + "```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception {\n // Given\n String jobName = \"testJob\";\n LazyBlockList blockList = mock(LazyBlockList.class);\n BlobId blobId = mock(BlobId.class);\n BufferedOutputStream gcsWriter = mock(BufferedOutputStream.class);\n MessageDigest messageDigest = mock(MessageDigest.class);\n when(gcsUtils.openOutputStream(blobId)).thenReturn(gcsWriter);\n when(checksumAlgorithm).thenReturn(\"SHA-256\");\n when(messageDigest.digest()).thenReturn(new byte[] {1, 2, 3, 4, 5});\n when(base16().lowerCase().encode(messageDigest.digest())).thenReturn(\"0102030405\");\n\n // When\n String checksum = saveAndChecksumBlockList(jobName, blockList);\n\n // Then\n assertEquals(\"0102030405\", checksum);\n verify(gcsWriter).write(any(), anyInt(), anyInt());\n verify(messageDigest).update(any(), anyInt(), anyInt());\n }\n```\n" + ], + "line": 30, + "token": 211, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ListenableFuture> findMissingBlobs(\n Iterable blobDigests, RequestMetadata requestMetadata) {\n // Some requests have been blocked, and we should tell the client we refuse to perform a lookup.\n try {\n if (inDenyList(requestMetadata)) {\n return immediateFailedFuture(\n Status.UNAVAILABLE\n .withDescription(\"The action associated with this request is forbidden\")\n .asException());\n }\n } catch (IOException e) {\n return immediateFailedFuture(Status.fromThrowable(e).asException());\n }\n\n // Empty blobs are an exceptional case. Filter them out.\n // If the user only requested empty blobs we can immediately tell them we already have it.\n Iterable nonEmptyDigests =\n Iterables.filter(blobDigests, (digest) -> digest.getSizeBytes() != 0);\n if (Iterables.isEmpty(nonEmptyDigests)) {\n return immediateFuture(ImmutableList.of());\n }\n\n if (configs.getServer().isFindMissingBlobsViaBackplane()) {\n return findMissingBlobsViaBackplane(nonEmptyDigests, requestMetadata);\n }\n\n return findMissingBlobsQueryingEachWorker(nonEmptyDigests, requestMetadata);\n }\n```\n###test function signature:\n```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "findMissingBlobsTest_ViaBackPlane", + "reference": " @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n Set activeWorkers = ImmutableSet.of(\"worker1\", \"worker2\", \"worker3\");\n Set expiredWorkers = ImmutableSet.of(\"workerX\", \"workerY\", \"workerZ\");\n Set imposterWorkers = ImmutableSet.of(\"imposter1\", \"imposter2\", \"imposter3\");\n\n Set availableDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFound1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build());\n\n Set missingDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"missing1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build());\n\n Set digestAvailableOnImposters =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFoundOnImposter1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter3\").setSizeBytes(1).build());\n\n Set emptyDigests =\n new HashSet<>(\n Arrays.asList(\n Digest.newBuilder().setHash(\"empty1\").build(),\n Digest.newBuilder().setHash(\"empty2\").build()));\n\n Iterable allDigests =\n Iterables.concat(\n availableDigests,\n missingDigests,\n emptyDigests,\n digestAvailableOnImposters,\n Arrays.asList(\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build()));\n\n Map> digestAndWorkersMap = new HashMap<>();\n\n for (Digest digest : availableDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(activeWorkers));\n }\n for (Digest digest : missingDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(expiredWorkers));\n }\n for (Digest digest : digestAvailableOnImposters) {\n digestAndWorkersMap.put(digest, getRandomSubset(imposterWorkers));\n }\n\n BuildfarmConfigs buildfarmConfigs = instance.getBuildFarmConfigs();\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(true);\n Set activeAndImposterWorkers =\n Sets.newHashSet(Iterables.concat(activeWorkers, imposterWorkers));\n\n when(mockBackplane.getStorageWorkers()).thenReturn(activeAndImposterWorkers);\n when(mockBackplane.getBlobDigestsWorkers(any(Iterable.class))).thenReturn(digestAndWorkersMap);\n when(mockInstanceLoader.load(anyString())).thenReturn(mockWorkerInstance);\n when(mockWorkerInstance.findMissingBlobs(anyIterable(), any(RequestMetadata.class)))\n .thenReturn(Futures.immediateFuture(new ArrayList<>()));\n\n long serverStartTime = 1686951033L; // june 15th, 2023\n Map workersStartTime = new HashMap<>();\n for (String worker : activeAndImposterWorkers) {\n workersStartTime.put(worker, serverStartTime);\n }\n when(mockBackplane.getWorkersStartTimeInEpochSecs(activeAndImposterWorkers))\n .thenReturn(workersStartTime);\n long oneDay = 86400L;\n for (Digest digest : availableDigests) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime + oneDay);\n }\n for (Digest digest : digestAvailableOnImposters) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime - oneDay);\n }\n\n Iterable actualMissingDigests =\n instance.findMissingBlobs(allDigests, RequestMetadata.getDefaultInstance()).get();\n Iterable expectedMissingDigests =\n Iterables.concat(missingDigests, digestAvailableOnImposters);\n\n assertThat(actualMissingDigests).containsExactlyElementsIn(expectedMissingDigests);\n verify(mockWorkerInstance, atMost(3))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n verify(mockWorkerInstance, atLeast(1))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n\n for (Digest digest : actualMissingDigests) {\n assertThat(digest).isNotIn(availableDigests);\n assertThat(digest).isNotIn(emptyDigests);\n assertThat(digest).isIn(expectedMissingDigests);\n }\n\n // reset BuildfarmConfigs\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(false);\n }", + "completions": [ + "```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n // Given\n Iterable blobDigests = Arrays.asList(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n RequestMetadata requestMetadata = RequestMetadata.newBuilder().setClientId(\"client1\").build();\n\n // When\n ListenableFuture> future =\n underTest.findMissingBlobs(blobDigests, requestMetadata);\n\n // Then\n Iterable missingBlobs = future.get();\n assertThat(missingBlobs).containsExactly(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n }\n```\n" + ], + "line": 37, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 97, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent_notFound()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent_notFound", + "reference": " @Test\n public void testReadComponent_notFound() {\n Component component = mock(Component.class);\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(null);\n try {\n underTest.readComponent(\"someid\", \"testRepositoryName\");\n fail(\"Exception should have been thrown\");\n }\n catch (WebApplicationException e) {\n assertThat(e.getResponse(), is(notNullValue()));\n assertThat(e.getResponse().getStatus(), is(404));\n }\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent_notFound() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(() -> storageTx);\n when(storageTx.findComponent(componentId)).thenReturn(null);\n\n // When\n try {\n readComponent(repository, componentId);\n fail(\"Expected WebApplicationException\");\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Status.NOT_FOUND.getStatusCode(), e.getResponse().getStatus());\n }\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture auth(MQTT3AuthData authData) {\n try {\n HttpRequest request = HttpRequest.newBuilder()\n .uri(URI.create(webhookURI + \"/auth\"))\n .header(\"Content-Type\", \"application/json\")\n .POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))\n .timeout(Duration.ofSeconds(5))\n .build();\n return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())\n .thenApply(response -> {\n if (response.statusCode() == 200) {\n try {\n MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();\n JsonFormat.parser()\n .ignoringUnknownFields()\n .merge(response.body(), resultBuilder);\n return resultBuilder.build();\n } catch (Throwable e) {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build();\n }\n } else {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(\"Authenticate failed\")\n .build())\n .build();\n }\n })\n .exceptionally(e -> {\n System.out.println(\"Failed to call webhook: \" + e.getMessage());\n return null;\n });\n } catch (Throwable e) {\n return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAuth()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAuth", + "reference": " @Test\n public void testAuth() {\n MQTT3AuthData authedUser = MQTT3AuthData.newBuilder()\n .setUsername(\"authUser\")\n .build();\n\n MQTT3AuthData unauthUser = MQTT3AuthData.newBuilder()\n .setUsername(\"unauthUser\")\n .build();\n testServer.addAuthedUser(authedUser.getUsername());\n WebHookBasedAuthProvider provider = new WebHookBasedAuthProvider(testServer.getURI());\n MQTT3AuthResult authResult = provider.auth(authedUser).join();\n assertTrue(authResult.hasOk());\n\n authResult = provider.auth(unauthUser).join();\n assertFalse(authResult.hasOk());\n }", + "completions": [ + "```java\n @Test\n public void testAuth() {\n // Given\n String webhookURI = \"http://localhost:8080\";\n HttpClient httpClient = HttpClient.newHttpClient();\n MQTT3AuthData authData = MQTT3AuthData.newBuilder()\n .setClientId(\"client1\")\n .setUsername(\"user1\")\n .setPassword(\"password1\")\n .build();\n MQTT3AuthResult expectedResult = MQTT3AuthResult.newBuilder()\n .setSuccess(Success.newBuilder()\n .setIsSuperuser(false)\n .build())\n .build();\n\n // When\n CompletableFuture result = new MQTT3AuthService(webhookURI, httpClient).auth(authData);\n\n // Then\n assertNotNull(result);\n assertEquals(expectedResult, result.join());\n }\n```\n" + ], + "line": 59, + "token": 335, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 99, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_success", + "reference": " @Test\n public void testValidateRoleDuplicate_success() {\n // 调用,不会抛异常\n roleService.validateRoleDuplicate(randomString(), randomString(), null);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_success() {\n // Given\n String name = \"TestRole\";\n String code = \"TEST_ROLE\";\n Long id = 1L;\n\n // When\n // Call the method to be tested\n validateRoleDuplicate(name, code, id);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/4/EI/line_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Java/4/EI/line_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..d20b8394acd7abcef00fb3b9f80586966f739596 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/4/EI/line_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(22~47),subset_1(51~74),subset_2(85~99),subset_3(109~128) +StarCoder2-15b,33.11,25.25,20.71,20.50 +CodeLlama-7b,34.10,29.25,21.00,23.62 +CodeLlama-13b,31.07,27.33,21.22,23.23 +CodeLlama-34b,31.91,28.57,22.31,23.22 +DeepSeek-Coder-1.3b,30.85,29.82,22.11,22.76 +DeepSeek-Coder-6.7b,20.25,22.86,21.75,24.54 +DeepSeek-Coder-33b,35.08,29.94,21.64,21.86 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/4/EI/token_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Java/4/EI/token_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..35a7684311a697281b664660c7e269a942f744f8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/4/EI/token_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~335),subset_1(346~544),subset_2(602~681),subset_3(793~952) +StarCoder2-15b,29.65,32.16,21.97,22.42 +CodeLlama-7b,32.59,32.34,21.12,23.90 +CodeLlama-13b,29.19,30.96,21.81,21.52 +CodeLlama-34b,30.05,31.90,21.21,26.06 +DeepSeek-Coder-1.3b,29.05,32.14,23.49,18.65 +DeepSeek-Coder-6.7b,21.59,19.46,22.72,23.39 +DeepSeek-Coder-33b,32.06,34.73,23.98,20.64 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/4/EI/tongji.py b/dataset/Test Generation/ComplexCodeEval-Java/4/EI/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/4/EI/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/4/QS/QS.json b/dataset/Test Generation/ComplexCodeEval-Java/4/QS/QS.json new file mode 100644 index 0000000000000000000000000000000000000000..35e2cf6370c0500de05cb58f61567ca4fb0e63e4 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/4/QS/QS.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTask", + "reference": " @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(TEST_S3_BUCKET);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 100L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockKafkaConsumer.consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(true);\n Mockito.when(mockKafkaConsumer.close()).thenReturn(true);\n Mockito.when(mockChunkManager.stopAsync()).thenReturn(true);\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenReturn(true);\n Mockito.when(AstraKafkaConsumer.makeKafkaConfig(Mockito.any(), Mockito.anyInt())).thenReturn(new Properties());\n Mockito.when(RecoveryChunkManager.fromConfig(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(mockChunkManager);\n Mockito.when(adminClient.listConsumerGroupOffsets(Mockito.anyString())).thenReturn(new ConsumerGroupOffsets());\n Mockito.when(AstraConfig.getRecoveryConfig()).thenReturn(new RecoveryConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig()).thenReturn(new KafkaConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic()).thenReturn(\"testTopic\");\n Mockito.when(AstraConfig.getIndexerConfig()).thenReturn(new IndexerConfig());\n Mockito.when(AstraConfig.getS3Config()).thenReturn(new S3Config());\n Mockito.when(meterRegistry.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class));\n Mockito.when(Timer.start(Mockito.any())).thenReturn(Mockito.mock(Timer.Sample.class));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n Assert.assertTrue(result);\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).prepConsumerForConsumption(Mockito.anyLong());\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong());\n Mockito.verify(mockChunkManager, Mockito.times(1)).waitForRollOvers();\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).close();\n Mockito.verify(mockChunkManager, Mockito.times(1)).stopAsync();\n Mockito.verify(mockChunkManager, Mockito.times(1)).awaitTerminated(Mockito.anyLong());\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 1, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskFailure", + "reference": " @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n String fakeS3Bucket = \"fakeBucket\";\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(fakeS3Bucket);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isFalse();\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 10L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(false);\n Mockito.when(mockKafkaConsumer.close()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.stopAsync()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertFalse(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 2, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public int migration(String srcTable, List uuids, String targetTable) {\n if (Objects.isNull(srcTable) || Objects.isNull(targetTable) || CollectionUtils.isEmpty(uuids)) {\n throw new SwValidationException(\n ValidSubject.EVALUATION, \"srcTable, targetTable and uuids must not be null\");\n }\n var filter = TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build();\n return store.migration(DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(filter)\n .build()\n );\n }\n```\n###test function signature:\n```java\n @Test\n public void testMigration()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMigration", + "reference": " @Test\n public void testMigration() {\n var srcTable = \"p/1/space/1/eval/summary\";\n var targetTable = \"p/1/eval/summary\";\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(null, List.of(\"uuid1\"), targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, null, targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, List.of(\"uuid1\"), null));\n\n evaluationRepo.migration(srcTable, List.of(\"uuid1\", \"uuid2\"), targetTable);\n\n verify(dataStore, times(1)).migration(\n DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(\n TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(List.of(\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid1\")))\n .build(),\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid2\")))\n .build()))\n .build())\n .build()\n );\n }", + "completions": [ + "```java\n @Test\n public void testMigration() {\n // Given\n String srcTable = \"srcTable\";\n List uuids = Arrays.asList(\"uuid1\", \"uuid2\", \"uuid3\");\n String targetTable = \"targetTable\";\n DataStoreMigrationRequest expectedRequest = DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build()\n )\n .build();\n when(store.migration(any(DataStoreMigrationRequest.class))).thenReturn(1);\n\n // When\n int result = migrationService.migration(srcTable, uuids, targetTable);\n\n // Then\n assertEquals(1, result);\n verify(store).migration(expectedRequest);\n }\n```\n" + ], + "line": 37, + "token": 221, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDatasMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDatasMapTest", + "reference": " @Test\n public void sendDatasMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"doris_test_db\";\n List> dataList = new ArrayList<>();\n for (int i = 0; i < 100; i++) {\n Map data = new HashMap<>();\n data.put(\"id\", i);\n data.put(\"name\", \"张三\");\n data.put(\"message\", \"股东风波的法规的规定翻跟斗广泛的给yryrtytr\");\n System.out.println(gson.toJson(data));\n dataList.add(data);\n\n }\n dorisStreamLoad.sendData(database, tableName, columnList, dataList);\n }", + "completions": [ + "```java\n @Test\n public void sendDataTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 4, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_serializable_key_and_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_serializable_key_and_value", + "reference": " @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map original = new HashMap<>();\n original.put(\"key\", new SerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map originalMap = new HashMap<>();\n originalMap.put(\"one\", 1);\n originalMap.put(\"two\", 2);\n originalMap.put(\"three\", 3);\n\n Map clonedMap = clone(originalMap);\n\n assertNotSame(originalMap, clonedMap);\n assertEquals(originalMap, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 5, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void encode(String schemaName, long nextTransactionOffset) {\n if (schemaName == null) {\n schemaName = UNKNOWN;\n }\n this.schemaName = schemaName;\n this.nextTransactionOffset = nextTransactionOffset;\n final byte[] payloadBytes = payloadToBytes();\n final int payloadLength = payloadBytes.length;\n\n // set logEventHeader\n int eventSize = eventHeaderLengthVersionGt1 + payloadLength;\n setLogEventHeader(new LogEventHeader(drc_filter_log_event.getType(), 0, eventSize, nextTransactionOffset)\n );\n\n // set payload\n final ByteBuf payloadByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(payloadLength);\n payloadByteBuf.writeBytes(payloadBytes);\n payloadByteBuf.skipBytes(payloadLength);\n setPayloadBuf(payloadByteBuf);\n }\n```\n###test function signature:\n```java\n @Test\n public void encode()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "encode", + "reference": " @Test\n public void encode() {\n FilterLogEvent filterLogEvent = new FilterLogEvent();\n filterLogEvent.encode(\"drc1\", 101);\n Assert.assertEquals(\"drc1\", filterLogEvent.getSchemaName());\n Assert.assertEquals(101, filterLogEvent.getNextTransactionOffset());\n\n ByteBuf headerBuf = filterLogEvent.getLogEventHeader().getHeaderBuf();\n headerBuf.readerIndex(0);\n ByteBuf payloadBuf = filterLogEvent.getPayloadBuf();\n payloadBuf.readerIndex(0);\n CompositeByteBuf compositeByteBuf = PooledByteBufAllocator.DEFAULT.compositeDirectBuffer();\n compositeByteBuf.addComponents(true, headerBuf, payloadBuf);\n\n FilterLogEvent newFilterLogEvent = new FilterLogEvent();\n newFilterLogEvent.read(compositeByteBuf);\n compositeByteBuf.release(compositeByteBuf.refCnt());\n Assert.assertEquals(\"drc1\", newFilterLogEvent.getSchemaName());\n Assert.assertEquals(101, newFilterLogEvent.getNextTransactionOffset());\n }", + "completions": [ + "```java\n @Test\n public void encode() {\n // Given\n String schemaName = \"testSchema\";\n long nextTransactionOffset = 1234567890L;\n drc_filter_log_event event = new drc_filter_log_event();\n\n // When\n event.encode(schemaName, nextTransactionOffset);\n\n // Then\n assertEquals(schemaName, event.getSchemaName());\n assertEquals(nextTransactionOffset, event.getNextTransactionOffset());\n assertEquals(drc_filter_log_event.getType(), event.getLogEventHeader().getType());\n assertEquals(0, event.getLogEventHeader().getServerId());\n assertEquals(eventHeaderLengthVersionGt1 + event.getPayloadBuf().readableBytes(), event.getLogEventHeader().getEventLength());\n assertEquals(nextTransactionOffset, event.getLogEventHeader().getNextTransactionOffset());\n assertEquals(event.getPayloadBuf().readableBytes(), event.getPayloadBuf().readableBytes());\n }\n```\n" + ], + "line": 29, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void put(String path, long position, ByteBuf data) {\n writeLock.lock();\n try {\n int dataLength = data.readableBytes();\n NavigableMap cache = path2cache.computeIfAbsent(path, k -> new TreeMap<>());\n Map.Entry pos2value = cache.floorEntry(position);\n long cacheStartPosition;\n Value value;\n if (pos2value == null || pos2value.getKey() + pos2value.getValue().dataLength < position) {\n cacheStartPosition = position;\n value = Value.EMPTY;\n } else {\n cacheStartPosition = pos2value.getKey();\n value = pos2value.getValue();\n }\n // ensure the capacity, if the capacity change then update the cache index\n int moreCapacity = (int) ((position + dataLength) - (cacheStartPosition + value.blocks.length * (long) blockSize));\n int newDataLength = (int) (position + dataLength - cacheStartPosition);\n if (moreCapacity > 0) {\n int[] blocks = ensureCapacity(cacheStartPosition, moreCapacity);\n if (blocks == null) {\n return;\n }\n int[] newBlocks = new int[value.blocks.length + blocks.length];\n System.arraycopy(value.blocks, 0, newBlocks, 0, value.blocks.length);\n System.arraycopy(blocks, 0, newBlocks, value.blocks.length, blocks.length);\n value = new Value(newBlocks, newDataLength);\n } else {\n value = new Value(value.blocks, newDataLength);\n }\n cache.put(cacheStartPosition, value);\n lru.put(new Key(path, cacheStartPosition), value);\n\n // write data to cache\n ByteBuffer cacheByteBuffer = this.cacheByteBuffer.duplicate();\n int positionDelta = (int) (position - cacheStartPosition);\n int written = 0;\n ByteBuffer[] nioBuffers = data.nioBuffers();\n int[] blocks = value.blocks;\n for (ByteBuffer nioBuffer : nioBuffers) {\n ByteBuf buf = Unpooled.wrappedBuffer(nioBuffer);\n while (buf.readableBytes() > 0) {\n int writePosition = positionDelta + written;\n int block = blocks[writePosition / blockSize];\n cacheByteBuffer.position(block * blockSize + writePosition % blockSize);\n int length = Math.min(buf.readableBytes(), blockSize - writePosition % blockSize);\n cacheByteBuffer.put(buf.slice(buf.readerIndex(), length).nioBuffer());\n buf.skipBytes(length);\n written += length;\n }\n }\n } finally {\n writeLock.unlock();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testMergePut() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMergePut", + "reference": " @Test\n public void testMergePut() throws IOException {\n FileCache fileCache = new FileCache(\"/tmp/file_cache_test\", 10 * 1024, 1024);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 1, 500));\n buf.addComponent(true, genBuf((byte) 2, 500));\n buf.addComponent(true, genBuf((byte) 3, 500));\n fileCache.put(\"test\", 3333L, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(1500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 4, 500));\n buf.addComponent(true, genBuf((byte) 5, 500));\n buf.addComponent(true, genBuf((byte) 6, 500));\n fileCache.put(\"test\", 3333L + 1000, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(2500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 7, 500));\n fileCache.put(\"test\", 3333L + 1000 + 1500, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(3000, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n\n assertTrue(verify(fileCache.get(\"test\", 3333L, 500).get(), (byte) 1));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 500, 500).get(), (byte) 2));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1000, 500).get(), (byte) 4));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1500, 500).get(), (byte) 5));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2000, 500).get(), (byte) 6));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2500, 500).get(), (byte) 7));\n }", + "completions": [ + "```java\n @Test\n public void testMergePut() throws IOException {\n // Given\n String path = \"/test/path\";\n long position = 100;\n ByteBuf data = Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5});\n\n // When\n cache.put(path, position, data);\n\n // Then\n NavigableMap cacheMap = cache.path2cache.get(path);\n assertNotNull(cacheMap);\n Map.Entry entry = cacheMap.floorEntry(position);\n assertNotNull(entry);\n assertEquals(position, entry.getKey().longValue());\n assertEquals(5, entry.getValue().dataLength);\n assertEquals(1, entry.getValue().blocks[0]);\n assertEquals(2, entry.getValue().blocks[1]);\n assertEquals(3, entry.getValue().blocks[2]);\n assertEquals(4, entry.getValue().blocks[3]);\n assertEquals(5, entry.getValue().blocks[4]);\n }\n```\n" + ], + "line": 65, + "token": 508, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 7, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchById()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchById", + "reference": " @Test\n public void testSearchById() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult index =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(index.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchById() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 8, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = new ArrayList<>();\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 9, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_object", + "reference": " @Test\n public void should_clone_serializable_object() {\n Object original = new SerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_object() {\n // Given\n Serializable object = new Serializable() {\n private static final long serialVersionUID = 1L;\n // Implement necessary methods\n };\n\n // When\n Serializable clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertEquals(object, clonedObject);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 10, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() {\n\n\t\tvar text = \"Hello, how are you?\";\n\n\t\tEmbedding response = vertexAiPaLm2Api.embedText(text);\n\n\t\tassertThat(response).isNotNull();\n\t\tassertThat(response.value()).hasSize(768);\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(new EmbeddingResponse(expectedEmbedding));\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Optional maybeFetchErrorIntervalMs(short errorCode, int intervalMs) {\n if (errorCode == Errors.NONE.code())\n return Optional.empty();\n\n int pushIntervalMs;\n String reason;\n\n Errors error = Errors.forCode(errorCode);\n switch (error) {\n case INVALID_REQUEST:\n case INVALID_RECORD:\n case UNSUPPORTED_VERSION:\n pushIntervalMs = Integer.MAX_VALUE;\n reason = \"The broker response indicates the client sent an request that cannot be resolved\"\n + \" by re-trying, hence disable telemetry\";\n break;\n case UNKNOWN_SUBSCRIPTION_ID:\n case UNSUPPORTED_COMPRESSION_TYPE:\n pushIntervalMs = 0;\n reason = error.message();\n break;\n case TELEMETRY_TOO_LARGE:\n case THROTTLING_QUOTA_EXCEEDED:\n reason = error.message();\n pushIntervalMs = (intervalMs != -1) ? intervalMs : ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS;\n break;\n default:\n reason = \"Unwrapped error code\";\n log.error(\"Error code: {}. Unmapped error for telemetry, disable telemetry.\", errorCode);\n pushIntervalMs = Integer.MAX_VALUE;\n }\n\n log.debug(\"Error code: {}, reason: {}. Push interval update to {} ms.\", errorCode, reason, pushIntervalMs);\n return Optional.of(pushIntervalMs);\n }\n```\n###test function signature:\n```java\n @Test\n public void testMaybeFetchErrorIntervalMs()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMaybeFetchErrorIntervalMs", + "reference": " @Test\n public void testMaybeFetchErrorIntervalMs() {\n assertEquals(Optional.empty(), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.NONE.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_REQUEST.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_RECORD.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SUBSCRIPTION_ID.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNSUPPORTED_COMPRESSION_TYPE.code(), -1));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), 20000));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), 20000));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SERVER_ERROR.code(), -1));\n }", + "completions": [ + "```java\n @Test\n public void testMaybeFetchErrorIntervalMs() {\n // Test case 1: errorCode is Errors.NONE.code()\n short errorCode = Errors.NONE.code();\n int intervalMs = 1000;\n Optional result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertFalse(result.isPresent());\n\n // Test case 2: errorCode is Errors.INVALID_REQUEST.code()\n errorCode = Errors.INVALID_REQUEST.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 3: errorCode is Errors.UNKNOWN_SUBSCRIPTION_ID.code()\n errorCode = Errors.UNKNOWN_SUBSCRIPTION_ID.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 4: errorCode is Errors.TELEMETRY_TOO_LARGE.code()\n errorCode = Errors.TELEMETRY_TOO_LARGE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 5: errorCode is Errors.THROTTLING_QUOTA_EXCEEDED.code()\n errorCode = Errors.THROTTLING_QUOTA_EXCEEDED.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 6: errorCode is Errors.UNSUPPORTED_COMPRESSION_TYPE.code()\n errorCode = Errors.UNSUPPORTED_COMPRESSION_TYPE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 7: errorCode is Errors.INVALID_RECORD.code()\n errorCode = Errors.INVALID_RECORD.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 8: errorCode is Errors.UNSUPPORTED_VERSION.code()\n errorCode = Errors.UNSUPPORTED_VERSION.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 9: errorCode is an unmapped error\n errorCode = (short) 9999;\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n }\n```\n" + ], + "line": 43, + "token": 290, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 12, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_non_serializable_object", + "reference": " @Test\n public void should_clone_non_serializable_object() {\n Object original = new NonSerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_non_serializable_object() {\n // Given\n Map nonSerializableMap = new HashMap<>();\n nonSerializableMap.put(\"key1\", new Object());\n nonSerializableMap.put(\"key2\", new Object());\n\n // When\n Map clonedMap = objectCloner.clone(nonSerializableMap);\n\n // Then\n assertNotSame(nonSerializableMap, clonedMap);\n assertNotSame(nonSerializableMap.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(nonSerializableMap.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 13, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueOperationCompletesOperationWithCachedActionResult", + "reference": " @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ActionKey actionKey = DigestUtil.asActionKey(Digest.newBuilder().setHash(\"test\").build());\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"operation-with-cached-action-result\")\n .setActionDigest(actionKey.getDigest())\n .build();\n\n ActionResult actionResult = ActionResult.getDefaultInstance();\n\n when(mockBackplane.getActionResult(eq(actionKey))).thenReturn(actionResult);\n\n Poller poller = mock(Poller.class);\n\n instance.queue(executeEntry, poller, DEFAULT_TIMEOUT).get();\n\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(CACHE_CHECK));\n verify(mockBackplane, never()).putOperation(any(Operation.class), eq(QUEUED));\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(COMPLETED));\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n ActionResult actionResult = ActionResult.newBuilder().setExitCode(0).build();\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(\"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n when(cache.get(any(ActionKey.class))).thenReturn(actionResult);\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n future.get();\n\n verify(poller).pause();\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 14, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueActionFailsQueueEligibility", + "reference": " @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n Directory inputRoot = Directory.newBuilder().build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(false);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_INVALID)\n .setSubject(INVALID_PLATFORM)\n .setDescription(\n \"properties are not valid for queue eligibility: []. If you think your\"\n + \" queue should still accept these poperties without them being\"\n + \" specified in queue configuration, consider configuring the queue\"\n + \" with `allow_unmatched: True`\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(ACTION_DIGEST));\n when(executeEntry.getStdoutStreamName()).thenReturn(STDOUT_STREAM_NAME);\n when(executeEntry.getStderrStreamName()).thenReturn(STDERR_STREAM_NAME);\n when(executeEntry.getOperationName()).thenReturn(OPERATION_NAME);\n when(executeEntry.getRequestMetadata()).thenReturn(REQUEST_METADATA);\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n when(poller.pause()).thenThrow(new RuntimeException(\"Queue eligibility failed\"));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n\n assertThrows(RuntimeException.class, future::get);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 15, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new SerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new SerializableObject(\"name2\"),\n new SerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key\", new SerializableObject());\n SerializableObject object = new SerializableObject(map);\n\n // When\n SerializableObject clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertNotSame(object.getMap(), clonedObject.getMap());\n assertNotSame(object.getMap().get(\"key\"), clonedObject.getMap().get(\"key\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 16, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForSumAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForSumAgg", + "reference": " @Test\n public void testFullIndexSearchForSumAgg() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new SumAggBuilder(\"test\", TEST_SOURCE_LONG_PROPERTY, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalSum internalSum =\n (InternalSum) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n // 1, 3, 4, 5\n assertThat(internalSum.getValue()).isEqualTo(13);\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForSumAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new SumAggBuilder(\"testField\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertEquals(SumAggBuilder.class, result.getAggregation().getClass());\n assertEquals(\"testField\", ((SumAggBuilder) result.getAggregation()).getField());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 17, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void process(String base64AuthenticatorData, String signature, String clientDataJson, Fido2RegistrationData registration,\n Fido2AuthenticationData authenticationEntity) {\n log.debug(\"Registration: {}\", registration);\n\n AuthData authData = authenticatorDataParser.parseAssertionData(base64AuthenticatorData);\n commonVerifiers.verifyRpIdHash(authData, registration.getDomain());\n\n log.debug(\"User verification option: {}\", authenticationEntity.getUserVerificationOption());\n userVerificationVerifier.verifyUserVerificationOption(authenticationEntity.getUserVerificationOption(), authData);\n\n byte[] clientDataHash = DigestUtils.getSha256Digest().digest(base64Service.urlDecode(clientDataJson));\n\n try {\n int counter = authenticatorDataParser.parseCounter(authData.getCounters());\n commonVerifiers.verifyCounter(registration.getCounter(), counter);\n registration.setCounter(counter);\n\n JsonNode uncompressedECPointNode = dataMapperService.cborReadTree(base64Service.urlDecode(registration.getUncompressedECPoint()));\n PublicKey publicKey = coseService.createUncompressedPointFromCOSEPublicKey(uncompressedECPointNode);\n\n log.debug(\"Uncompressed ECpoint node: {}\", uncompressedECPointNode);\n log.debug(\"EC Public key hex: {}\", Hex.encodeHexString(publicKey.getEncoded()));\n log.debug(\"Registration algorithm: {}, default use: -7\", registration.getSignatureAlgorithm());\n authenticatorDataVerifier.verifyAssertionSignature(authData, clientDataHash, signature, publicKey, -7);\n\n } catch (Fido2CompromisedDevice ex) {\n log.error(\"Error compromised device: {}\", ex.getMessage());\n throw ex;\n } catch (Exception ex) {\n log.error(\"Error to check none assertion: {}\", ex.getMessage());\n throw new Fido2RuntimeException(\"Failed to check none assertion: {}\", ex.getMessage(), ex);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "process_ifCborReadTreeThrowException_fido2RuntimeException", + "reference": " @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n String base64AuthenticatorData = \"base64AuthenticatorData_test\";\n String signature = \"signature_test\";\n String clientDataJson = \"clientDataJson_test\";\n Fido2RegistrationData registration = mock(Fido2RegistrationData.class);\n Fido2AuthenticationData authenticationEntity = mock(Fido2AuthenticationData.class);\n\n when(authenticationEntity.getUserVerificationOption()).thenReturn(UserVerification.preferred);\n when(registration.getDomain()).thenReturn(\"domain_test\");\n when(registration.getCounter()).thenReturn(100);\n when(registration.getUncompressedECPoint()).thenReturn(\"uncompressedECPoint_test\");\n\n when(authenticatorDataParser.parseAssertionData(any())).thenReturn(mock(AuthData.class));\n when(base64Service.urlDecode(any(String.class))).thenReturn(\"decode_test\".getBytes());\n when(dataMapperService.cborReadTree(any(byte[].class))).thenThrow(new IOException(\"IOException_test\"));\n\n Fido2RuntimeException ex = assertThrows(Fido2RuntimeException.class, () -> noneAssertionFormatProcessor.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity));\n assertNotNull(ex);\n assertEquals(ex.getMessage(), \"IOException_test\");\n\n verify(log).debug(eq(\"Registration: {}\"), any(Fido2RegistrationData.class));\n verify(log).debug(eq(\"User verification option: {}\"), any(UserVerification.class));\n verify(commonVerifiers).verifyRpIdHash(any(AuthData.class), any(String.class));\n verify(authenticatorDataParser).parseCounter(any());\n verify(log).error(eq(\"Error to check none assertion: {}\"), any(String.class));\n\n verify(log, never()).error(eq(\"Error compromised device: {}\"), any(String.class));\n verifyNoInteractions(coseService, authenticatorDataVerifier);\n verifyNoMoreInteractions(log);\n }", + "completions": [ + "```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n // Arrange\n String base64AuthenticatorData = \"base64AuthenticatorData\";\n String signature = \"signature\";\n String clientDataJson = \"clientDataJson\";\n Fido2RegistrationData registration = new Fido2RegistrationData();\n Fido2AuthenticationData authenticationEntity = new Fido2AuthenticationData();\n\n when(authenticatorDataParser.parseAssertionData(base64AuthenticatorData)).thenReturn(new AuthData());\n when(authenticatorDataParser.parseCounter(any())).thenReturn(1);\n when(dataMapperService.cborReadTree(any())).thenThrow(new IOException());\n\n // Act & Assert\n Assertions.assertThrows(Fido2RuntimeException.class, () -> {\n fido2AuthenticationService.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity);\n });\n }\n```\n" + ], + "line": 38, + "token": 365, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 18, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void setUserData_invalidKey_throws()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "setUserData_invalidKey_throws", + "reference": "\t@Test\n\tpublic void setUserData_invalidKey_throws() {\n\t\t// setup\n\t\tTestParameters parameters = new TestParameters();\n\n\t\t// test\n\t\tfor (String key : new String[] { null, \"\" }) {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, \"test\");\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid key; key must be non-empty, non-null\"),\n\t\t\t\t\tex.getMessage());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@Test(expected = IllegalArgumentException.class)\n\tpublic void setUserData_invalidKey_throws() {\n\t\tTestClass testClass = new TestClass();\n\t\ttestClass.setUserData(\"\", \"value\");\n\t}\n```\n\nIn this test, we are creating an instance of the TestClass and calling the setUserData method with an empty key and a valid value. We expect this to throw an IllegalArgumentException, which is what we specify in the @Test annotation.\n" + ], + "line": 23, + "token": 170, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueDirectoryMissingErrorsOperation", + "reference": " @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ByteString foo = ByteString.copyFromUtf8(\"foo\");\n Digest subdirDigest = DIGEST_UTIL.compute(foo);\n Directory inputRoot =\n Directory.newBuilder()\n .addDirectories(\n DirectoryNode.newBuilder().setName(\"missing-subdir\").setDigest(subdirDigest))\n .build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(true);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_MISSING)\n .setSubject(\"blobs/\" + DigestUtil.toString(subdirDigest))\n .setDescription(\"The directory `/missing-subdir` was not found in the CAS.\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(10);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(DigestUtil.DEFAULT_INSTANCE_NAME + \"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n ActionKey actionKey = DigestUtil.asActionKey(executeEntry.getActionDigest());\n when(checkCacheFuture(actionKey, any(), any())).thenReturn(immediateFuture(true));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n assertThat(future.get()).isEqualTo(null);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 20, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime = Instant.now();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Start recovery with an offset range that is partially unavailable\n long startOffset = 50;\n long endOffset = 150;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry)).isEqualTo(50);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(51);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 100L,\n 200L,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 21, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public List selectHeaderAccept(String[] accepts) {\n if (accepts.length == 0) {\n return null;\n }\n for (String accept : accepts) {\n MediaType mediaType = MediaType.parseMediaType(accept);\n if (isJsonMime(mediaType) && !isProblemJsonMime(accept)) {\n return Collections.singletonList(mediaType);\n }\n }\n return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts));\n }\n```\n###test function signature:\n```java\n @Test\n public void testSelectHeaderAccept()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSelectHeaderAccept", + "reference": " @Test\n public void testSelectHeaderAccept() {\n String[] accepts = {\"application/json\", \"application/xml\"};\n assertEquals(\"application/json\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"APPLICATION/XML\", \"APPLICATION/JSON\"};\n assertEquals(\"APPLICATION/JSON\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"application/xml\", \"application/json; charset=UTF8\"};\n assertEquals(\"application/json; charset=UTF8\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"text/plain\", \"application/xml\"};\n assertEquals(\"text/plain,application/xml\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{};\n assertNull(apiClient.selectHeaderAccept(accepts));\n }", + "completions": [ + "```java\n @Test\n public void testSelectHeaderAccept() {\n // Given\n String[] accepts = {\"application/json\", \"application/problem+json\"};\n // When\n List result = selectHeaderAccept(accepts);\n // Then\n assertEquals(1, result.size());\n assertEquals(\"application/json\", result.get(0).toString());\n }\n```\n" + ], + "line": 23, + "token": 154, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollEventSentOnConsumerPoll", + "reference": " @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n SubscriptionState subscriptions = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);\n consumer = newConsumer(\n mock(FetchBuffer.class),\n new ConsumerInterceptors<>(Collections.emptyList()),\n mock(ConsumerRebalanceListenerInvoker.class),\n subscriptions,\n singletonList(new RoundRobinAssignor()),\n \"group-id\",\n \"client-id\");\n final TopicPartition tp = new TopicPartition(\"topic\", 0);\n final List> records = singletonList(\n new ConsumerRecord<>(\"topic\", 0, 2, \"key1\", \"value1\"));\n doAnswer(invocation -> Fetch.forPartition(tp, records, true))\n .when(fetchCollector)\n .collectFetch(Mockito.any(FetchBuffer.class));\n\n consumer.subscribe(singletonList(\"topic1\"));\n consumer.poll(Duration.ofMillis(100));\n verify(applicationEventHandler).add(any(PollEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n // Arrange\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n consumer.poll(Duration.ofMillis(100));\n\n // Act\n ConsumerRecords records = consumer.poll(Duration.ofMillis(100));\n\n // Assert\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 23, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetMultiplePartitionRecoveriesBehind", + "reference": " @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n final RecoveryTaskMetadata recoveryTask2 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"2\",\n \"2\",\n recoveryStartOffset * 3 + 1,\n recoveryStartOffset * 4,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask2);\n final RecoveryTaskMetadata recoveryTask21 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"21\", \"2\", recoveryStartOffset * 4 + 1, 50000, createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask21);\n await()\n .until(\n () ->\n recoveryTaskStore\n .listSync()\n .containsAll(\n List.of(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(5);\n assertThat(recoveryTasks)\n .contains(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n // Given\n String partitionId = \"partition1\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 24, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture commitStreamObject(apache.rocketmq.controller.v1.S3StreamObject streamObject,\n List compactedObjects) {\n LOGGER.info(\"commitStreamObject with streamObject: {}, compactedObjects: {}\", TextFormat.shortDebugString(streamObject),\n compactedObjects);\n\n CompletableFuture future = new CompletableFuture<>();\n try (SqlSession session = sessionFactory.openSession()) {\n if (streamObject.getObjectId() == S3Constants.NOOP_OBJECT_ID) {\n LOGGER.error(\"S3StreamObject[object-id={}] is null or objectId is unavailable\", streamObject.getObjectId());\n String msg = String.format(\"S3StreamObject[object-id=%d] is null or objectId is unavailable\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.NOT_FOUND_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n\n long committedTs = System.currentTimeMillis();\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n\n // commit object\n if (!commitObject(streamObject.getObjectId(), streamObject.getStreamId(), streamObject.getObjectSize(), session)) {\n String msg = String.format(\"S3StreamObject[object-id=%d] is not ready for commit\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.ILLEGAL_STATE_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n long dataTs = committedTs;\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n dataTs = compactedObjects.stream()\n .map(id -> {\n // mark destroy compacted object\n S3Object object = s3ObjectMapper.getById(id);\n object.setState(S3ObjectState.BOS_WILL_DELETE);\n object.setMarkedForDeletionTimestamp(new Date());\n s3ObjectMapper.markToDelete(object.getId(), new Date());\n\n // update dataTs to the min compacted object's dataTs\n com.automq.rocketmq.metadata.dao.S3StreamObject s3StreamObject =\n s3StreamObjectMapper.getByObjectId(id);\n return s3StreamObject.getBaseDataTimestamp().getTime();\n })\n .min(Long::compareTo).get();\n }\n\n List toCache = new ArrayList<>();\n\n // create a new S3StreamObject to replace committed ones\n if (streamObject.getObjectId() != S3Constants.NOOP_OBJECT_ID) {\n com.automq.rocketmq.metadata.dao.S3StreamObject newS3StreamObj =\n new com.automq.rocketmq.metadata.dao.S3StreamObject();\n newS3StreamObj.setStreamId(streamObject.getStreamId());\n newS3StreamObj.setObjectId(streamObject.getObjectId());\n newS3StreamObj.setObjectSize(streamObject.getObjectSize());\n newS3StreamObj.setStartOffset(streamObject.getStartOffset());\n newS3StreamObj.setEndOffset(streamObject.getEndOffset());\n newS3StreamObj.setBaseDataTimestamp(new Date(dataTs));\n newS3StreamObj.setCommittedTimestamp(new Date(committedTs));\n s3StreamObjectMapper.create(newS3StreamObj);\n toCache.add(newS3StreamObj);\n }\n\n // delete the compactedObjects of S3Stream\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n compactedObjects.forEach(id -> s3StreamObjectMapper.delete(null, null, id));\n }\n session.commit();\n\n // Update Cache\n s3StreamObjectCache.cache(streamObject.getStreamId(), toCache);\n s3StreamObjectCache.onCompact(streamObject.getStreamId(), compactedObjects);\n\n LOGGER.info(\"S3StreamObject[object-id={}] commit success, compacted objects: {}\",\n streamObject.getObjectId(), compactedObjects);\n future.complete(null);\n } catch (Exception e) {\n LOGGER.error(\"CommitStream failed\", e);\n ControllerException ex = new ControllerException(Code.INTERNAL_VALUE, \"CommitStream failed\" + e.getMessage());\n future.completeExceptionally(ex);\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCommitStreamObject() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCommitStreamObject", + "reference": " @Test\n public void testCommitStreamObject() throws IOException {\n long objectId, streamId = 1;\n\n try (S3MetadataService metadataService = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n objectId = metadataService.prepareS3Objects(3, 5).get();\n } catch (Exception e) {\n throw new RuntimeException(e);\n }\n\n S3StreamObject news3StreamObject = S3StreamObject.newBuilder()\n .setObjectId(objectId + 2)\n .setStreamId(streamId)\n .setObjectSize(111L)\n .build();\n\n try (SqlSession session = this.getSessionFactory().openSession()) {\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n buildS3StreamObjs(objectId, 2, 3, 100L).forEach(s3StreamObjectMapper::create);\n session.commit();\n }\n\n try (S3MetadataService service = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n List compactedObjects = new ArrayList<>();\n compactedObjects.add(objectId);\n compactedObjects.add(objectId + 1);\n service.commitStreamObject(news3StreamObject, compactedObjects);\n }\n\n try (SqlSession session = getSessionFactory().openSession()) {\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3Object s3Object = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object.getState());\n\n S3Object s3Object1 = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object1.getState());\n\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n for (long index = objectId; index < objectId + 2; index++) {\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(index);\n Assertions.assertNull(object);\n }\n\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(objectId + 2);\n Assertions.assertEquals(111L, object.getObjectSize());\n Assertions.assertEquals(streamId, object.getStreamId());\n Assertions.assertTrue(object.getBaseDataTimestamp().getTime() > 0);\n Assertions.assertTrue(object.getCommittedTimestamp().getTime() > 0);\n }\n }", + "completions": [ + "```java\n @Test\n public void testCommitStreamObject() throws IOException {\n // Given\n apache.rocketmq.controller.v1.S3StreamObject streamObject = new apache.rocketmq.controller.v1.S3StreamObject();\n streamObject.setObjectId(1L);\n streamObject.setStreamId(2L);\n streamObject.setObjectSize(3L);\n streamObject.setStartOffset(4L);\n streamObject.setEndOffset(5L);\n\n List compactedObjects = Arrays.asList(6L, 7L, 8L);\n\n // When\n CompletableFuture future = controller.commitStreamObject(streamObject, compactedObjects);\n\n // Then\n assertDoesNotThrow(() -> future.get());\n }\n```\n" + ], + "line": 85, + "token": 664, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 25, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyTooManyResults", + "reference": " @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(totalHits.value).thenReturn(2L);\n Mockito.when(hits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(response.getHits()).thenReturn(hits);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processService.getProcessByKey(processDefinitionKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Expected exception message to match\",\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n e.getMessage());\n }\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 26, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String doGet(String url) {\n try (CloseableHttpClient client = HttpClients.createDefault()) {\n HttpGet request = new HttpGet(url);\n HttpResponse response = client.execute(request);\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n return EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n }\n } catch (IOException e) {\n log.error(\"DoGet exception.\", e);\n\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n void testDoGet()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDoGet", + "reference": " @Test\n void testDoGet() {\n String result = HttpUtil.doGet(\"http://baidu.com\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testDoGet() {\n String url = \"http://example.com\";\n String expectedResponse = \"Hello, world!\";\n\n // Mock the HttpClient and HttpResponse\n CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);\n HttpResponse mockResponse = Mockito.mock(HttpResponse.class);\n StatusLine mockStatusLine = Mockito.mock(StatusLine.class);\n HttpEntity mockEntity = Mockito.mock(HttpEntity.class);\n\n // Set up the mock objects to return the expected values\n Mockito.when(mockClient.execute(Mockito.any(HttpGet.class))).thenReturn(mockResponse);\n Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);\n Mockito.when(mockStatusLine.getStatusCode()).thenReturn(HttpStatus.SC_OK);\n Mockito.when(mockResponse.getEntity()).thenReturn(mockEntity);\n Mockito.when(EntityUtils.toString(mockEntity, StandardCharsets.UTF_8)).thenReturn(expectedResponse);\n\n // Call the method under test\n String actualResponse = YourClassName.doGet(url);\n\n // Verify the result\n Assertions.assertEquals(expectedResponse, actualResponse);\n }\n```\nPlease replace `YourClassName` with the actual name of the class where the `doGet` method is located.\n" + ], + "line": 23, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n dpopService.validateDpopThumprint(null, \"any\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n // Given\n String existingThumprint = \"\";\n String requestThumprint = \"requestThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n try {\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), e.getResponse().getStatus());\n assertEquals(\"Thumprint does not match\", e.getMessage());\n }\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDataMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDataMapTest", + "reference": " @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"hera_log_doris_table_2_11\";\n for (int i = 0; i < 1; i++) {\n String jsonString = \"{\\\"linenumber\\\":1256,\\\"tailId\\\":90219,\\\"mqtag\\\":\\\"tags_4_13_90219\\\",\\\"filename\\\":\\\"/home/work/log/log-agent/server.log\\\",\\\"tail\\\":\\\"demo-client-agent\\\",\\\"mqtopic\\\":\\\"90219_hera-demo-client\\\",\\\"message\\\":\\\"2024-01-10 19:34:40,360|INFO ||NettyClientPublicExecutor_1|c.x.mone.log.agent.rpc.task.PingTask|83|ping res: log-agent-server:2022-12-05:0.0.2->2024-01-10 19:34:40 358->10.53.129.250:9899\\\",\\\"logstore\\\":\\\"测试doris日志\\\",\\\"logip\\\":\\\"10.53.129.176\\\",\\\"timestamp\\\":1704886481241}\";\n\n Map map = gson.fromJson(jsonString, Map.class);\n dorisStreamLoad.sendData(database, tableName, columnListNew, map);\n }\n }", + "completions": [ + "```java\n @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 29, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyNoResults", + "reference": " @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final NotFoundException exception =\n Assertions.assertThrows(\n NotFoundException.class, () -> processService.getProcessByKey(processDefinitionKey));\n\n // Then\n Assertions.assertEquals(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 30, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String getDiagramByKey(Long processDefinitionKey) {\n final IdsQueryBuilder q = idsQuery().addIds(processDefinitionKey.toString());\n\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n\n if (response.getHits().getTotalHits().value == 1) {\n final Map result = response.getHits().getHits()[0].getSourceAsMap();\n return (String) result.get(BPMN_XML);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with id '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with id '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\n \"Exception occurred, while obtaining the process diagram: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetDiagramByKeyTooManyResults", + "reference": " @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getDiagramByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 1L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHits.Relation.EQUAL_TO));\n\n // When\n try {\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n processDiagramService.getDiagramByKey(processDefinitionKey);\n } catch (NotFoundException e) {\n // Then\n Assert.assertEquals(\n \"Could not find unique process with id '\" + processDefinitionKey + \"'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 37, + "token": 309, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 31, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n void shutdown() {\n LOG.info(\"Running shutdown hook.\");\n try {\n serviceManager.stopAsync().awaitStopped(30, TimeUnit.SECONDS);\n } catch (Exception e) {\n // stopping timed out\n LOG.error(\"ServiceManager shutdown timed out\", e);\n }\n try {\n curatorFramework.unwrap().close();\n } catch (Exception e) {\n LOG.error(\"Error while closing curatorFramework \", e);\n }\n LOG.info(\"Shutting down LogManager\");\n LogManager.shutdown();\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexerShutdownTwice() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexerShutdownTwice", + "reference": " @Test\n public void testIndexerShutdownTwice() throws Exception {\n startKafkaServer();\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // Create a live partition for this partiton\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTimeMs = 1;\n final long endTimeMs = 100;\n final long maxOffset = 30;\n SnapshotMetadata livePartition0 =\n new SnapshotMetadata(\n name + \"live0\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"0\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition0);\n\n SnapshotMetadata livePartition1 =\n new SnapshotMetadata(\n name + \"live1\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"1\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition1);\n\n final SnapshotMetadata partition0 =\n new SnapshotMetadata(name, path, startTimeMs, endTimeMs, maxOffset, \"0\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition0);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .containsOnly(livePartition1, livePartition0, partition0);\n\n // Empty consumer offset since there is no prior consumer.\n astraIndexer =\n new AstraIndexer(\n chunkManagerUtil.chunkManager,\n curatorFramework,\n makeIndexerConfig(50),\n getKafkaConfig(),\n metricsRegistry);\n astraIndexer.startAsync();\n astraIndexer.awaitRunning(DEFAULT_START_STOP_DURATION);\n await().until(() -> kafkaServer.getConnectedConsumerGroups() == 1);\n\n // Produce more messages since the recovery task is created for head.\n produceMessagesToKafka(kafkaServer.getBroker(), startTime);\n\n consumeMessagesAndSearchMessagesTest(100, 1);\n\n // Live snapshot is deleted, recovery task is created.\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(livePartition1, partition0);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isEqualTo(4);\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).size()).isEqualTo(1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(searchMetadataStore).size()).isEqualTo(1);\n RecoveryTaskMetadata recoveryTask1 =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).get(0);\n assertThat(recoveryTask1.startOffset).isEqualTo(31);\n assertThat(recoveryTask1.endOffset).isEqualTo(99);\n assertThat(recoveryTask1.partitionId).isEqualTo(\"0\");\n\n // Shutting down is idempotent. So, doing it twice shouldn't throw an error.\n astraIndexer.shutDown();\n astraIndexer.shutDown();\n astraIndexer = null;\n }", + "completions": [ + "```java\n @Test\n public void testIndexerShutdownTwice() throws Exception {\n // Given\n Indexer indexer = new Indexer();\n indexer.serviceManager = Mockito.mock(ServiceManager.class);\n indexer.curatorFramework = Mockito.mock(CuratorFramework.class);\n Mockito.when(indexer.serviceManager.stopAsync()).thenReturn(Mockito.mock(ServiceManager.class));\n Mockito.when(indexer.serviceManager.awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class))).thenThrow(new TimeoutException());\n Mockito.when(indexer.curatorFramework.unwrap()).thenReturn(Mockito.mock(Closeable.class));\n Mockito.doThrow(new IOException()).when(indexer.curatorFramework.unwrap()).close();\n\n // When\n indexer.shutdown();\n\n // Then\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).stopAsync();\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class));\n Mockito.verify(indexer.curatorFramework.unwrap(), Mockito.times(1)).close();\n }\n```\n" + ], + "line": 27, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(addressList,\"test\",\"test\",\"test\", SystemEnum.AGENT, MsgTypeEnum.ALARM);\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = null;\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 33, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyWithException", + "reference": " @Test\n public void testGetProcessByKeyWithException() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(OperateRuntimeException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(totalHits.value).thenReturn(2L);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final Executable executable = () -> processService.getProcessByKey(processDefinitionKey);\n\n // Then\n final NotFoundException exception =\n Assertions.assertThrows(NotFoundException.class, executable);\n Assertions.assertEquals(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 34, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testConcurrentSearches() throws InterruptedException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testConcurrentSearches", + "reference": " @Test\n public void testConcurrentSearches() throws InterruptedException {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n AtomicInteger searchFailures = new AtomicInteger(0);\n AtomicInteger statsFailures = new AtomicInteger(0);\n AtomicInteger searchExceptions = new AtomicInteger(0);\n AtomicInteger successfulRuns = new AtomicInteger(0);\n\n Runnable searchRun =\n () -> {\n for (int i = 0; i < 100; i++) {\n try {\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n 0L,\n MAX_TIME,\n 100,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n if (babies.hits.size() != 2) {\n searchFailures.addAndGet(1);\n } else {\n successfulRuns.addAndGet(1);\n }\n } catch (Exception e) {\n searchExceptions.addAndGet(1);\n }\n }\n };\n\n Thread t1 = new Thread(searchRun);\n Thread t2 = new Thread(searchRun);\n t1.start();\n t2.start();\n t1.join();\n t2.join();\n assertThat(searchExceptions.get()).isEqualTo(0);\n assertThat(statsFailures.get()).isEqualTo(0);\n assertThat(searchFailures.get()).isEqualTo(0);\n assertThat(successfulRuns.get()).isEqualTo(200);\n }", + "completions": [ + "```java\n @Test\n public void testConcurrentSearches() throws InterruptedException {\n // Create a fixed thread pool with 10 threads\n ExecutorService executorService = Executors.newFixedThreadPool(10);\n\n // Create a list to hold the futures\n List>> futures = new ArrayList<>();\n\n // Submit 100 search tasks to the executor service\n for (int i = 0; i < 100; i++) {\n // Create a search task\n Callable> searchTask = () -> {\n // Generate a random dataset\n String dataset = UUID.randomUUID().toString();\n\n // Generate a random query string\n String queryStr = UUID.randomUUID().toString();\n\n // Generate a random start time\n long startTimeMsEpoch = System.currentTimeMillis() - ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random end time\n long endTimeMsEpoch = startTimeMsEpoch + ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random number of hits requested\n int howMany = ThreadLocalRandom.current().nextInt(100);\n\n // Generate a random aggregation builder\n AggBuilder aggBuilder = new AggBuilder();\n\n // Perform the search\n return logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n };\n\n // Submit the search task to the executor service and add the future to the list\n futures.add(executorService.submit(searchTask));\n }\n\n // Wait for all search tasks to complete\n for (Future> future : futures) {\n future.get();\n }\n\n // Shutdown the executor service\n executorService.shutdown();\n }\n```\n" + ], + "line": 85, + "token": 608, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 35, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String sendGetData(String path, Map headMap) {\n log.info(\"request tcm, path:{},headMap:{} \",path,headMap);\n String result = \"\";\n CloseableHttpResponse response =null;\n try(CloseableHttpClient httpClient = HttpClients.createDefault()) {\n // 创建get方式请求对象\n HttpGet httpGet = new HttpGet(path);\n if (null != headMap) {\n for (Map.Entry entry : headMap.entrySet()) {\n httpGet.addHeader(entry.getKey(), entry.getValue());\n }\n }\n // 通过请求对象获取响应对象\n response = httpClient.execute(httpGet);\n\n // 获取结果实体\n // 判断网络连接状态码是否正常(0--200都数正常)\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n result = EntityUtils.toString(response.getEntity(), UTF_8);\n } else {\n log.error(\"get请求传输 异常.url:{}, headMap:{}\", path, headMap);\n }\n } catch (ClientProtocolException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n } catch (IOException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n }\n finally {\n try {\n if (null!=response){\n response.close();\n }\n } catch (IOException e) {\n log.error(\"关闭response 异常 \", e);\n }\n }\n log.debug(result);\n return result;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendGetData()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendGetData", + "reference": " @Test\n void testSendGetData() {\n final Map headMap = new HashMap<>();\n final String result = HttpUtils.sendGetData(\"path\", headMap);\n assertThat(result).isEqualTo(\"\");\n }", + "completions": [ + "```java\n @Test\n void testSendGetData() {\n String path = \"http://localhost:8080/api/test\";\n Map headMap = new HashMap<>();\n headMap.put(\"Content-Type\", \"application/json\");\n headMap.put(\"Authorization\", \"Bearer your_token\");\n\n String result = sendGetData(path, headMap);\n\n // Add assertions to verify the result\n // For example, if the result is expected to be a JSON string, you can use a JSON library to parse the result and assert the values\n // Here is an example using the JSON library Jackson\n ObjectMapper objectMapper = new ObjectMapper();\n JsonNode jsonNode = objectMapper.readTree(result);\n assertEquals(\"expectedValue\", jsonNode.get(\"key\").asText());\n }\n```\nPlease replace \"your_token\" with the actual token you want to use for the test. Also, replace \"expectedValue\" and \"key\" with the actual expected value and key from the JSON response.\n" + ], + "line": 51, + "token": 367, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 36, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ByteString retainKeyPrefix(ByteString tenantNS, List topicFilterLevels) {\n ByteString prefix = ByteString.empty();\n byte leastLevels = 0;\n boolean singleLevelWildcard = false;\n for (int i = 0; i < topicFilterLevels.size(); i++) {\n String tfl = topicFilterLevels.get(i);\n if (\"+\".equals(tfl)) {\n leastLevels++;\n singleLevelWildcard = true;\n continue;\n }\n if (\"#\".equals(tfl)) {\n break;\n }\n leastLevels++;\n if (!singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(tfl));\n }\n if (i + 1 < topicFilterLevels.size()) {\n if (!topicFilterLevels.get(i + 1).equals(\"#\") && !singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(NUL));\n }\n }\n }\n return tenantNS.concat(unsafeWrap(new byte[] {leastLevels})).concat(prefix);\n }\n```\n###test function signature:\n```java\n @Test\n public void testRetainKeyPrefix()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRetainKeyPrefix", + "reference": " @Test\n public void testRetainKeyPrefix() {\n String tenantId = \"tenantA\";\n ByteString tenantNS = tenantNS(tenantId);\n assertEquals(retainKeyPrefix(tenantNS, parse(\"#\", false)), tenantNS.concat(levelByte(0)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/#\", false)),\n tenantNS.concat(levelByte(1).concat(copyFromUtf8(\"a\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/+\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(\"a\").concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/b\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)).concat(copyFromUtf8(\"b\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/#\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+\", false)),\n tenantNS.concat(\n levelByte(3).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+/+\", false)),\n tenantNS.concat(\n levelByte(4).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/\", false)),\n tenantNS.concat(levelByte(4).concat(copyFromUtf8(NUL))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/+/\", false)),\n tenantNS.concat(levelByte(5).concat(copyFromUtf8(NUL))));\n }", + "completions": [ + "```java\n @Test\n public void testRetainKeyPrefix() {\n ByteString tenantNS = ByteString.copyFromUtf8(\"tenant1\");\n List topicFilterLevels = Arrays.asList(\"level1\", \"level2\", \"+\", \"#\");\n ByteString expectedPrefix = ByteString.copyFromUtf8(\"tenant1\\u0001level1\\u0000level2\");\n ByteString actualPrefix = retainKeyPrefix(tenantNS, topicFilterLevels);\n assertEquals(expectedPrefix, actualPrefix);\n }\n```\n" + ], + "line": 37, + "token": 255, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 37, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n @Transactional(rollbackFor = Exception.class)\n public void updateFileConfigMaster(Long id) {\n // 校验存在\n validateFileConfigExists(id);\n // 更新其它为非 master\n fileConfigMapper.updateBatch(new FileConfigDO().setMaster(false));\n // 更新\n fileConfigMapper.updateById(new FileConfigDO().setId(id).setMaster(true));\n // 发送刷新配置的消息\n TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {\n\n @Override\n public void afterCommit() {\n fileConfigProducer.sendFileConfigRefreshMessage();\n }\n\n });\n }\n```\n###test function signature:\n```java\n @Test\n public void testUpdateFileConfigMaster_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testUpdateFileConfigMaster_success", + "reference": " @Test\n public void testUpdateFileConfigMaster_success() {\n // mock 数据\n FileConfigDO dbFileConfig = randomFileConfigDO().setMaster(false);\n fileConfigMapper.insert(dbFileConfig);// @Sql: 先插入出一条存在的数据\n FileConfigDO masterFileConfig = randomFileConfigDO().setMaster(true);\n fileConfigMapper.insert(masterFileConfig);// @Sql: 先插入出一条存在的数据\n\n // 调用\n fileConfigService.updateFileConfigMaster(dbFileConfig.getId());\n // 断言数据\n assertTrue(fileConfigMapper.selectById(dbFileConfig.getId()).getMaster());\n assertFalse(fileConfigMapper.selectById(masterFileConfig.getId()).getMaster());\n // verify 调用\n verify(fileConfigProducer).sendFileConfigRefreshMessage();\n }", + "completions": [ + "```java\n @Test\n public void testUpdateFileConfigMaster_success() {\n // 准备数据\n FileConfigDO fileConfigDO = new FileConfigDO();\n fileConfigDO.setId(1L);\n fileConfigDO.setMaster(false);\n fileConfigMapper.insert(fileConfigDO);\n\n // 执行方法\n fileConfigService.updateFileConfigMaster(1L);\n\n // 验证结果\n FileConfigDO updatedFileConfigDO = fileConfigMapper.selectById(1L);\n Assert.assertTrue(updatedFileConfigDO.getMaster());\n\n // 清理数据\n fileConfigMapper.deleteById(1L);\n }\n```\n" + ], + "line": 28, + "token": 164, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\t@Override\n\tpublic void add(List documents) {\n\n\t\tAssert.notNull(documents, \"The document list should not be null.\");\n\t\tif (CollectionUtils.isEmpty(documents)) {\n\t\t\treturn; // nothing to do;\n\t\t}\n\n\t\tfinal var searchDocuments = documents.stream().map(document -> {\n\t\t\tfinal var embeddings = this.embeddingClient.embed(document);\n\t\t\tSearchDocument searchDocument = new SearchDocument();\n\t\t\tsearchDocument.put(ID_FIELD_NAME, document.getId());\n\t\t\tsearchDocument.put(EMBEDDING_FIELD_NAME, embeddings);\n\t\t\tsearchDocument.put(CONTENT_FIELD_NAME, document.getContent());\n\t\t\tsearchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());\n\n\t\t\t// Add the filterable metadata fields as top level fields, allowing filler\n\t\t\t// expressions on them.\n\t\t\tfor (MetadataField mf : this.filterMetadataFields) {\n\t\t\t\tif (document.getMetadata().containsKey(mf.name())) {\n\t\t\t\t\tsearchDocument.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name()));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn searchDocument;\n\t\t}).toList();\n\n\t\tIndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);\n\n\t\tfor (IndexingResult indexingResult : result.getResults()) {\n\t\t\tAssert.isTrue(indexingResult.isSucceeded(),\n\t\t\t\t\tString.format(\"Document with key %s upload is not successfully\", indexingResult.getKey()));\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void addAndSearchTest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "addAndSearchTest", + "reference": "\t@Test\n\tpublic void addAndSearchTest() {\n\n\t\tcontextRunner.run(context -> {\n\n\t\t\tVectorStore vectorStore = context.getBean(VectorStore.class);\n\n\t\t\tvectorStore.add(documents);\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\t\t\t}, hasSize(1));\n\n\t\t\tList results = vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\n\t\t\tassertThat(results).hasSize(1);\n\t\t\tDocument resultDoc = results.get(0);\n\t\t\tassertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());\n\t\t\tassertThat(resultDoc.getContent()).contains(\"The Great Depression (1929–1939) was an economic shock\");\n\t\t\tassertThat(resultDoc.getMetadata()).hasSize(2);\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"meta2\");\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"distance\");\n\n\t\t\t// Remove all documents from the store\n\t\t\tvectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Hello\").withTopK(1));\n\t\t\t}, hasSize(0));\n\t\t});\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void addAndSearchTest() {\n\t\t// Given\n\t\tList documents = new ArrayList<>();\n\t\tDocument document1 = new Document();\n\t\tdocument1.setId(\"1\");\n\t\tdocument1.setContent(\"This is a test document\");\n\t\tdocument1.setMetadata(new HashMap<>());\n\t\tdocument1.getMetadata().put(\"author\", \"John Doe\");\n\t\tdocument1.getMetadata().put(\"date\", \"2022-01-01\");\n\t\tdocuments.add(document1);\n\n\t\tDocument document2 = new Document();\n\t\tdocument2.setId(\"2\");\n\t\tdocument2.setContent(\"This is another test document\");\n\t\tdocument2.setMetadata(new HashMap<>());\n\t\tdocument2.getMetadata().put(\"author\", \"Jane Doe\");\n\t\tdocument2.getMetadata().put(\"date\", \"2022-01-02\");\n\t\tdocuments.add(document2);\n\n\t\t// When\n\t\tsearchService.add(documents);\n\n\t\t// Then\n\t\tSearchRequest searchRequest = new SearchRequest();\n\t\tsearchRequest.setQuery(\"test\");\n\t\tsearchRequest.setFilter(new HashMap<>());\n\t\tsearchRequest.getFilter().put(\"author\", \"John Doe\");\n\t\tSearchResponse searchResponse = searchService.search(searchRequest);\n\n\t\tAssert.assertEquals(1, searchResponse.getDocuments().size());\n\t\tAssert.assertEquals(\"1\", searchResponse.getDocuments().get(0).getId());\n\t}\n```\n" + ], + "line": 39, + "token": 312, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 39, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {\n long startTime = System.currentTimeMillis();\n\n // 校验参数\n try {\n validateParams(metrics);\n } catch (Exception e) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(e.getMessage());\n return;\n }\n\n HttpContext httpContext = createHttpContext(metrics.getNginx());\n HttpUriRequest request = createHttpRequest(metrics.getNginx());\n try (CloseableHttpResponse response = CommonHttpClient.getHttpClient().execute(request, httpContext)){\n // 发起http请求,获取响应数据\n int statusCode = response.getStatusLine().getStatusCode();\n if (statusCode != SUCCESS_CODE) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(\"StatusCode \" + statusCode);\n return;\n }\n String resp = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n\n Long responseTime = System.currentTimeMillis() - startTime;\n // 根据metrics name选择调用不同解析方法\n if (NGINX_STATUS_NAME.equals(metrics.getName()) || AVAILABLE.equals(metrics.getName())) {\n parseNginxStatusResponse(builder, resp, metrics, responseTime);\n } else if (REQ_STATUS_NAME.equals(metrics.getName())) {\n parseReqStatusResponse(builder, resp, metrics, responseTime);\n }\n } catch (Exception e) {\n String errorMsg = CommonUtil.getMessageFromThrowable(e);\n log.info(errorMsg);\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(errorMsg);\n } finally {\n if (request != null) {\n request.abort();\n }\n }\n\n }\n```\n###test function signature:\n```java\n @Test\n public void testNginxCollect()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNginxCollect", + "reference": " @Test\n public void testNginxCollect() {\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n long monitorId = 999;\n String app = \"testNginx\";\n\n Metrics metrics = new Metrics();\n metrics.setNginx(nginxProtocol);\n nginxCollect.collect(builder, monitorId, app, metrics);\n }", + "completions": [ + "```java\n @Test\n public void testNginxCollect() {\n // 创建一个Metrics对象\n Metrics metrics = new Metrics();\n metrics.setName(\"nginx_status\");\n metrics.setNginx(new Nginx(\"http://localhost:8080/nginx_status\"));\n\n // 创建一个CollectRep.MetricsData.Builder对象\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n\n // 创建一个NginxCollect对象\n NginxCollect nginxCollect = new NginxCollect();\n\n // 调用collect方法\n nginxCollect.collect(builder, 1, \"test\", metrics);\n\n // 验证结果\n assertEquals(CollectRep.Code.SUCCESS, builder.getCode());\n assertNotNull(builder.getData());\n assertTrue(builder.getData().getMetrics().containsKey(\"active_connections\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"accepts\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"handled\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"requests\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"reading\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"writing\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"waiting\"));\n }\n```\n" + ], + "line": 51, + "token": 335, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 40, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Transactional(rollbackOn = Exception.class)\n public void submit(@NonNull PartitionPlanConfig partitionPlanConfig)\n throws SchedulerException, ClassNotFoundException {\n Long databaseId = partitionPlanConfig.getDatabaseId();\n Validate.notNull(databaseId, \"DatabaseId can not be null\");\n // disable all related partition plan task\n Database database = this.databaseService.detail(databaseId);\n disablePartitionPlan(database.getId());\n PartitionPlanEntity partitionPlanEntity = modelToEntity(partitionPlanConfig);\n partitionPlanEntity = this.partitionPlanRepository.save(partitionPlanEntity);\n if (!partitionPlanConfig.isEnabled()\n || CollectionUtils.isEmpty(partitionPlanConfig.getPartitionTableConfigs())) {\n log.info(\"Partition plan is disabled or table config is empty, do nothing and return\");\n return;\n }\n Validate.isTrue(partitionPlanConfig.getCreationTrigger() != null, \"Creation trigger can not be null\");\n if (partitionPlanConfig.getDroppingTrigger() == null) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(partitionPlanConfig.getPartitionTableConfigs(),\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n return;\n }\n Map> strategy2TblCfgs =\n partitionPlanConfig.getPartitionTableConfigs().stream().flatMap(tableConfig -> {\n Map> strategy2Cfgs =\n tableConfig.getPartitionKeyConfigs().stream()\n .collect(Collectors.groupingBy(PartitionPlanKeyConfig::getStrategy));\n return strategy2Cfgs.values().stream().map(cfgs -> {\n PartitionPlanTableConfig cfg = new PartitionPlanTableConfig();\n cfg.setPartitionKeyConfigs(cfgs);\n cfg.setTableName(tableConfig.getTableName());\n cfg.setEnabled(tableConfig.isEnabled());\n cfg.setPartitionNameInvoker(tableConfig.getPartitionNameInvoker());\n cfg.setPartitionNameInvokerParameters(tableConfig.getPartitionNameInvokerParameters());\n return cfg;\n });\n }).collect(Collectors.groupingBy(cfg -> cfg.getPartitionKeyConfigs().get(0).getStrategy()));\n List createConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.CREATE);\n List dropConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.DROP);\n if (CollectionUtils.isNotEmpty(createConfigs)) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(createConfigs,\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n if (CollectionUtils.isNotEmpty(dropConfigs)) {\n ScheduleEntity dropScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getDroppingTrigger());\n createPartitionPlanTables(dropConfigs,\n partitionPlanEntity.getId(), dropScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "submit_bothCreateAndDropTrigger_submitSucceed", + "reference": " @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n PartitionPlanTableConfig tableConfig = new PartitionPlanTableConfig();\n tableConfig.setTableName(MYSQL_REAL_RANGE_TABLE_NAME);\n tableConfig.setPartitionNameInvoker(\"CUSTOM_PARTITION_NAME_GENERATOR\");\n SqlExprBasedGeneratorConfig config = new SqlExprBasedGeneratorConfig();\n config.setGenerateExpr(\"concat('p', date_format(from_unixtime(unix_timestamp(\"\n + \"STR_TO_DATE(20240125, '%Y%m%d')) + \"\n + PartitionPlanVariableKey.INTERVAL.getVariable() + \"), '%Y%m%d'))\");\n config.setIntervalGenerateExpr(\"86400\");\n tableConfig.setPartitionNameInvokerParameters(getSqlExprBasedNameGeneratorParameters(config));\n PartitionPlanKeyConfig c3Create = getMysqlc3CreateConfig();\n PartitionPlanKeyConfig datekeyCreate = getMysqldatekeyCreateConfig();\n PartitionPlanKeyConfig dropConfig = getDropConfig();\n tableConfig.setPartitionKeyConfigs(Arrays.asList(c3Create, datekeyCreate, dropConfig));\n\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setPartitionTableConfigs(Collections.singletonList(tableConfig));\n partitionPlanConfig.setFlowInstanceId(1L);\n partitionPlanConfig.setTimeoutMillis(180000L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setDatabaseId(1L);\n\n long t1 = System.currentTimeMillis();\n long t2 = System.currentTimeMillis() + 1000;\n TriggerConfig createTrigger = TestRandom.nextObject(TriggerConfig.class);\n createTrigger.setStartAt(new Date(t1));\n partitionPlanConfig.setCreationTrigger(createTrigger);\n\n TriggerConfig dropTrigger = TestRandom.nextObject(TriggerConfig.class);\n dropTrigger.setStartAt(new Date(t2));\n partitionPlanConfig.setDroppingTrigger(dropTrigger);\n Database database = TestRandom.nextObject(Database.class);\n database.setId(1L);\n Mockito.when(this.databaseService.detail(1L)).thenReturn(database);\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t1;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t2;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.doNothing().when(this.scheduleService).enable(Mockito.isA(ScheduleEntity.class));\n this.partitionPlanScheduleService.submit(partitionPlanConfig);\n List actuals = this.partitionPlanTableRepository.findAll();\n Set scheduleIds = actuals.stream()\n .map(PartitionPlanTableEntity::getScheduleId).collect(Collectors.toSet());\n Assert.assertEquals(scheduleIds.size(), 2);\n }", + "completions": [ + "```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n // Given\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setDatabaseId(1L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanConfig.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanConfig.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanConfig.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanConfig.setTaskId(\"taskId\");\n partitionPlanConfig.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanConfig.setTimeoutMillis(1000L);\n\n Database database = new Database();\n database.setId(1L);\n database.setName(\"database\");\n database.setType(DatabaseType.MYSQL);\n database.setHost(\"localhost\");\n database.setPort(3306);\n database.setUsername(\"root\");\n database.setPassword(\"password\");\n database.setDatabaseName(\"test\");\n database.setProperties(\"{}\");\n database.setStatus(DatabaseStatus.ONLINE);\n database.setCreateTime(new Date());\n database.setUpdateTime(new Date());\n when(databaseService.detail(1L)).thenReturn(database);\n\n PartitionPlanEntity partitionPlanEntity = new PartitionPlanEntity();\n partitionPlanEntity.setId(1L);\n partitionPlanEntity.setDatabaseId(1L);\n partitionPlanEntity.setEnabled(true);\n partitionPlanEntity.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanEntity.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanEntity.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanEntity.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanEntity.setTaskId(\"taskId\");\n partitionPlanEntity.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanEntity.setTimeoutMillis(1000L);\n when(partitionPlanRepository.save(any(PartitionPlanEntity.class))).thenReturn(partitionPlanEntity);\n\n ScheduleEntity createScheduleEntity = new ScheduleEntity();\n createScheduleEntity.setId(1L);\n createScheduleEntity.setDatabaseId(1L);\n createScheduleEntity.setTriggerConfig(new TriggerConfig(TriggerType" + ], + "line": 70, + "token": 530, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 41, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAllQueryWithFullTextSearchDisabled", + "reference": " @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n Instant time = Instant.now();\n Trace.KeyValue customField =\n Trace.KeyValue.newBuilder()\n .setVStr(\"value\")\n .setKey(\"customField\")\n .setFieldType(Schema.SchemaFieldType.KEYWORD)\n .build();\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time, List.of(customField)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n SearchResult termQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"customField:value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(termQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermStrQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermStrQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermNumericQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermNumericQuery.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"*\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNull(result.getInternalAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 42, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollExecutedCommitAsyncCallbacks", + "reference": " @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks() {\n consumer = newConsumer();\n MockCommitCallback callback = new MockCommitCallback();\n completeCommitAsyncApplicationEventSuccessfully();\n doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n completeFetchedCommittedOffsetApplicationEventSuccessfully(mkMap());\n\n consumer.assign(Collections.singleton(new TopicPartition(\"foo\", 0)));\n assertDoesNotThrow(() -> consumer.commitAsync(new HashMap<>(), callback));\n assertMockCommitCallbackInvoked(() -> consumer.poll(Duration.ZERO),\n callback,\n null);\n }", + "completions": [ + "```java\n @Test\n public void testPoll() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertNotNull(records);\n assertEquals(0, records.count());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 43, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static URI resolve(URI base, URI ref) {\n String path = ref.getRawPath();\n if (path != null && !path.isEmpty()) {\n path = base.getRawPath();\n if (path == null || !path.endsWith(\"/\")) {\n try {\n base = new URI(base.getScheme(), base.getAuthority(), base.getPath() + '/', null, null);\n } catch (URISyntaxException e) {\n throw new IllegalStateException(e);\n }\n }\n }\n return URIUtils.resolve(base, ref);\n }\n```\n###test function signature:\n```java\n @Test\n void testResolve_BasePathTrailingSlash()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testResolve_BasePathTrailingSlash", + "reference": " @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://host/sub/dir/\");\n assertEquals(\"http://host/sub/dir/file.jar\", resolve(base, \"file.jar\"));\n assertEquals(\"http://host/sub/dir/dir/file.jar\", resolve(base, \"dir/file.jar\"));\n assertEquals(\"http://host/sub/dir/?arg=val\", resolve(base, \"?arg=val\"));\n assertEquals(\"http://host/sub/dir/file?arg=val\", resolve(base, \"file?arg=val\"));\n assertEquals(\"http://host/sub/dir/dir/file?arg=val\", resolve(base, \"dir/file?arg=val\"));\n }", + "completions": [ + "```java\n @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://example.com/path/\");\n URI ref = URI.create(\"http://example.com/path/file.txt\");\n URI expected = URI.create(\"http://example.com/path/file.txt\");\n URI actual = resolve(base, ref);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 25, + "token": 187, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_collection_of_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_collection_of_non_serializable_object", + "reference": " @Test\n public void should_clone_collection_of_non_serializable_object() {\n List original = new ArrayList<>();\n original.add(new NonSerializableObject(\"value\"));\n List cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_collection_of_non_serializable_object() {\n // Given\n List nonSerializableList = new ArrayList<>();\n nonSerializableList.add(new NonSerializableObject(\"test\"));\n nonSerializableList.add(new NonSerializableObject(\"test2\"));\n\n // When\n List clonedList = objectCloner.clone(nonSerializableList);\n\n // Then\n assertNotSame(nonSerializableList, clonedList);\n assertEquals(nonSerializableList.size(), clonedList.size());\n for (int i = 0; i < nonSerializableList.size(); i++) {\n assertNotSame(nonSerializableList.get(i), clonedList.get(i));\n assertEquals(nonSerializableList.get(i), clonedList.get(i));\n }\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 45, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssign()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssign", + "reference": " @Test\n public void testAssign() {\n consumer = newConsumer();\n final TopicPartition tp = new TopicPartition(\"foo\", 3);\n consumer.assign(singleton(tp));\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().contains(tp));\n verify(applicationEventHandler).add(any(AssignmentChangeEvent.class));\n verify(applicationEventHandler).add(any(NewTopicsMetadataUpdateRequestEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testAssign() {\n // Create a mock Consumer instance\n Consumer mockConsumer = mock(Consumer.class);\n\n // Create a collection of TopicPartition instances\n Collection partitions = new ArrayList<>();\n partitions.add(new TopicPartition(\"testTopic\", 0));\n partitions.add(new TopicPartition(\"testTopic\", 1));\n\n // Call the assign method\n mockConsumer.assign(partitions);\n\n // Verify that the assign method was called with the correct parameters\n verify(mockConsumer).assign(partitions);\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 46, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String relativePath(Path absolutePath, Path basePath) {\n Preconditions.checkArgument(absolutePath.isAbsolute(), \"absolutePath must be an absolute path\");\n Preconditions.checkArgument(basePath.isAbsolute(), \"basePath must be an absolute path\");\n\n List absolutePathComponents = Lists.newArrayList(toPathComponents(absolutePath)); // make a copy\n List basePathComponents = toPathComponents(basePath);\n boolean hasCommonPrefix = basePathComponents.isEmpty(); // when basePath is \"/\", we always have a common prefix\n\n for (String base : basePathComponents) {\n if (absolutePathComponents.get(0).equals(base)) {\n absolutePathComponents.remove(0);\n hasCommonPrefix = true;\n } else {\n break;\n }\n }\n if (hasCommonPrefix) {\n return PATH_JOINER.join(absolutePathComponents);\n } else {\n return absolutePath.toString();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testRelativePath() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRelativePath", + "reference": " @Test\n public void testRelativePath() throws Exception {\n assertEquals(\"a\", PathUtils.relativePath(Path.of(\"/a\"), Path.of(\"/\")));\n assertEquals(\"b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/a\")));\n assertEquals(\"b/c.json\", PathUtils.relativePath(Path.of(\"/a/b/c.json\"), Path.of(\"/a\")));\n assertEquals(\"c/d/e\", PathUtils.relativePath(Path.of(\"/a/b/c/d/e\"), Path.of(\"/a/b\")));\n assertEquals(\"/a/b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/c/d\"))); // no common prefix\n }", + "completions": [ + "```java\n @Test\n public void testRelativePath() throws Exception {\n Path absolutePath = Paths.get(\"/a/b/c/d\");\n Path basePath = Paths.get(\"/a/b\");\n String expected = \"c/d\";\n String actual = relativePath(absolutePath, basePath);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 31, + "token": 234, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Boolean isJson(String value) {\n JsonElement jsonElement;\n try {\n jsonElement = JsonParser.parseString(value);\n } catch (Exception e) {\n return false;\n }\n if (jsonElement == null || jsonElement.isJsonNull()) {\n return false;\n }\n if (!jsonElement.isJsonObject() && !jsonElement.isJsonArray()) {\n return false;\n }\n return true;\n }\n```\n###test function signature:\n```java\n @Test\n void testIsJson()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJson", + "reference": " @Test\n void testIsJson() {\n assertThat(JsonUtil.isJson(\"value\")).isFalse();\n }", + "completions": [ + "```java\n @Test\n void testIsJson() {\n // Test case 1: Valid JSON object\n String jsonObject = \"{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}\";\n assertTrue(isJson(jsonObject));\n\n // Test case 2: Valid JSON array\n String jsonArray = \"[{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}, {\\\"name\\\":\\\"Jane\\\", \\\"age\\\":25, \\\"city\\\":\\\"Chicago\\\"}]\";\n assertTrue(isJson(jsonArray));\n\n // Test case 3: Invalid JSON (not an object or array)\n String invalidJson = \"\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"\";\n assertFalse(isJson(invalidJson));\n\n // Test case 4: Empty string\n String emptyString = \"\";\n assertFalse(isJson(emptyString));\n\n // Test case 5: Null value\n String nullValue = null;\n assertFalse(isJson(nullValue));\n }\n```\n" + ], + "line": 26, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchOnlyHistogram()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchOnlyHistogram", + "reference": " @Test\n public void testSearchOnlyHistogram() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 0,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(0);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(2);\n\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(1).getDocCount()).isEqualTo(1);\n\n assertThat(\n Long.parseLong(histogram.getBuckets().get(0).getKeyAsString()) >= time.toEpochMilli())\n .isTrue();\n assertThat(\n Long.parseLong(histogram.getBuckets().get(1).getKeyAsString())\n <= time.plusSeconds(10).toEpochMilli())\n .isTrue();\n }", + "completions": [ + "```java\n @Test\n public void testSearchOnlyHistogram() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 0;\n AggBuilder aggBuilder = new AggBuilder();\n aggBuilder.addHistogramAgg(\"testField\", \"testInterval\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation().getHistogramAggs().size() > 0);\n assertEquals(\"testField\", result.getAggregation().getHistogramAggs().get(0).getField());\n assertEquals(\"testInterval\", result.getAggregation().getHistogramAggs().get(0).getInterval());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 49, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent", + "reference": " @Test\n public void testReadComponent() {\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n EntityMetadata entityMetadata = mock(EntityMetadata.class);\n VariableSource variableSource = mock(VariableSource.class);\n when(contentPermissionChecker.isPermitted(anyString(), any(), eq(BreadActions.BROWSE), any())).thenReturn(true);\n when(component.getEntityMetadata()).thenReturn(entityMetadata);\n when(entityMetadata.getId()).thenReturn(new DetachedEntityId(\"someid\"));\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(Arrays.asList(asset));\n when(assetVariableResolver.fromAsset(asset)).thenReturn(variableSource);\n ComponentXO componentXO = underTest.readComponent(\"someid\", \"testRepositoryName\");\n\n assertThat(componentXO, is(notNullValue()));\n assertThat(componentXO.getId(), is(\"someid\"));\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n Supplier txSupplier = () -> storageTx;\n\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(txSupplier);\n when(storageTx.findComponent(componentId)).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(ImmutableList.of(asset));\n\n // When\n ComponentXO result = readComponent(repository, componentId);\n\n // Then\n verify(repository).facet(StorageFacet.class);\n verify(storageFacet).txSupplier();\n verify(storageTx).begin();\n verify(storageTx).findComponent(componentId);\n verify(storageTx).browseAssets(component);\n verify(storageTx).close();\n // Add assertions for the result\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 50, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ListResult listEntities(String tableName, String searchIndexName, List matchFilters,\n List queryFilters, List multiMatchFilter, String nextToken, List sorters, Class clazz) {\n SearchQuery searchQuery = createSearchQuery(matchFilters, queryFilters, multiMatchFilter);\n SearchRequest searchRequest = new SearchRequest(tableName, searchIndexName, searchQuery);\n if (!StringUtils.isEmpty(nextToken)) {\n byte[] tokenBytes = EncryptionUtil.decode(nextToken);\n searchRequest.getSearchQuery().setToken(tokenBytes);\n } else {\n if (sorters != null &&!sorters.isEmpty()) {\n searchQuery.setSort(new Sort(sorters));\n }\n }\n SearchRequest.ColumnsToGet columnsToGet = new SearchRequest.ColumnsToGet();\n columnsToGet.setColumns(ReflectionUtil.getPropertyNames(clazz));\n searchRequest.setColumnsToGet(columnsToGet);\n log.info(\"searchRequest:{}\", JsonUtil.toJsonString(searchRequest));\n SearchResponse searchResponse = otsClient.search(searchRequest);\n if (searchResponse == null || searchResponse.getRows() == null) {\n return ListResult.genSuccessListResult(null, 0);\n }\n byte[] nextTokenBytes = searchResponse.getNextToken();\n nextToken = nextTokenBytes == null || nextTokenBytes.length == 0 ? null : EncryptionUtil.encode(nextTokenBytes);\n List result = searchResponse.getRows().stream()\n .map(row -> OtsUtil.convertRowToDTO(row, clazz))\n .collect(Collectors.toList());\n return ListResult.genSuccessListResult(result, searchResponse.getTotalCount(), nextToken);\n }\n```\n###test function signature:\n```java\n @Test\n void testListEntities()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListEntities", + "reference": " @Test\n void testListEntities() {\n final List matchFilters = createMatchFilters();\n final List queryFilters = createQueryFilters();\n final ListResult expectedResult = new ListResult<>();\n FieldSort fieldSort = new FieldSort(OrderOtsConstant.GMT_CREATE_LONG);\n fieldSort.setOrder(SortOrder.DESC);\n expectedResult.setData(Arrays.asList());\n expectedResult.setCount(0L);\n String nextToken = \"CAESFQoTChEKDWdtdENyZWF0ZUxvbmcQARgBIlQKCQBI8UqGigEAAApHA0IAAAAxUzM1MzQzMTM0NjQzMjYzMzAzMzYyMzE2MTMzMzkzOTM1MzEzNjM2MzM2NDM2MzAzMDMwNjYzNTM1MzA2NjY0MzM=\";\n expectedResult.setNextToken(nextToken);\n final SearchResponse searchResponse = new SearchResponse(new Response(\"requestId\"));\n when(mockOtsClient.search(any(SearchRequest.class))).thenReturn(searchResponse);\n ListResult result = baseOtsHelper.listEntities(\"order\", \"order_index\", matchFilters, queryFilters, null, nextToken, Arrays.asList(fieldSort), OrderDTO.class);\n assertThat(result.getCount()).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n void testListEntities() {\n // Given\n String tableName = \"testTable\";\n String searchIndexName = \"testIndex\";\n List matchFilters = Arrays.asList(new OtsFilter(\"field1\", OtsFilter.CompareOperator.EQUAL, \"value1\"),\n new OtsFilter(\"field2\", OtsFilter.CompareOperator.EQUAL, \"value2\"));\n List queryFilters = Arrays.asList(new OtsFilter(\"field3\", OtsFilter.CompareOperator.EQUAL, \"value3\"),\n new OtsFilter(\"field4\", OtsFilter.CompareOperator.EQUAL, \"value4\"));\n List multiMatchFilter = Arrays.asList(new OtsFilter(\"field5\", OtsFilter.CompareOperator.EQUAL, \"value5\"),\n new OtsFilter(\"field6\", OtsFilter.CompareOperator.EQUAL, \"value6\"));\n String nextToken = \"testNextToken\";\n List sorters = Arrays.asList(new Sort.Sorter(\"field7\", Sort.SortOrder.ASC),\n new Sort.Sorter(\"field8\", Sort.SortOrder.DESC));\n Class clazz = TestDTO.class;\n\n // When\n ListResult result = listEntities(tableName, searchIndexName, matchFilters, queryFilters, multiMatchFilter, nextToken, sorters, clazz);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalCount());\n assertNull(result.getNextToken());\n assertTrue(result.getData().isEmpty());\n }\n```\n" + ], + "line": 38, + "token": 346, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 51, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String[] listFiles(URI fileUri, boolean recursive) throws IOException {\n try {\n ImmutableList.Builder builder = ImmutableList.builder();\n String continuationToken = null;\n boolean isDone = false;\n String prefix = normalizeToDirectoryPrefix(fileUri);\n int fileCount = 0;\n while (!isDone) {\n ListObjectsV2Request.Builder listObjectsV2RequestBuilder =\n ListObjectsV2Request.builder().bucket(fileUri.getHost());\n if (!prefix.equals(DELIMITER)) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.prefix(prefix);\n }\n if (!recursive) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.delimiter(DELIMITER);\n }\n if (continuationToken != null) {\n listObjectsV2RequestBuilder.continuationToken(continuationToken);\n }\n ListObjectsV2Request listObjectsV2Request = listObjectsV2RequestBuilder.build();\n LOG.debug(\"Trying to send ListObjectsV2Request {}\", listObjectsV2Request);\n ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);\n LOG.debug(\"Getting ListObjectsV2Response: {}\", listObjectsV2Response);\n List filesReturned = listObjectsV2Response.contents();\n fileCount += filesReturned.size();\n filesReturned.stream()\n .forEach(\n object -> {\n // Only add files and not directories\n if (!object.key().equals(fileUri.getPath())\n && !object.key().endsWith(DELIMITER)) {\n String fileKey = object.key();\n if (fileKey.startsWith(DELIMITER)) {\n fileKey = fileKey.substring(1);\n }\n builder.add(S3_SCHEME + fileUri.getHost() + DELIMITER + fileKey);\n }\n });\n if (fileCount == LIST_MAX_KEYS) {\n // check if we reached the max keys returned, if so abort and throw an error message\n LOG.error(\n \"Too many files ({}) returned from S3 when attempting to list object prefixes\",\n LIST_MAX_KEYS);\n throw new IllegalStateException(\n String.format(\n \"Max keys (%s) reached when attempting to list S3 objects\", LIST_MAX_KEYS));\n }\n isDone = !listObjectsV2Response.isTruncated();\n continuationToken = listObjectsV2Response.nextContinuationToken();\n }\n String[] listedFiles = builder.build().toArray(new String[0]);\n LOG.debug(\n \"Listed {} files from URI: {}, is recursive: {}\", listedFiles.length, fileUri, recursive);\n return listedFiles;\n } catch (Throwable t) {\n throw new IOException(t);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListFilesInFolderRecursive", + "reference": " @Test\n public void testListFilesInFolderRecursive() throws Exception {\n String folder = \"list-files-rec\";\n String[] nestedFolders = new String[] {\"list-files-child-1\", \"list-files-child-2\"};\n String[] originalFiles = new String[] {\"a-list-3.txt\", \"b-list-3.txt\", \"c-list-3.txt\"};\n\n List expectedResultList = new ArrayList<>();\n for (String childFolder : nestedFolders) {\n String folderName = folder + DELIMITER + childFolder;\n for (String fileName : originalFiles) {\n createEmptyFile(folderName, fileName);\n expectedResultList.add(\n String.format(FILE_FORMAT, SCHEME, bucket, folderName + DELIMITER + fileName));\n }\n }\n String[] actualFiles =\n s3BlobFs.listFiles(URI.create(String.format(FILE_FORMAT, SCHEME, bucket, folder)), true);\n\n actualFiles =\n Arrays.stream(actualFiles).filter(x -> x.contains(\"list-3\")).toArray(String[]::new);\n assertEquals(actualFiles.length, expectedResultList.size());\n assertTrue(Arrays.equals(expectedResultList.toArray(), actualFiles));\n }", + "completions": [ + "```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception {\n // Given\n URI fileUri = new URI(\"s3://bucket/folder/\");\n boolean recursive = true;\n String[] expectedFiles = {\n \"s3://bucket/folder/file1.txt\",\n \"s3://bucket/folder/file2.txt\",\n \"s3://bucket/folder/subfolder/file3.txt\",\n \"s3://bucket/folder/subfolder/file4.txt\"\n };\n\n // When\n String[] actualFiles = s3FileSystem.listFiles(fileUri, recursive);\n\n // Then\n assertArrayEquals(expectedFiles, actualFiles);\n }\n```\n" + ], + "line": 70, + "token": 490, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 52, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_nameDuplicate", + "reference": " @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // mock 数据\n RoleDO roleDO = randomPojo(RoleDO.class, o -> o.setName(\"role_name\"));\n roleMapper.insert(roleDO);\n // 准备参数\n String name = \"role_name\";\n\n // 调用,并断言异常\n assertServiceException(() -> roleService.validateRoleDuplicate(name, randomString(), null),\n ROLE_NAME_DUPLICATE, name);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // Given\n String name = \"duplicateName\";\n String code = \"uniqueCode\";\n Long id = 1L;\n RoleDO existingRole = new RoleDO();\n existingRole.setId(2L);\n existingRole.setName(name);\n existingRole.setCode(code);\n when(roleMapper.selectByName(name)).thenReturn(existingRole);\n\n // When\n try {\n roleService.validateRoleDuplicate(name, code, id);\n fail(\"Expected exception not thrown\");\n } catch (ServiceException e) {\n // Then\n assertEquals(ROLE_NAME_DUPLICATE, e.getCode());\n assertEquals(name, e.getArgs()[0]);\n }\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public String sign(Map map) {\n try {\n Date date = new Date(System.currentTimeMillis() + jwtProperties.getExpireTimeMills());\n Algorithm algorithm = Algorithm.HMAC256(jwtProperties.getTokenSecret());\n Map header = new HashMap<>(2);\n header.put(\"typ\", \"jwt\");\n JWTCreator.Builder builder = JWT.create()\n .withHeader(header)\n .withIssuedAt(new Date())\n .withExpiresAt(date);\n map.entrySet().forEach(entry -> {\n if (entry.getValue() instanceof Integer) {\n builder.withClaim(entry.getKey(), (Integer) entry.getValue());\n } else if (entry.getValue() instanceof Long) {\n builder.withClaim(entry.getKey(), (Long) entry.getValue());\n } else if (entry.getValue() instanceof Boolean) {\n builder.withClaim(entry.getKey(), (Boolean) entry.getValue());\n } else if (entry.getValue() instanceof String) {\n builder.withClaim(entry.getKey(), String.valueOf(entry.getValue()));\n } else if (entry.getValue() instanceof Double) {\n builder.withClaim(entry.getKey(), (Double) entry.getValue());\n } else if (entry.getValue() instanceof Date) {\n builder.withClaim(entry.getKey(), (Date) entry.getValue());\n }\n });\n return builder.sign(algorithm);\n } catch (Exception e) {\n return null;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sign_generateJwtToken_succeed()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sign_generateJwtToken_succeed", + "reference": " @Test\n public void sign_generateJwtToken_succeed() {\n Map map = new HashMap<>();\n map.put(\"userId\", 123);\n map.put(\"username\", \"testUser\");\n String token = jwtService.sign(map);\n Assert.assertNotNull(token);\n }", + "completions": [ + "```java\n @Test\n public void sign_generateJwtToken_succeed() {\n // Given\n Map map = new HashMap<>();\n map.put(\"userId\", 1);\n map.put(\"username\", \"testUser\");\n map.put(\"isAdmin\", true);\n map.put(\"expireTime\", System.currentTimeMillis() + 3600000);\n\n // When\n String jwtToken = jwtService.sign(map);\n\n // Then\n assertNotNull(jwtToken);\n assertTrue(jwtToken.length() > 0);\n }\n```\n" + ], + "line": 41, + "token": 349, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 54, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_non_serializable_value_with_null_value", + "reference": " @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map original = new LinkedHashMap<>();\n original.put(\"null\", null);\n original.put(\"key\", new NonSerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map map = new HashMap<>();\n map.put(\"key\", null);\n\n Map clonedMap = (Map) clone(map);\n\n assertNotSame(map, clonedMap);\n assertEquals(map, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 55, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void init(final DiscoveryConfig config) {\n if (this.namingService != null) {\n LOGGER.info(\"Nacos naming service already registered\");\n return;\n }\n Properties properties = config.getProps();\n Properties nacosProperties = new Properties();\n this.groupName = properties.getProperty(\"groupName\", \"SHENYU_GROUP\");\n String serverAddr = config.getServerList();\n nacosProperties.put(PropertyKeyConst.SERVER_ADDR, serverAddr);\n nacosProperties.put(PropertyKeyConst.NAMESPACE, properties.getProperty(NAMESPACE, \"\"));\n nacosProperties.put(PropertyKeyConst.USERNAME, properties.getProperty(PropertyKeyConst.USERNAME, \"\"));\n nacosProperties.put(PropertyKeyConst.PASSWORD, properties.getProperty(PropertyKeyConst.PASSWORD, \"\"));\n nacosProperties.put(PropertyKeyConst.ACCESS_KEY, properties.getProperty(PropertyKeyConst.ACCESS_KEY, \"\"));\n nacosProperties.put(PropertyKeyConst.SECRET_KEY, properties.getProperty(PropertyKeyConst.SECRET_KEY, \"\"));\n try {\n this.namingService = NamingFactory.createNamingService(nacosProperties);\n LOGGER.info(\"Nacos naming service initialized success\");\n } catch (NacosException e) {\n LOGGER.error(\"Error initializing Nacos naming service\", e);\n throw new ShenyuException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testInit", + "reference": " @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n // Set the discovery config\n setField(nacosDiscoveryServiceUnderTest.getClass(), \"namingService\", null);\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n config.setServerList(\"127.0.0.1:8848\");\n properties.setProperty(\"groupName\", \"SHENYU_GROUP\");\n config.setProps(properties);\n\n try (MockedStatic mockedNamingFactory = mockStatic(NamingFactory.class)) {\n // Mock the successful creation of NamingService\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenReturn(namingService);\n nacosDiscoveryServiceUnderTest.init(config);\n mockedNamingFactory.verify(() -> NamingFactory.createNamingService(any(Properties.class)));\n assertEquals(namingService, getField(nacosDiscoveryServiceUnderTest, \"namingService\"));\n // Mock the situation where NamingService fails to be created and throws an exception\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenThrow(new NacosException());\n assertDoesNotThrow(() -> nacosDiscoveryServiceUnderTest.init(config));\n }\n }", + "completions": [ + "```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n properties.put(\"groupName\", \"SHENYU_GROUP\");\n properties.put(\"serverAddr\", \"localhost:8848\");\n properties.put(\"namespace\", \"\");\n properties.put(\"username\", \"\");\n properties.put(\"password\", \"\");\n properties.put(\"accessKey\", \"\");\n properties.put(\"secretKey\", \"\");\n config.setProps(properties);\n config.setServerList(\"localhost:8848\");\n\n NacosDiscoveryService nacosDiscoveryService = new NacosDiscoveryService();\n nacosDiscoveryService.init(config);\n\n Field namingServiceField = NacosDiscoveryService.class.getDeclaredField(\"namingService\");\n namingServiceField.setAccessible(true);\n NamingService namingService = (NamingService) namingServiceField.get(nacosDiscoveryService);\n\n assertNotNull(namingService);\n }\n```\n" + ], + "line": 35, + "token": 273, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 56, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyRecoveryNotBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(850, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 2) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(750, 0, indexerConfig));\n\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n assertThat(recoveryTaskCreator.determineStartingOffset(1201, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTaskCreator.determineStartingOffset(1200, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(1150, 0, indexerConfig));\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore))\n .contains(recoveryTask1, recoveryTask11);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 57, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public boolean isJsonMime(String mediaType) {\n // \"* / *\" is default to JSON\n if (\"*/*\".equals(mediaType)) {\n return true;\n }\n\n try {\n return isJsonMime(MediaType.parseMediaType(mediaType));\n } catch (InvalidMediaTypeException e) {\n }\n return false;\n }\n```\n###test function signature:\n```java\n @Test\n public void testIsJsonMime()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJsonMime", + "reference": " @Test\n public void testIsJsonMime() {\n assertFalse(apiClient.isJsonMime(null));\n assertFalse(apiClient.isJsonMime(\"\"));\n assertFalse(apiClient.isJsonMime(\"text/plain\"));\n assertFalse(apiClient.isJsonMime(\"application/xml\"));\n assertFalse(apiClient.isJsonMime(\"application/jsonp\"));\n assertFalse(apiClient.isJsonMime(\"example/json\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+jsonx\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+xjson\"));\n\n assertTrue(apiClient.isJsonMime(\"application/json\"));\n assertTrue(apiClient.isJsonMime(\"application/json; charset=UTF8\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/JSON\"));\n\n assertTrue(apiClient.isJsonMime(\"application/problem+json\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/PROBLEM+JSON\"));\n assertTrue(apiClient.isJsonMime(\"application/json\\t\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+bar+json\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json;x;y\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json\\t;\"));\n assertTrue(apiClient.isJsonMime(\"Example/fOO+JSON\"));\n }", + "completions": [ + "```java\n @Test\n public void testIsJsonMime() {\n // Test case 1: \"* / *\"\n assertTrue(isJsonMime(\"*/*\"));\n\n // Test case 2: \"application/json\"\n assertTrue(isJsonMime(\"application/json\"));\n\n // Test case 3: \"application/json; charset=UTF-8\"\n assertTrue(isJsonMime(\"application/json; charset=UTF-8\"));\n\n // Test case 4: \"application/xml\"\n assertFalse(isJsonMime(\"application/xml\"));\n\n // Test case 5: \"text/plain\"\n assertFalse(isJsonMime(\"text/plain\"));\n\n // Test case 6: \"invalid/media/type\"\n assertFalse(isJsonMime(\"invalid/media/type\"));\n\n // Test case 7: null\n assertFalse(isJsonMime(null));\n }\n```\n" + ], + "line": 22, + "token": 139, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n dpopService.validateDpopThumprint(\"test\", \"test\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n // Given\n String existingThumprint = \"existingThumprint\";\n String requestThumprint = \"existingThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n // Call the method to test\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String getExpandUrl(String url, Object paramObj) {\n UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url);\n if (paramObj instanceof Map) {\n for (Map.Entry param : ((Map) paramObj).entrySet()) {\n if (null != param.getValue()) {\n builder.queryParam(param.getKey(), param.getValue());\n }\n }\n } else {\n for (Field param : paramObj.getClass().getDeclaredFields()) {\n param.setAccessible(true);\n\n try {\n Object value = param.get(paramObj);\n if (null != value) {\n if (value instanceof List) {\n builder.queryParam(param.getName(), (List) value);\n } else {\n builder.queryParam(param.getName(), value);\n }\n }\n } catch (IllegalAccessException e) {\n log.warn(\"Field: {} value retrieval failed, exception: {}\", param.getName(), e);\n }\n }\n }\n\n return builder.build().encode().toUri().toString();\n }\n```\n###test function signature:\n```java\n @Test\n void testGetExpandUrl()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetExpandUrl", + "reference": " @Test\n void testGetExpandUrl() {\n String result = HttpUtil.getExpandUrl(\"http://baidu.com\", \"paramObj\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testGetExpandUrl() {\n // Test with a Map\n Map map = new HashMap<>();\n map.put(\"param1\", \"value1\");\n map.put(\"param2\", \"value2\");\n String url = \"http://example.com\";\n String expected = \"http://example.com?param1=value1¶m2=value2\";\n String actual = getExpandUrl(url, map);\n assertEquals(expected, actual);\n\n // Test with a custom object\n TestObject obj = new TestObject(\"value1\", \"value2\");\n expected = \"http://example.com?param1=value1¶m2=value2\";\n actual = getExpandUrl(url, obj);\n assertEquals(expected, actual);\n }\n\n public static class TestObject {\n private String param1;\n private String param2;\n\n public TestObject(String param1, String param2) {\n this.param1 = param1;\n this.param2 = param2;\n }\n\n public String getParam1() {\n return param1;\n }\n\n public String getParam2() {\n return param2;\n }\n }\n```\n" + ], + "line": 38, + "token": 277, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 60, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"asdf\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n String invalidExpirationTime = \"invalid\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(invalidExpirationTime);\n assertEquals(CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 61, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Get\n @Path(\"/:indexName/_mapping\")\n public HttpResponse mapping(\n @Param(\"indexName\") Optional indexName,\n @Param(\"startTimeEpochMs\") Optional startTimeEpochMs,\n @Param(\"endTimeEpochMs\") Optional endTimeEpochMs)\n throws IOException {\n // Use a tree map so the results are naturally sorted\n Map> propertiesMap = new TreeMap<>();\n\n // we default the schema search to the last hour if params are not provided\n AstraSearch.SchemaResult schemaResult =\n searcher.getSchema(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(\n startTimeEpochMs.orElse(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build());\n\n Map schema = SearchResultUtils.fromSchemaResultProto(schemaResult);\n schema.forEach((key, value) -> propertiesMap.put(key, Map.of(\"type\", value.getName())));\n\n // todo - remove this after we add support for a \"date\" type\n // override the timestamp as a date field for proper autocomplete\n propertiesMap.put(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, Map.of(\"type\", \"date\"));\n\n return HttpResponse.of(\n HttpStatus.OK,\n MediaType.JSON,\n JsonUtil.writeAsString(\n ImmutableMap.of(\n indexName.orElseThrow(),\n ImmutableMap.of(\"mappings\", ImmutableMap.of(\"properties\", propertiesMap)))));\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexMapping() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexMapping", + "reference": " @Test\n public void testIndexMapping() throws IOException {\n AstraQueryServiceBase searcher = mock(AstraQueryServiceBase.class);\n ElasticsearchApiService serviceUnderTest = new ElasticsearchApiService(searcher);\n\n Instant start = Instant.now();\n Instant end = start.minusSeconds(60);\n\n when(searcher.getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build())))\n .thenReturn(AstraSearch.SchemaResult.newBuilder().build());\n\n HttpResponse response =\n serviceUnderTest.mapping(\n Optional.of(\"foo\"), Optional.of(start.toEpochMilli()), Optional.of(end.toEpochMilli()));\n verify(searcher)\n .getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build()));\n\n // handle response\n AggregatedHttpResponse aggregatedRes = response.aggregate().join();\n String body = aggregatedRes.content(StandardCharsets.UTF_8);\n JsonNode jsonNode = new ObjectMapper().readTree(body);\n\n assertThat(aggregatedRes.status().code()).isEqualTo(200);\n\n assertThat(jsonNode.findValue(\"foo\")).isNotNull();\n assertThat(\n jsonNode.findValue(\"foo\").findValue(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName))\n .isNotNull();\n\n when(searcher.getSchema(any()))\n .thenAnswer(\n invocationOnMock -> {\n AstraSearch.SchemaRequest request =\n ((AstraSearch.SchemaRequest) invocationOnMock.getArguments()[0]);\n assertThat(request.getDataset()).isEqualTo(\"bar\");\n assertThat(request.getStartTimeEpochMs())\n .isCloseTo(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli(),\n Offset.offset(1000L));\n assertThat(request.getEndTimeEpochMs())\n .isCloseTo(Instant.now().toEpochMilli(), Offset.offset(1000L));\n return AstraSearch.SchemaResult.newBuilder().build();\n });\n serviceUnderTest.mapping(Optional.of(\"bar\"), Optional.empty(), Optional.empty());\n }", + "completions": [ + "```java\n @Test\n public void testIndexMapping() throws IOException {\n Optional indexName = Optional.of(\"testIndex\");\n Optional startTimeEpochMs = Optional.of(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli());\n Optional endTimeEpochMs = Optional.of(Instant.now().toEpochMilli());\n\n AstraSearch.SchemaResult schemaResult =\n AstraSearch.SchemaResult.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(startTimeEpochMs.orElse(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build();\n\n when(searcher.getSchema(any(AstraSearch.SchemaRequest.class))).thenReturn(schemaResult);\n\n HttpResponse response = mapping(indexName, startTimeEpochMs, endTimeEpochMs);\n\n assertEquals(HttpStatus.OK, response.getStatus());\n assertEquals(MediaType.JSON, response.getContentType());\n assertNotNull(response.getContent());\n }\n```\n" + ], + "line": 43, + "token": 354, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 62, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "@SuppressWarnings", + "reference": " @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n consumer = newConsumer();\n Exception e = assertThrows(UnsupportedOperationException.class, () -> consumer.poll(0L));\n assertEquals(\"Consumer.poll(long) is not supported when \\\"group.protocol\\\" is \\\"consumer\\\". \" +\n \"This method is deprecated and will be removed in the next major release.\", e.getMessage());\n }", + "completions": [ + "```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 46, + "token": 373, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 63, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testNullSearchString()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNullSearchString", + "reference": " @Test\n public void testNullSearchString() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n assertThatExceptionOfType(IllegalArgumentException.class)\n .isThrownBy(\n () ->\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME + \"miss\",\n null,\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\")));\n }", + "completions": [ + "```java\n @Test\n public void testNullSearchString() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = null;\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n try {\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n fail(\"Expected an IllegalArgumentException to be thrown\");\n } catch (IllegalArgumentException e) {\n // Expected\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 64, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture sendMessageBack(ProxyContext ctx, ReceiptHandle handle, String messageId,\n ConsumerSendMsgBackRequestHeader requestHeader, long timeoutMillis) {\n // Build the response.\n final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null, null);\n\n Integer delayLevel = requestHeader.getDelayLevel();\n if (Objects.isNull(delayLevel)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument delay level is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n Long offset = requestHeader.getOffset();\n if (Objects.isNull(offset)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument offset is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n VirtualQueue virtualQueue = new VirtualQueue(requestHeader.getBname());\n\n CompletableFuture topicFuture = topicOf(requestHeader.getOriginTopic());\n CompletableFuture groupFuture = consumerGroupOf(requestHeader.getGroup());\n return topicFuture.thenCombine(groupFuture, (topic, group) -> {\n if (topic.getTopicId() != virtualQueue.topicId()) {\n LOGGER.error(\"Topic id in request header {} does not match topic id in message queue {}, maybe the topic is recreated.\",\n topic.getTopicId(), virtualQueue.topicId());\n throw new ProxyException(apache.rocketmq.v2.Code.TOPIC_NOT_FOUND, \"Topic resource does not exist.\");\n }\n return Pair.of(topic, group);\n }).thenCompose(pair -> {\n Topic topic = pair.getLeft();\n ConsumerGroup group = pair.getRight();\n\n return store.pull(StoreContext.EMPTY, group.getGroupId(), topic.getTopicId(), virtualQueue.physicalQueueId(),\n Filter.DEFAULT_FILTER, requestHeader.getOffset(), 1, false)\n .thenApply(pullResult -> {\n if (pullResult.status() == com.automq.rocketmq.store.model.message.PullResult.Status.FOUND) {\n return pullResult.messageList().get(0);\n }\n throw new ProxyException(apache.rocketmq.v2.Code.MESSAGE_NOT_FOUND, \"Message not found from server.\");\n }).thenCompose(messageExt -> {\n if (messageExt.deliveryAttempts() > group.getMaxDeliveryAttempt()) {\n return deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n }\n\n // Message consume retry strategy\n // <0: no retry,put into DLQ directly\n // =0: broker control retry frequency\n // >0: client control retry frequency\n return switch (Integer.compare(delayLevel, 0)) {\n case -1 ->\n deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n case 0 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n // Keep the same logic as apache RocketMQ.\n int serverDelayLevel = messageExt.deliveryAttempts() + 1;\n messageExt.setDeliveryAttempts(serverDelayLevel);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(serverDelayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put messageExt to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n case 1 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n messageExt.setDeliveryAttempts(messageExt.deliveryAttempts() + 1);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(delayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put message to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n default -> throw new IllegalStateException(\"Never reach here\");\n };\n });\n }).whenComplete((nil, throwable) -> {\n if (throwable != null) {\n response.setCode(ResponseCode.SYSTEM_ERROR);\n response.setRemark(throwable.getMessage());\n }\n }).thenApply(nil -> response);\n }\n```\n###test function signature:\n```java\n @Test\n void sendMessageBack()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendMessageBack", + "reference": " @Test\n void sendMessageBack() {\n String topicName = \"topic\";\n VirtualQueue virtualQueue = new VirtualQueue(2, 0);\n ConsumerSendMsgBackRequestHeader header = new ConsumerSendMsgBackRequestHeader();\n header.setBname(virtualQueue.brokerName());\n header.setGroup(\"group\");\n header.setOriginTopic(topicName);\n header.setOffset(0L);\n header.setDelayLevel(0);\n\n CompletionException exception = assertThrowsExactly(CompletionException.class, () -> messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join());\n assertInstanceOf(ProxyException.class, ExceptionUtils.getRealException(exception));\n ProxyException realException = (ProxyException) ExceptionUtils.getRealException(exception);\n assertEquals(realException.getErrorCode(), Code.MESSAGE_NOT_FOUND);\n assertEquals(realException.getMessage(), \"Message not found from server.\");\n\n long topicId = metadataService.topicOf(topicName).join().getTopicId();\n messageStore.put(StoreContext.EMPTY, FlatMessageUtil.convertTo(ProxyContextExt.create(), topicId, 0, \"\", new Message(topicName, \"\", new byte[] {})));\n\n // Broker controlled delay level.\n header.setDelayLevel(0);\n RemotingCommand response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Client controlled delay level.\n header.setDelayLevel(16);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Forward message into dead letter topic.\n header.setDelayLevel(-1);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.times(1)).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n }", + "completions": [ + "```java\n @Test\n void sendMessageBack() {\n // Given\n ProxyContext ctx = mock(ProxyContext.class);\n ReceiptHandle handle = mock(ReceiptHandle.class);\n String messageId = \"testMessageId\";\n ConsumerSendMsgBackRequestHeader requestHeader = mock(ConsumerSendMsgBackRequestHeader.class);\n long timeoutMillis = 1000L;\n\n when(requestHeader.getDelayLevel()).thenReturn(1);\n when(requestHeader.getOffset()).thenReturn(1L);\n when(requestHeader.getOriginTopic()).thenReturn(\"testOriginTopic\");\n when(requestHeader.getGroup()).thenReturn(\"testGroup\");\n when(requestHeader.getBname()).thenReturn(\"testBname\");\n\n Topic topic = mock(Topic.class);\n when(topic.getTopicId()).thenReturn(1);\n\n ConsumerGroup group = mock(ConsumerGroup.class);\n when(group.getGroupId()).thenReturn(\"testGroupId\");\n when(group.getMaxDeliveryAttempt()).thenReturn(3);\n\n FlatMessage message = mock(FlatMessage.class);\n when(message.deliveryAttempts()).thenReturn(1);\n when(message.originalOffset()).thenReturn(1L);\n\n MessageExt messageExt = mock(MessageExt.class);\n when(messageExt.message()).thenReturn(message);\n when(messageExt.deliveryAttempts()).thenReturn(1);\n\n PullResult pullResult = mock(PullResult.class);\n when(pullResult.status()).thenReturn(PullResult.Status.FOUND);\n when(pullResult.messageList()).thenReturn(List.of(messageExt));\n\n Store store = mock(Store.class);\n when(store.pull(any(), any(), any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(pullResult));\n when(store.put(any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n DeadLetterService deadLetterService = mock(DeadLetterService.class);\n when(deadLetterService.send(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n ProxyContextExt ctxExt = mock(ProxyContextExt.class);\n\n TopicService topicService = mock(TopicService.class);\n when(topicService.topicOf(any())).thenReturn(CompletableFuture.completedFuture(topic));\n when(topicService.consumerGroupOf(any())).thenReturn(CompletableFuture.completedFuture(group));\n\n ConsumerService consumerService = new ConsumerService(store, deadLetterService, topicService);\n\n // When\n CompletableFuture future = consumerService.sendMessageBack(ctx, handle, messageId, requestHeader, timeoutMillis);\n\n // Then\n RemotingCommand response = future.join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n }\n```\n" + ], + "line": 99, + "token": 793, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 65, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInRequestThreads()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInRequestThreads", + "reference": " @Test\n public void testGetTracerInRequestThreads() {\n ApolloAuditTracer mockTracer = new ApolloAuditTracer(Mockito.mock(ApolloAuditScopeManager.class), supplier);\n RequestAttributes mockRequestAttributes = Mockito.mock(RequestAttributes.class);\n RequestContextHolder.setRequestAttributes(mockRequestAttributes);\n Mockito.when(mockRequestAttributes.getAttribute(Mockito.eq(ApolloAuditConstants.TRACER), Mockito.eq(RequestAttributes.SCOPE_REQUEST)))\n .thenReturn(mockTracer);\n ApolloAuditTracer get = traceContext.tracer();\n assertNotNull(get);\n Mockito.verify(traceContext, Mockito.times(0))\n .setTracer(Mockito.any(ApolloAuditTracer.class));\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInRequestThreads() {\n // Given\n RequestAttributes requestAttributes = Mockito.mock(RequestAttributes.class);\n ApolloAuditTracer tracer = Mockito.mock(ApolloAuditTracer.class);\n Mockito.when(requestAttributes.getAttribute(ApolloAuditConstants.TRACER, RequestAttributes.SCOPE_REQUEST)).thenReturn(tracer);\n Mockito.when(RequestContextHolder.getRequestAttributes()).thenReturn(requestAttributes);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n Assert.assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Map generate(String modelName) {\n Map kv = new HashMap<>();\n if (modelName != null) {\n final Schema schema = this.globalContext.getSchemaMap().get(modelName);\n if (schema != null) {\n String example = Json.pretty(this.resolveModelToExample(modelName, schema));\n\n if (example != null) {\n kv.put(EXAMPLE, example);\n return Map.copyOf(kv);\n }\n }\n }\n return Collections.emptyMap();\n }\n```\n###test function signature:\n```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "shouldGenerateOneOfWhenOneOfInRoot", + "reference": " @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n OpenAPIModelGenerator generator = setupPayloadGenerator();\n\n Map example = generator.generate(\"PetType\");\n String exampleJson = example.get(\"example\");\n\n Assertions.assertThat(exampleJson)\n .contains(\"PetTypeONE_OF#/components/schemas/Husky\")\n .contains(\"PetTypeONE_OF#/components/schemas/Labrador\");\n }", + "completions": [ + "```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n // Given\n String modelName = \"ModelName\";\n Schema schema = new Schema();\n schema.setOneOf(Arrays.asList(new Schema(), new Schema()));\n this.globalContext.getSchemaMap().put(modelName, schema);\n\n // When\n Map result = this.generate(modelName);\n\n // Then\n assertNotNull(result);\n assertFalse(result.isEmpty());\n assertTrue(result.containsKey(EXAMPLE));\n assertNotNull(result.get(EXAMPLE));\n }\n```\n" + ], + "line": 25, + "token": 181, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInAnotherThreadButSameRequest", + "reference": " @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n ApolloAuditTracer mockTracer = Mockito.mock(ApolloAuditTracer.class);\n {\n Mockito.when(traceContext.tracer()).thenReturn(mockTracer);\n }\n CountDownLatch latch = new CountDownLatch(1);\n Executors.newSingleThreadExecutor().submit(() -> {\n ApolloAuditTracer tracer = traceContext.tracer();\n\n assertEquals(mockTracer, tracer);\n\n latch.countDown();\n });\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n // Given\n RequestAttributes requestAttributes = new ServletRequestAttributes(new MockHttpServletRequest());\n RequestContextHolder.setRequestAttributes(requestAttributes);\n ApolloAuditTracer tracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n requestAttributes.setAttribute(ApolloAuditConstants.TRACER, tracer, RequestAttributes.SCOPE_REQUEST);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchMultipleItemsAndIndices()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchMultipleItemsAndIndices", + "reference": " @Test\n public void testSearchMultipleItemsAndIndices() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(1);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchMultipleItemsAndIndices() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(10, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 69, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n static Map>> getCandidateJobs(String tableNameWithType,\n Map> allJobMetadata)\n throws Exception {\n long nowMs = System.currentTimeMillis();\n Map>> candidates = new HashMap<>();\n // If the job started most recently has already completed, then skip retry for the table.\n Pair latestStartedJob = null;\n Pair latestCompletedJob = null;\n // The processing order of job metadata from the given Map is not deterministic. Track the completed original\n // jobs so that we can simply skip the retry jobs belonging to the completed original jobs.\n Map completedOriginalJobs = new HashMap<>();\n Set cancelledOriginalJobs = new HashSet<>();\n for (Map.Entry> entry : allJobMetadata.entrySet()) {\n String jobId = entry.getKey();\n Map jobMetadata = entry.getValue();\n long statsUpdatedAt = Long.parseLong(jobMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));\n String jobStatsInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);\n if (StringUtils.isEmpty(jobStatsInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job progress stats\", jobId);\n continue;\n }\n String jobCtxInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n if (StringUtils.isEmpty(jobCtxInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job context\", jobId);\n continue;\n }\n TableRebalanceProgressStats jobStats = JsonUtils.stringToObject(jobStatsInStr, TableRebalanceProgressStats.class);\n TableRebalanceContext jobCtx = JsonUtils.stringToObject(jobCtxInStr, TableRebalanceContext.class);\n long jobStartTimeMs = jobStats.getStartTimeMs();\n if (latestStartedJob == null || latestStartedJob.getRight() < jobStartTimeMs) {\n latestStartedJob = Pair.of(jobId, jobStartTimeMs);\n }\n String originalJobId = jobCtx.getOriginalJobId();\n RebalanceResult.Status jobStatus = jobStats.getStatus();\n if (jobStatus == RebalanceResult.Status.DONE || jobStatus == RebalanceResult.Status.NO_OP) {\n LOGGER.info(\"Skip rebalance job: {} as it has completed with status: {}\", jobId, jobStatus);\n completedOriginalJobs.put(originalJobId, jobId);\n if (latestCompletedJob == null || latestCompletedJob.getRight() < jobStartTimeMs) {\n latestCompletedJob = Pair.of(jobId, jobStartTimeMs);\n }\n continue;\n }\n if (jobStatus == RebalanceResult.Status.FAILED || jobStatus == RebalanceResult.Status.ABORTED) {\n LOGGER.info(\"Found rebalance job: {} for original job: {} has been stopped with status: {}\", jobId,\n originalJobId, jobStatus);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n continue;\n }\n if (jobStatus == RebalanceResult.Status.CANCELLED) {\n LOGGER.info(\"Found cancelled rebalance job: {} for original job: {}\", jobId, originalJobId);\n cancelledOriginalJobs.add(originalJobId);\n continue;\n }\n // Check if an IN_PROGRESS job is still actively running.\n long heartbeatTimeoutMs = jobCtx.getConfig().getHeartbeatTimeoutInMs();\n if (nowMs - statsUpdatedAt < heartbeatTimeoutMs) {\n LOGGER.info(\"Rebalance job: {} is actively running with status updated at: {} within timeout: {}. Skip \"\n + \"retry for table: {}\", jobId, statsUpdatedAt, heartbeatTimeoutMs, tableNameWithType);\n return Collections.emptyMap();\n }\n // The job is considered failed, but it's possible it is still running, then we might end up with more than one\n // rebalance jobs running in parallel for a table. The rebalance algorithm is idempotent, so this should be fine\n // for the correctness.\n LOGGER.info(\"Found stuck rebalance job: {} for original job: {}\", jobId, originalJobId);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n }\n if (latestCompletedJob != null && latestCompletedJob.getLeft().equals(latestStartedJob.getLeft())) {\n LOGGER.info(\"Rebalance job: {} started most recently has already done. Skip retry for table: {}\",\n latestCompletedJob.getLeft(), tableNameWithType);\n return Collections.emptyMap();\n }\n for (String jobId : cancelledOriginalJobs) {\n LOGGER.info(\"Skip original job: {} as it's cancelled\", jobId);\n candidates.remove(jobId);\n }\n for (Map.Entry entry : completedOriginalJobs.entrySet()) {\n LOGGER.info(\"Skip original job: {} as it's completed by attempt: {}\", entry.getKey(), entry.getValue());\n candidates.remove(entry.getKey());\n }\n return candidates;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetCandidateJobs()\n throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetCandidateJobs", + "reference": " @Test\n public void testGetCandidateJobs()\n throws Exception {\n String tableName = \"table01\";\n Map> allJobMetadata = new HashMap<>();\n\n // Original job run as job1, and all its retry jobs failed too.\n RebalanceConfig jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n TableRebalanceProgressStats stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(1000);\n TableRebalanceContext jobCtx = TableRebalanceContext.forInitialAttempt(\"job1\", jobCfg);\n Map jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job1\", stats, jobCtx);\n allJobMetadata.put(\"job1\", jobMetadata);\n // 3 failed retry runs for job1\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 2, 1100, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 3, 1200, RebalanceResult.Status.ABORTED);\n allJobMetadata.put(\"job1_3\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 4, 1300, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_4\", jobMetadata);\n\n // Original job run as job2, and its retry job job2_1 completed.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(2000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job2\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job2\", stats, jobCtx);\n allJobMetadata.put(\"job2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job2\", 2, 2100, RebalanceResult.Status.DONE);\n allJobMetadata.put(\"job2_2\", jobMetadata);\n\n // Original job run as job3, and failed to send out heartbeat in time.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.IN_PROGRESS);\n stats.setStartTimeMs(3000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job3\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job3\", stats, jobCtx);\n jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"3000\");\n allJobMetadata.put(\"job3\", jobMetadata);\n\n // Original job run as job4, which didn't have retryJobCfg as from old version of the code.\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(4000);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job4\", stats, null);\n jobMetadata.remove(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n allJobMetadata.put(\"job4\", jobMetadata);\n\n // Only need to retry job1 and job3, as job2 is completed and job4 is from old version of code.\n Map>> jobs =\n RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 2);\n assertTrue(jobs.containsKey(\"job1\"));\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job1\").size(), 4); // four runs including job1,job1_1,job1_2,job1_3\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Abort job1 and cancel its retries, then only job3 is retry candidate.\n jobMetadata = allJobMetadata.get(\"job1_4\");\n cancelRebalanceJob(jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 1);\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Add latest job5 that's already done, thus no need to retry for table.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.DONE);\n stats.setStartTimeMs(5000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job5\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job5\", stats, jobCtx);\n allJobMetadata.put(\"job5\", jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 0);\n }", + "completions": [ + "```java\n @Test\n public void testGetCandidateJobs()\n throws Exception {\n // Given\n String tableNameWithType = \"table1_type1\";\n Map> allJobMetadata = new HashMap<>();\n Map jobMetadata1 = new HashMap<>();\n jobMetadata1.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"FAILED\\\"}\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job1\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job1\", jobMetadata1);\n Map jobMetadata2 = new HashMap<>();\n jobMetadata2.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"CANCELLED\\\"}\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job2\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job2\", jobMetadata2);\n\n // When\n Map>> result = getCandidateJobs(tableNameWithType, allJobMetadata);\n\n // Then\n assertEquals(1, result.size());\n assertTrue(result.containsKey(\"job2\"));\n Set> jobSet = result.get(\"job2\");\n assertEquals(1, jobSet.size());\n Pair jobPair = jobSet.iterator().next();\n assertEquals(\"job2\", jobPair.getLeft().getOriginalJobId());\n assertEquals(1640995200000L, jobPair.getRight().longValue());\n }\n```\n" + ], + "line": 94, + "token": 952, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 70, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchError()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchError", + "reference": " @Test\n public void distClientMatchError() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.ERROR);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.ERROR);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.NO_SUB));\n for (IMQTTTransientSession session : sessions) {\n verify(session, never()).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchError() {\n // Create a mock for the ITenantMeter interface\n ITenantMeter mockTenantMeter = Mockito.mock(ITenantMeter.class);\n // Create a mock for the ILocalDistService interface\n ILocalDistService mockLocalDistService = Mockito.mock(ILocalDistService.class);\n // Create a mock for the sessionMap\n Map mockSessionMap = new HashMap<>();\n // Create a mock for the routeMap\n Map> mockRouteMap = new HashMap<>();\n // Create a mock for the resourceThrottler\n IResourceThrottler mockResourceThrottler = Mockito.mock(IResourceThrottler.class);\n\n // Set up the mock objects\n Mockito.when(mockTenantMeter.get(Mockito.anyString())).thenReturn(mockTenantMeter);\n Mockito.when(mockLocalDistService.isGlobal(Mockito.anyString())).thenReturn(false);\n Mockito.when(mockSessionMap.get(Mockito.anyString())).thenReturn(null);\n Mockito.when(mockRouteMap.get(Mockito.any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(null));\n Mockito.when(mockResourceThrottler.hasResource(Mockito.anyString(), Mockito.anyLong())).thenReturn(false);\n\n // Create a new instance of the class under test\n DistService distService = new DistService(mockTenantMeter, mockLocalDistService, mockSessionMap, mockRouteMap, mockResourceThrottler);\n\n // Create a test request\n DeliveryRequest request = DeliveryRequest.newBuilder().build();\n\n // Call the method under test\n CompletableFuture replyFuture = distService.dist(request);\n\n // Verify the results\n DeliveryReply reply = replyFuture.join();\n Assert.assertEquals(0, reply.getResultMap().size());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 71, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchOK()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchOK", + "reference": " @Test\n public void distClientMatchOK() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n when(session.publish(any(), anyList())).thenReturn(true);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.OK);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.OK);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.OK));\n for (IMQTTTransientSession session : sessions) {\n verify(session, times(1)).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchOK() {\n // Given\n DeliveryRequest request = DeliveryRequest.newBuilder()\n .putPackage(\"tenant1\", DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .setMessagePack(TopicMessagePack.newBuilder()\n .setTopic(\"topic1\")\n .setPayload(\"payload1\")\n .build())\n .addMatchInfo(MatchInfo.newBuilder()\n .setReceiverId(\"receiver1\")\n .setTopicFilter(\"topic1\")\n .build())\n .build())\n .build())\n .build();\n sessionMap.put(\"receiver1\", mock(IMQTTTransientSession.class));\n when(sessionMap.get(\"receiver1\").publish(any(MatchInfo.class), anyList())).thenReturn(true);\n when(routeMap.get(any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(LocalRoutes.newBuilder()\n .setLocalizedReceiverId(\"receiver1\")\n .putRoute(\"receiver1\", mock(IMQTTTransientSession.class))\n .build()));\n when(ITenantMeter.get(\"tenant1\")).thenReturn(mock(ITenantMeter.class));\n when(resourceThrottler.hasResource(\"tenant1\", TotalTransientFanOutBytesPerSeconds)).thenReturn(true);\n\n // When\n CompletableFuture replyFuture = dist(request);\n\n // Then\n DeliveryReply reply = replyFuture.join();\n DeliveryResults results = reply.getResultMap().get(\"tenant1\");\n assertEquals(1, results.getResultCount());\n DeliveryResult result = results.getResult(0);\n assertEquals(DeliveryResult.Code.OK, result.getCode());\n assertEquals(\"receiver1\", result.getMatchInfo().getReceiverId());\n assertEquals(\"topic1\", result.getMatchInfo().getTopicFilter());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 72, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Range calcInterquartileRangeApproximation(List sortedValues) {\n Range iqr = null;\n\n if (sortedValues.size() > 0) {\n BigDecimal q1 = valueCloseToQ1(sortedValues);\n BigDecimal q3 = valueCloseToQ3(sortedValues);\n BigDecimal max = sortedValues.get(sortedValues.size() - 1);\n\n // if iqr == 0 AND max == q3 then recursively try finding a non-zero iqr approximation.\n if (q1.compareTo(q3) == 0 && max.compareTo(q3) == 0) {\n // filter out max and try again\n iqr = this.calcInterquartileRangeApproximation(\n sortedValues.stream().filter(d -> d.compareTo(max) < 0).collect(Collectors.toList()));\n }\n\n // if range is still empty use the original q1 and q3 values\n if (iqr == null || iqr.isEmpty()) {\n iqr = Range.closedOpen(q1, q3);\n }\n }\n\n return iqr;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCalcInterquartileRangeApproximation()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCalcInterquartileRangeApproximation", + "reference": " @Test\n public void testCalcInterquartileRangeApproximation() {\n List values = decList(1, 2, 3, 4, 5, 6, 7, 8);\n assertEquals(new BigDecimal(\"3\"), dataBinHelper.calcInterquartileRangeApproximation(values).lowerEndpoint());\n assertEquals(new BigDecimal(\"7\"), dataBinHelper.calcInterquartileRangeApproximation(values).upperEndpoint());\n }", + "completions": [ + "```java\n @Test\n public void testCalcInterquartileRangeApproximation() {\n List sortedValues = Arrays.asList(\n new BigDecimal(\"1.0\"),\n new BigDecimal(\"2.0\"),\n new BigDecimal(\"3.0\"),\n new BigDecimal(\"4.0\"),\n new BigDecimal(\"5.0\"),\n new BigDecimal(\"6.0\"),\n new BigDecimal(\"7.0\"),\n new BigDecimal(\"8.0\"),\n new BigDecimal(\"9.0\"),\n new BigDecimal(\"10.0\")\n );\n\n Range expectedIqr = Range.closedOpen(new BigDecimal(\"2.5\"), new BigDecimal(\"7.5\"));\n Range actualIqr = calcInterquartileRangeApproximation(sortedValues);\n\n assertEquals(expectedIqr, actualIqr);\n }\n```\n" + ], + "line": 30, + "token": 250, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 73, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_non_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new NonSerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new NonSerializableObject(\"name2\"),\n new NonSerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObjectWithNonSerializableNestedObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serialable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key1\", \"value1\");\n map.put(\"key2\", new NonSerializableObject(\"nonSerializableValue\"));\n\n // When\n Map clonedMap = objectCloner.clone(map);\n\n // Then\n assertNotSame(map, clonedMap);\n assertEquals(map.size(), clonedMap.size());\n assertEquals(map.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(map.get(\"key2\"), clonedMap.get(\"key2\"));\n assertEquals(map.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 74, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SignedJWT token2Jwt(String token) {\n if (token == null) {\n log.error(\"Token is null.\");\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n try {\n return SignedJWT.parse(token);\n } catch (ParseException e) {\n log.error(\"Token to JWT Error,message = {}\", e.getMessage());\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testToken2Jwt()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testToken2Jwt", + "reference": " @Test\n public void testToken2Jwt() {\n Assertions.assertThrows(BizException.class, () -> tokenParseHelper.token2Jwt(null));\n }", + "completions": [ + "```java\n @Test\n public void testToken2Jwt() {\n // Given\n String token = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c\";\n\n // When\n SignedJWT jwt = token2Jwt(token);\n\n // Then\n assertNotNull(jwt);\n assertEquals(token, jwt.serialize());\n }\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testLongPollWaitIsLimited()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testLongPollWaitIsLimited", + "reference": " @Test\n public void testLongPollWaitIsLimited() {\n consumer = newConsumer();\n String topicName = \"topic1\";\n consumer.subscribe(singletonList(topicName));\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertTrue(consumer.assignment().isEmpty());\n\n final int partition = 3;\n final TopicPartition tp = new TopicPartition(topicName, partition);\n final List> records = asList(\n new ConsumerRecord<>(topicName, partition, 2, \"key1\", \"value1\"),\n new ConsumerRecord<>(topicName, partition, 3, \"key2\", \"value2\")\n );\n\n // On the first iteration, return no data; on the second, return two records\n doAnswer(invocation -> {\n // Mock the subscription being assigned as the first fetch is collected\n consumer.subscriptions().assignFromSubscribed(Collections.singleton(tp));\n return Fetch.empty();\n }).doAnswer(invocation -> {\n return Fetch.forPartition(tp, records, true);\n }).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n\n // And then poll for up to 10000ms, which should return 2 records without timing out\n ConsumerRecords returnedRecords = consumer.poll(Duration.ofMillis(10000));\n assertEquals(2, returnedRecords.count());\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertEquals(singleton(tp), consumer.assignment());\n }", + "completions": [ + "```java\n @Test\n public void testLongPollWaitIsLimited() {\n // Arrange\n Duration timeout = Duration.ofMillis(1000);\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // Act\n long startTime = System.currentTimeMillis();\n ConsumerRecords records = consumer.poll(timeout);\n long endTime = System.currentTimeMillis();\n\n // Assert\n long pollDuration = endTime - startTime;\n assertTrue(pollDuration >= timeout.toMillis());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 76, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForMinAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForMinAgg", + "reference": " @Test\n public void testFullIndexSearchForMinAgg() {\n Instant time = Instant.now();\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new MinAggBuilder(\n \"test\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalMin internalMin =\n (InternalMin) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n assertThat(Double.valueOf(internalMin.getValue()).longValue()).isEqualTo(time.toEpochMilli());\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForMinAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new MinAggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation() instanceof MinAgg);\n assertTrue(result.getElapsedTimeMicros() > 0);\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 77, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_bsaChecksumFetched", + "reference": " @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n setupMocks();\n when(connection.getResponseCode()).thenReturn(SC_OK);\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream(\"bsa-checksum\\ndata\".getBytes(UTF_8)));\n LazyBlockList download = fetcher.tryFetch(BlockListType.BLOCK);\n assertThat(download.getName()).isEqualTo(BlockListType.BLOCK);\n assertThat(download.checksum()).isEqualTo(\"bsa-checksum\");\n verify(connection, times(1)).setRequestMethod(\"GET\");\n verify(connection, times(1)).setRequestProperty(\"Authorization\", \"Bearer authToken\");\n }", + "completions": [ + "```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n // Given\n URL dataUrl = new URL(\"https://example.com/bsaChecksum\");\n HttpsURLConnection connection = mock(HttpsURLConnection.class);\n when(connection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);\n when(connection.getErrorStream()).thenReturn(new ByteArrayInputStream(\"\".getBytes()));\n when(urlConnectionService.createConnection(dataUrl)).thenReturn(connection);\n\n // When\n LazyBlockList lazyBlockList = bsaService.tryFetch(BlockListType.BSA_CHECKSUM);\n\n // Then\n assertEquals(BlockListType.BSA_CHECKSUM, lazyBlockList.getBlockListType());\n assertEquals(connection, lazyBlockList.getConnection());\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 78, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDisabledFullTextSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDisabledFullTextSearch", + "reference": " @Test\n public void testDisabledFullTextSearch() {\n Instant time = Instant.ofEpochSecond(1593365471);\n Trace.KeyValue field1Tag =\n Trace.KeyValue.newBuilder()\n .setVInt32(1234)\n .setKey(\"field1\")\n .setFieldType(Schema.SchemaFieldType.INTEGER)\n .build();\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(2, \"apple baby\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(3, \"baby car 1234\", time.plusSeconds(4)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Without the _all field as default.\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(3);\n\n // empty string\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"app*\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Returns baby or car, 2 messages.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby car\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Test numbers\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"apple 1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"123\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n }", + "completions": [ + "```java\n @Test\n public void testDisabledFullTextSearch() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getResults().size());\n assertNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 79, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyWithException", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(\n OperateRuntimeException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHitsRelation.EQUAL_TO));\n Mockito.when(searchHits.getHits()).thenReturn(new SearchHit[0]);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processInstanceListViewService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException to be thrown\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find unique process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 80, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_SecurityException_notRetriable", + "reference": " @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n when(connectionService.createConnection(any(URL.class)))\n .thenThrow(new GeneralSecurityException());\n assertThat(\n assertThrows(BsaException.class, () -> fetcher.tryFetch(BlockListType.BLOCK))\n .isRetriable())\n .isFalse();\n }", + "completions": [ + "```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n // Arrange\n String blockListTypeName = \"TEST\";\n BlockListType blockListType = new BlockListType(blockListTypeName);\n URL dataUrl = new URL(\"http://example.com\");\n when(blockListUrls.get(blockListTypeName)).thenReturn(dataUrl.toString());\n when(urlConnectionService.createConnection(dataUrl)).thenThrow(new GeneralSecurityException());\n\n // Act and Assert\n assertThrows(BsaException.class, () -> {\n try {\n blockListService.tryFetch(blockListType);\n } catch (BsaException e) {\n assertFalse(e.isRetriable());\n throw e;\n }\n });\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 81, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssignOnEmptyTopicPartition()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssignOnEmptyTopicPartition", + "reference": " @Test\n public void testAssignOnEmptyTopicPartition() {\n consumer = newConsumer();\n completeUnsubscribeApplicationEventSuccessfully();\n\n consumer.assign(Collections.emptyList());\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().isEmpty());\n }", + "completions": [ + "```java\n @Test\n public void testAssignOnEmptyTopicPartition() {\n // Given\n Collection partitions = new ArrayList<>();\n ConsumerImpl consumer = new ConsumerImpl();\n\n // When\n consumer.assign(partitions);\n\n // Then\n // Add assertions to verify the expected behavior\n // For example, you can check if the consumer is unsubscribed and if the fetch buffer is empty\n // assertFalse(consumer.isSubscribed());\n // assertTrue(consumer.getFetchBuffer().isEmpty());\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 82, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ImmutableList applyLabelDiff(\n ImmutableList labels,\n IdnChecker idnChecker,\n DownloadSchedule schedule,\n DateTime now) {\n ImmutableList.Builder nonBlockedDomains = new ImmutableList.Builder<>();\n ImmutableMap> labelsByType =\n ImmutableMap.copyOf(\n labels.stream().collect(groupingBy(BlockLabel::labelType, toImmutableList())));\n\n tm().transact(\n () -> {\n for (Map.Entry> entry :\n labelsByType.entrySet()) {\n switch (entry.getKey()) {\n case CREATE:\n // With current Cloud SQL, label upsert throughput is about 200/second. If\n // better performance is needed, consider bulk insert in native SQL.\n tm().putAll(\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(\n label ->\n new BsaLabel(label.label(), schedule.jobCreationTime()))\n .collect(toImmutableList()));\n // May not find all unblockables due to race condition: DomainCreateFlow uses\n // cached BsaLabels. Eventually will be consistent.\n nonBlockedDomains.addAll(\n tallyUnblockableDomainsForNewLabels(entry.getValue(), idnChecker, now));\n break;\n case DELETE:\n ImmutableSet deletedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n // Delete labels in DB. Also cascade-delete BsaUnblockableDomain.\n int nDeleted = Queries.deleteBsaLabelByLabels(deletedLabels);\n if (nDeleted != deletedLabels.size()) {\n logger.atSevere().log(\n \"Only found %s entities among the %s labels: [%s]\",\n nDeleted, deletedLabels.size(), deletedLabels);\n }\n break;\n case NEW_ORDER_ASSOCIATION:\n ImmutableSet affectedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n ImmutableSet labelsInDb =\n Queries.queryBsaLabelByLabels(affectedLabels)\n .map(BsaLabel::getLabel)\n .collect(toImmutableSet());\n verify(\n labelsInDb.size() == affectedLabels.size(),\n \"Missing labels in DB: %s\",\n LazyArgs.lazy(() -> difference(affectedLabels, labelsInDb)));\n\n // Reuse registered and reserved names that are already computed.\n Queries.queryBsaUnblockableDomainByLabels(affectedLabels)\n .map(BsaUnblockableDomain::toUnblockableDomain)\n .forEach(nonBlockedDomains::add);\n\n for (BlockLabel label : entry.getValue()) {\n getInvalidTldsForLabel(label, idnChecker)\n .map(tld -> UnblockableDomain.of(label.label(), tld, Reason.INVALID))\n .forEach(nonBlockedDomains::add);\n }\n break;\n }\n }\n },\n TRANSACTION_REPEATABLE_READ);\n logger.atInfo().log(\"Processed %s of labels.\", labels.size());\n return nonBlockedDomains.build();\n }\n```\n###test function signature:\n```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "applyLabelDiffs_newAssociationOfLabelToOrder", + "reference": " @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n tm().transact(\n () -> {\n tm().insert(new BsaLabel(\"label\", fakeClock.nowUtc()));\n tm().insert(new BsaUnblockableDomain(\"label\", \"app\", Reason.REGISTERED));\n });\n when(idnChecker.getSupportingTlds(any())).thenReturn(ImmutableSet.of(app));\n when(idnChecker.getForbiddingTlds(any()))\n .thenReturn(Sets.difference(ImmutableSet.of(dev), ImmutableSet.of()).immutableCopy());\n\n ImmutableList unblockableDomains =\n applyLabelDiff(\n ImmutableList.of(\n BlockLabel.of(\"label\", LabelType.NEW_ORDER_ASSOCIATION, ImmutableSet.of())),\n idnChecker,\n schedule,\n fakeClock.nowUtc());\n assertThat(unblockableDomains)\n .containsExactly(\n UnblockableDomain.of(\"label.app\", UnblockableDomain.Reason.REGISTERED),\n UnblockableDomain.of(\"label.dev\", UnblockableDomain.Reason.INVALID));\n assertThat(tm().transact(() -> tm().loadByKeyIfPresent(BsaLabel.vKey(\"label\")))).isPresent();\n assertThat(\n tm().transact(() -> tm().loadByKeyIfPresent(BsaUnblockableDomain.vKey(\"label\", \"app\"))))\n .isPresent();\n }", + "completions": [ + "```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n // Given\n ImmutableList labels = ImmutableList.of(BlockLabel.of(\"label1\", LabelType.NEW_ORDER_ASSOCIATION));\n IdnChecker idnChecker = mock(IdnChecker.class);\n DownloadSchedule schedule = mock(DownloadSchedule.class);\n DateTime now = DateTime.now();\n\n // When\n ImmutableList result = applyLabelDiff(labels, idnChecker, schedule, now);\n\n // Then\n // Add assertions here to verify the result\n }\n```\n" + ], + "line": 85, + "token": 602, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 83, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture createTopic(CreateTopicRequest request) {\n CompletableFuture future = new CompletableFuture<>();\n for (; ; ) {\n if (metadataStore.isLeader()) {\n try (SqlSession session = metadataStore.openSession()) {\n if (!metadataStore.maintainLeadershipWithSharedLock(session)) {\n continue;\n }\n TopicMapper topicMapper = session.getMapper(TopicMapper.class);\n if (null != topicMapper.get(null, request.getTopic())) {\n ControllerException e = new ControllerException(Code.DUPLICATED_VALUE,\n String.format(\"Topic %s was taken\", request.getTopic()));\n future.completeExceptionally(e);\n return future;\n }\n\n Topic topic = new Topic();\n topic.setName(request.getTopic());\n topic.setQueueNum(request.getCount());\n topic.setStatus(TopicStatus.TOPIC_STATUS_ACTIVE);\n topic.setAcceptMessageTypes(JsonFormat.printer().print(request.getAcceptTypes()));\n topic.setRetentionHours(request.getRetentionHours());\n topicMapper.create(topic);\n long topicId = topic.getId();\n List assignments = createQueues(IntStream.range(0, request.getCount()),\n topicId, session);\n // Commit transaction\n session.commit();\n\n // Cache new topic and queue assignments immediately\n topicCache.apply(List.of(topic));\n assignmentCache.apply(assignments);\n future.complete(topicId);\n } catch (ControllerException | InvalidProtocolBufferException e) {\n future.completeExceptionally(e);\n }\n return future;\n } else {\n Optional leaderAddress = metadataStore.electionService().leaderAddress();\n if (leaderAddress.isEmpty()) {\n return CompletableFuture.failedFuture(new ControllerException(Code.NO_LEADER_VALUE, \"No leader is elected yet\"));\n }\n metadataStore.controllerClient().createTopic(leaderAddress.get(), request).whenComplete((res, e) -> {\n if (null != e) {\n future.completeExceptionally(e);\n } else {\n future.complete(res);\n }\n });\n }\n break;\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCreateTopic()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCreateTopic", + "reference": " @Test\n public void testCreateTopic() {\n MetadataStore metadataStore = Mockito.mock(MetadataStore.class);\n ControllerClient controllerClient = Mockito.mock(ControllerClient.class);\n Mockito.when(controllerClient.createTopic(ArgumentMatchers.anyString(), ArgumentMatchers.any())).thenReturn(\n CompletableFuture.failedFuture(new CompletionException(new RuntimeException()))\n );\n Mockito.when(metadataStore.controllerClient()).thenReturn(controllerClient);\n Mockito.when(metadataStore.isLeader()).thenReturn(false);\n\n ElectionService electionService = Mockito.mock(ElectionService.class);\n Mockito.when(metadataStore.electionService()).thenReturn(electionService);\n\n Mockito.when(electionService.leaderAddress()).thenReturn(Optional.of(\"localhost:1234\"));\n\n TopicManager topicManager = new TopicManager(metadataStore);\n topicManager.createTopic(CreateTopicRequest.newBuilder().build());\n }", + "completions": [ + "```java\n @Test\n public void testCreateTopic() {\n // Given\n CreateTopicRequest request = new CreateTopicRequest();\n request.setTopic(\"testTopic\");\n request.setCount(3);\n request.setRetentionHours(24);\n request.setAcceptTypes(List.of(MessageType.MESSAGE_TYPE_JSON));\n\n // When\n CompletableFuture future = controller.createTopic(request);\n\n // Then\n assertDoesNotThrow(() -> {\n Long topicId = future.get();\n assertNotNull(topicId);\n assertTrue(topicId > 0);\n });\n }\n```\n" + ], + "line": 63, + "token": 424, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 84, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull(){\n try(MockedStatic sessionMockedStatic = Mockito.mockStatic(Session.class)){\n Session mockSession = mock(Session.class);\n Transport transport = mock(Transport.class);\n sessionMockedStatic.when(()->Session.getDefaultInstance(any(Properties.class))).thenReturn(mockSession);\n when(blacklistService.inBlacklist(anyString())).thenReturn(false);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n when(settingsService.getByCategoryAndKey(anyString(),anyString())).thenReturn(\"test\").thenReturn(5678).thenReturn(\"test\").thenReturn(\"test\").thenReturn(\"123456\");\n when(mockSession.getTransport(anyString())).thenReturn(transport);\n doNothing().when(transport).connect(anyString(),anyInt(),anyString(),anyString());\n doAnswer(invocationOnMock -> {\n InternetAddress[] internetAddressList = invocationOnMock.getArgument(1);\n Assertions.assertEquals(addressList.get(0),internetAddressList[0].getAddress());\n return null;\n }).when(transport).sendMessage(any(MimeMessage.class),any());\n mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n } catch (MessagingException e) {\n throw new RuntimeException(e);\n }\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = Arrays.asList(\"test1@example.com\", \"test2@example.com\");\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://example.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 85, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead", + "reference": " @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 1,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n AstraConfigs.IndexerConfig headLocationAndNoRecoveryConfig =\n AstraConfigs.IndexerConfig.newBuilder()\n .setCreateRecoveryTasksOnStart(false)\n .setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST)\n .build();\n\n // When there is no data and ReadFromLocationOnStart is set to LATEST, return the current head\n assertThat(\n recoveryTaskCreator.determineStartingOffset(1000, 0, headLocationAndNoRecoveryConfig))\n .isEqualTo(1000);\n\n // Data exists for not for this partition.\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTime = 1;\n final long endTime = 100;\n final long maxOffset = 100;\n\n final SnapshotMetadata partition1 =\n new SnapshotMetadata(name, path, startTime, endTime, maxOffset, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition1);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition1));\n\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n\n final SnapshotMetadata partition11 =\n new SnapshotMetadata(\n name + \"1\", path, endTime + 1, endTime * 2, maxOffset * 2, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition11);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition11));\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(partition1, partition11);\n assertThat(recoveryTaskCreator.determineStartingOffset(0, 0, indexerConfig)).isNegative();\n\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n \"2\",\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(false);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n\n // When\n long result = determineStartingOffset(currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(currentEndOffsetForPartition, result);\n }\n```\n" + ], + "line": 128, + "token": 935, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 86, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime =\n LocalDateTime.of(2020, 10, 1, 10, 10, 0).atZone(ZoneOffset.UTC).toInstant();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n long startOffset = 1;\n long endOffset = msgsToProduce - 1;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry))\n .isEqualTo(endOffset - startOffset + 1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(0);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(0);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isEmpty();\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 0,\n 100,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 87, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"3g\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n String expirationTime = \"10x\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(expirationTime);\n assertEquals(CacheControlHeader.DEFAULT_BASE_VALUE + CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 88, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_empty_collection()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_empty_collection", + "reference": " @Test\n public void should_clone_empty_collection() {\n List original = new ArrayList<>();\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_empty_collection() {\n // Given\n Collection collection = new ArrayList<>();\n\n // When\n Collection clonedCollection = (Collection) clone(collection);\n\n // Then\n assertNotSame(collection, clonedCollection);\n assertEquals(collection, clonedCollection);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 89, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testTimeBoundSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testTimeBoundSearch", + "reference": " @Test\n public void testTimeBoundSearch() {\n Instant time = Instant.now();\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(1, time));\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(2, time.plusSeconds(100)));\n strictLogStore.logStore.commit();\n strictLogStore.logStore.refresh();\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(2);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(0);\n assertThat(getTimerCount(REFRESHES_TIMER, strictLogStore.metricsRegistry)).isEqualTo(1);\n\n // Start inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Extended range still only picking one element.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(90).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Both ranges are inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.toEpochMilli(),\n time.plusSeconds(100).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n // Extended range to pick up both events\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(1000).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n }", + "completions": [ + "```java\n @Test\n public void testTimeBoundSearch() {\n // Given\n String dataset = \"test_dataset\";\n String queryStr = \"test_query\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n for (LogMessage logMessage : result.getResults()) {\n assertTrue(logMessage.getTimestampMs() >= startTimeMsEpoch);\n assertTrue(logMessage.getTimestampMs() <= endTimeMsEpoch);\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 90, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "\t@ParameterizedTest\n\t@MethodSource", + "reference": "\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\t// setup\n\t\tString key = \"key\";\n\t\tTestParameters parameters = new TestParameters();\n\t\tObject testValue = theParams.getTestValue();\n\n\t\t// test\n\t\tif (theParams.isExpectedToWork()) {\n\t\t\tparameters.setUserData(key, testValue);\n\t\t\tassertFalse(parameters.getUserData().isEmpty());\n\t\t\tassertEquals(testValue, parameters.getUserData().get(key));\n\t\t} else {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, testValue);\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tString dataType = testValue.getClass().getName();\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid data type provided \" + dataType),\n\t\t\t\t\tex.getMessage());\n\t\t\t\tassertTrue(parameters.getUserData().isEmpty());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\tUser user = new User();\n\t\tuser.setUserData(theParams.getKey(), theParams.getValue());\n\t\tassertEquals(theParams.getValue(), user.getUserData().get(theParams.getKey()));\n\t}\n\n\tstatic Stream parameters() {\n\t\treturn Stream.of(\n\t\t\t\tnew TestParam(\"stringKey\", \"stringValue\"),\n\t\t\t\tnew TestParam(\"numberKey\", 123),\n\t\t\t\tnew TestParam(\"booleanKey\", true),\n\t\t\t\tnew TestParam(\"nullKey\", null)\n\t\t);\n\t}\n\n\tstatic class TestParam {\n\t\tprivate final String key;\n\t\tprivate final Object value;\n\n\t\tTestParam(String key, Object value) {\n\t\t\tthis.key = key;\n\t\t\tthis.value = value;\n\t\t}\n\n\t\tString getKey() {\n\t\t\treturn key;\n\t\t}\n\n\t\tObject getValue() {\n\t\t\treturn value;\n\t\t}\n\t}\n```\n" + ], + "line": 24, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\n\t\tString text = \"Hello, how are you?\";\n\n\t\tEmbedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));\n\n\t\tserver\n\t\t\t.expect(requestToUriTemplate(\"/models/{generative}:embedText?key={apiKey}\",\n\t\t\t\t\tVertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))\n\t\t\t.andExpect(method(HttpMethod.POST))\n\t\t\t.andExpect(content().json(objectMapper.writeValueAsString(Map.of(\"text\", text))))\n\t\t\t.andRespond(withSuccess(objectMapper.writeValueAsString(Map.of(\"embedding\", expectedEmbedding)),\n\t\t\t\t\tMediaType.APPLICATION_JSON));\n\n\t\tEmbedding embedding = client.embedText(text);\n\n\t\tassertThat(embedding).isEqualTo(expectedEmbedding);\n\n\t\tserver.verify();\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\tEmbeddingResponse response = new EmbeddingResponse(expectedEmbedding);\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(response);\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Mono syncAclWithAclCsv(KafkaCluster cluster, String csv) {\n return adminClientService.get(cluster)\n .flatMap(ac -> ac.listAcls(ResourcePatternFilter.ANY).flatMap(existingAclList -> {\n var existingSet = Set.copyOf(existingAclList);\n var newAcls = Set.copyOf(AclCsv.parseCsv(csv));\n var toDelete = Sets.difference(existingSet, newAcls);\n var toAdd = Sets.difference(newAcls, existingSet);\n logAclSyncPlan(cluster, toAdd, toDelete);\n if (toAdd.isEmpty() && toDelete.isEmpty()) {\n return Mono.empty();\n }\n log.info(\"Starting new ACLs creation\");\n return ac.createAcls(toAdd)\n .doOnSuccess(v -> {\n log.info(\"{} new ACLs created\", toAdd.size());\n log.info(\"Starting ACLs deletion\");\n })\n .then(ac.deleteAcls(toDelete)\n .doOnSuccess(v -> log.info(\"{} ACLs deleted\", toDelete.size())));\n }));\n }\n```\n###test function signature:\n```java\n @Test\n void testSyncAclWithAclCsv()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSyncAclWithAclCsv", + "reference": " @Test\n void testSyncAclWithAclCsv() {\n var existingBinding1 = new AclBinding(\n new ResourcePattern(ResourceType.TOPIC, \"*\", PatternType.LITERAL),\n new AccessControlEntry(\"User:test1\", \"*\", AclOperation.READ, AclPermissionType.ALLOW));\n\n var existingBinding2 = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"group1\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test2\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n var newBindingToBeAdded = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"groupNew\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test3\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n when(adminClientMock.listAcls(ResourcePatternFilter.ANY))\n .thenReturn(Mono.just(List.of(existingBinding1, existingBinding2)));\n\n ArgumentCaptor> createdCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.createAcls(createdCaptor.capture()))\n .thenReturn(Mono.empty());\n\n ArgumentCaptor> deletedCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.deleteAcls(deletedCaptor.capture()))\n .thenReturn(Mono.empty());\n\n aclsService.syncAclWithAclCsv(\n CLUSTER,\n \"Principal,ResourceType, PatternType, ResourceName,Operation,PermissionType,Host\\n\"\n + \"User:test1,TOPIC,LITERAL,*,READ,ALLOW,*\\n\"\n + \"User:test3,GROUP,PREFIXED,groupNew,DESCRIBE,DENY,localhost\"\n ).block();\n\n Collection createdBindings = createdCaptor.getValue();\n assertThat(createdBindings)\n .hasSize(1)\n .contains(newBindingToBeAdded);\n\n Collection deletedBindings = deletedCaptor.getValue();\n assertThat(deletedBindings)\n .hasSize(1)\n .contains(existingBinding2);\n }", + "completions": [ + "```java\n @Test\n void testSyncAclWithAclCsv() {\n // Given\n KafkaCluster cluster = new KafkaCluster(\"testCluster\", \"localhost:9092\");\n String csv = \"principal,host,operation,permissionType\\nUser:Alice,*,Read,Allow\\nUser:Bob,*,Write,Allow\";\n AdminClient adminClient = Mockito.mock(AdminClient.class);\n when(adminClientService.get(cluster)).thenReturn(Mono.just(adminClient));\n when(adminClient.listAcls(ResourcePatternFilter.ANY)).thenReturn(Flux.just(\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Alice\", \"*\", AclOperation.READ, AclPermissionType.ALLOW)),\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Bob\", \"*\", AclOperation.WRITE, AclPermissionType.ALLOW))\n ).collectList());\n\n // When\n Mono result = service.syncAclWithAclCsv(cluster, csv);\n\n // Then\n StepVerifier.create(result)\n .verifyComplete();\n\n verify(adminClient).createAcls(any());\n verify(adminClient).deleteAcls(any());\n }\n```\n" + ], + "line": 32, + "token": 259, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 93, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyNoResults", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final SearchResponse mockResponse = mock(SearchResponse.class);\n final SearchHits mockHits = mock(SearchHits.class);\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n when(mockHits.getHits()).thenReturn(new SearchHit[0]);\n when(tenantAwareClient.search(any(SearchRequest.class))).thenReturn(mockResponse);\n\n // When\n try {\n processInstanceService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 94, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyMultipleRecoveryBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(3);\n assertThat(recoveryTasks).contains(recoveryTask1, recoveryTask11);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo(recoveryStartOffset * 3 + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(100);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 95, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n private String saveAndChecksumBlockList(String jobName, LazyBlockList blockList) {\n BlobId blobId = getBlobId(jobName, getBlockListFileName(blockList.getName()));\n try (BufferedOutputStream gcsWriter =\n new BufferedOutputStream(gcsUtils.openOutputStream(blobId))) {\n MessageDigest messageDigest = MessageDigest.getInstance(checksumAlgorithm);\n blockList.consumeAll(\n (byteArray, length) -> {\n try {\n gcsWriter.write(byteArray, 0, length);\n } catch (IOException e) {\n throw new RuntimeException(e);\n }\n messageDigest.update(byteArray, 0, length);\n });\n return base16().lowerCase().encode(messageDigest.digest());\n } catch (IOException | NoSuchAlgorithmException e) {\n throw new RuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "saveAndChecksumBlockList_success", + "reference": " @Test\n void saveAndChecksumBlockList_success() throws Exception {\n String payload = \"somedata\\n\";\n String payloadChecksum = \"0737c8e591c68b93feccde50829aca86a80137547d8cfbe96bab6b20f8580c63\";\n\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream((\"bsa-checksum\\n\" + payload).getBytes(UTF_8)));\n lazyBlockList = new LazyBlockList(BlockListType.BLOCK, connection);\n\n ImmutableMap checksums =\n gcsClient.saveAndChecksumBlockList(\"some-name\", ImmutableList.of(lazyBlockList));\n assertThat(gcsUtils.existsAndNotEmpty(BlobId.of(\"my-bucket\", \"some-name/BLOCK.csv\"))).isTrue();\n assertThat(checksums).containsExactly(BlockListType.BLOCK, payloadChecksum);\n assertThat(gcsClient.readBlockList(\"some-name\", BlockListType.BLOCK))\n .containsExactly(\"somedata\");\n }", + "completions": [ + "```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception {\n // Given\n String jobName = \"testJob\";\n LazyBlockList blockList = mock(LazyBlockList.class);\n BlobId blobId = mock(BlobId.class);\n BufferedOutputStream gcsWriter = mock(BufferedOutputStream.class);\n MessageDigest messageDigest = mock(MessageDigest.class);\n when(gcsUtils.openOutputStream(blobId)).thenReturn(gcsWriter);\n when(checksumAlgorithm).thenReturn(\"SHA-256\");\n when(messageDigest.digest()).thenReturn(new byte[] {1, 2, 3, 4, 5});\n when(base16().lowerCase().encode(messageDigest.digest())).thenReturn(\"0102030405\");\n\n // When\n String checksum = saveAndChecksumBlockList(jobName, blockList);\n\n // Then\n assertEquals(\"0102030405\", checksum);\n verify(gcsWriter).write(any(), anyInt(), anyInt());\n verify(messageDigest).update(any(), anyInt(), anyInt());\n }\n```\n" + ], + "line": 30, + "token": 211, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ListenableFuture> findMissingBlobs(\n Iterable blobDigests, RequestMetadata requestMetadata) {\n // Some requests have been blocked, and we should tell the client we refuse to perform a lookup.\n try {\n if (inDenyList(requestMetadata)) {\n return immediateFailedFuture(\n Status.UNAVAILABLE\n .withDescription(\"The action associated with this request is forbidden\")\n .asException());\n }\n } catch (IOException e) {\n return immediateFailedFuture(Status.fromThrowable(e).asException());\n }\n\n // Empty blobs are an exceptional case. Filter them out.\n // If the user only requested empty blobs we can immediately tell them we already have it.\n Iterable nonEmptyDigests =\n Iterables.filter(blobDigests, (digest) -> digest.getSizeBytes() != 0);\n if (Iterables.isEmpty(nonEmptyDigests)) {\n return immediateFuture(ImmutableList.of());\n }\n\n if (configs.getServer().isFindMissingBlobsViaBackplane()) {\n return findMissingBlobsViaBackplane(nonEmptyDigests, requestMetadata);\n }\n\n return findMissingBlobsQueryingEachWorker(nonEmptyDigests, requestMetadata);\n }\n```\n###test function signature:\n```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "findMissingBlobsTest_ViaBackPlane", + "reference": " @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n Set activeWorkers = ImmutableSet.of(\"worker1\", \"worker2\", \"worker3\");\n Set expiredWorkers = ImmutableSet.of(\"workerX\", \"workerY\", \"workerZ\");\n Set imposterWorkers = ImmutableSet.of(\"imposter1\", \"imposter2\", \"imposter3\");\n\n Set availableDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFound1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build());\n\n Set missingDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"missing1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build());\n\n Set digestAvailableOnImposters =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFoundOnImposter1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter3\").setSizeBytes(1).build());\n\n Set emptyDigests =\n new HashSet<>(\n Arrays.asList(\n Digest.newBuilder().setHash(\"empty1\").build(),\n Digest.newBuilder().setHash(\"empty2\").build()));\n\n Iterable allDigests =\n Iterables.concat(\n availableDigests,\n missingDigests,\n emptyDigests,\n digestAvailableOnImposters,\n Arrays.asList(\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build()));\n\n Map> digestAndWorkersMap = new HashMap<>();\n\n for (Digest digest : availableDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(activeWorkers));\n }\n for (Digest digest : missingDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(expiredWorkers));\n }\n for (Digest digest : digestAvailableOnImposters) {\n digestAndWorkersMap.put(digest, getRandomSubset(imposterWorkers));\n }\n\n BuildfarmConfigs buildfarmConfigs = instance.getBuildFarmConfigs();\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(true);\n Set activeAndImposterWorkers =\n Sets.newHashSet(Iterables.concat(activeWorkers, imposterWorkers));\n\n when(mockBackplane.getStorageWorkers()).thenReturn(activeAndImposterWorkers);\n when(mockBackplane.getBlobDigestsWorkers(any(Iterable.class))).thenReturn(digestAndWorkersMap);\n when(mockInstanceLoader.load(anyString())).thenReturn(mockWorkerInstance);\n when(mockWorkerInstance.findMissingBlobs(anyIterable(), any(RequestMetadata.class)))\n .thenReturn(Futures.immediateFuture(new ArrayList<>()));\n\n long serverStartTime = 1686951033L; // june 15th, 2023\n Map workersStartTime = new HashMap<>();\n for (String worker : activeAndImposterWorkers) {\n workersStartTime.put(worker, serverStartTime);\n }\n when(mockBackplane.getWorkersStartTimeInEpochSecs(activeAndImposterWorkers))\n .thenReturn(workersStartTime);\n long oneDay = 86400L;\n for (Digest digest : availableDigests) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime + oneDay);\n }\n for (Digest digest : digestAvailableOnImposters) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime - oneDay);\n }\n\n Iterable actualMissingDigests =\n instance.findMissingBlobs(allDigests, RequestMetadata.getDefaultInstance()).get();\n Iterable expectedMissingDigests =\n Iterables.concat(missingDigests, digestAvailableOnImposters);\n\n assertThat(actualMissingDigests).containsExactlyElementsIn(expectedMissingDigests);\n verify(mockWorkerInstance, atMost(3))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n verify(mockWorkerInstance, atLeast(1))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n\n for (Digest digest : actualMissingDigests) {\n assertThat(digest).isNotIn(availableDigests);\n assertThat(digest).isNotIn(emptyDigests);\n assertThat(digest).isIn(expectedMissingDigests);\n }\n\n // reset BuildfarmConfigs\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(false);\n }", + "completions": [ + "```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n // Given\n Iterable blobDigests = Arrays.asList(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n RequestMetadata requestMetadata = RequestMetadata.newBuilder().setClientId(\"client1\").build();\n\n // When\n ListenableFuture> future =\n underTest.findMissingBlobs(blobDigests, requestMetadata);\n\n // Then\n Iterable missingBlobs = future.get();\n assertThat(missingBlobs).containsExactly(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n }\n```\n" + ], + "line": 37, + "token": 276, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 97, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent_notFound()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent_notFound", + "reference": " @Test\n public void testReadComponent_notFound() {\n Component component = mock(Component.class);\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(null);\n try {\n underTest.readComponent(\"someid\", \"testRepositoryName\");\n fail(\"Exception should have been thrown\");\n }\n catch (WebApplicationException e) {\n assertThat(e.getResponse(), is(notNullValue()));\n assertThat(e.getResponse().getStatus(), is(404));\n }\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent_notFound() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(() -> storageTx);\n when(storageTx.findComponent(componentId)).thenReturn(null);\n\n // When\n try {\n readComponent(repository, componentId);\n fail(\"Expected WebApplicationException\");\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Status.NOT_FOUND.getStatusCode(), e.getResponse().getStatus());\n }\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture auth(MQTT3AuthData authData) {\n try {\n HttpRequest request = HttpRequest.newBuilder()\n .uri(URI.create(webhookURI + \"/auth\"))\n .header(\"Content-Type\", \"application/json\")\n .POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))\n .timeout(Duration.ofSeconds(5))\n .build();\n return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())\n .thenApply(response -> {\n if (response.statusCode() == 200) {\n try {\n MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();\n JsonFormat.parser()\n .ignoringUnknownFields()\n .merge(response.body(), resultBuilder);\n return resultBuilder.build();\n } catch (Throwable e) {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build();\n }\n } else {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(\"Authenticate failed\")\n .build())\n .build();\n }\n })\n .exceptionally(e -> {\n System.out.println(\"Failed to call webhook: \" + e.getMessage());\n return null;\n });\n } catch (Throwable e) {\n return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAuth()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAuth", + "reference": " @Test\n public void testAuth() {\n MQTT3AuthData authedUser = MQTT3AuthData.newBuilder()\n .setUsername(\"authUser\")\n .build();\n\n MQTT3AuthData unauthUser = MQTT3AuthData.newBuilder()\n .setUsername(\"unauthUser\")\n .build();\n testServer.addAuthedUser(authedUser.getUsername());\n WebHookBasedAuthProvider provider = new WebHookBasedAuthProvider(testServer.getURI());\n MQTT3AuthResult authResult = provider.auth(authedUser).join();\n assertTrue(authResult.hasOk());\n\n authResult = provider.auth(unauthUser).join();\n assertFalse(authResult.hasOk());\n }", + "completions": [ + "```java\n @Test\n public void testAuth() {\n // Given\n String webhookURI = \"http://localhost:8080\";\n HttpClient httpClient = HttpClient.newHttpClient();\n MQTT3AuthData authData = MQTT3AuthData.newBuilder()\n .setClientId(\"client1\")\n .setUsername(\"user1\")\n .setPassword(\"password1\")\n .build();\n MQTT3AuthResult expectedResult = MQTT3AuthResult.newBuilder()\n .setSuccess(Success.newBuilder()\n .setIsSuperuser(false)\n .build())\n .build();\n\n // When\n CompletableFuture result = new MQTT3AuthService(webhookURI, httpClient).auth(authData);\n\n // Then\n assertNotNull(result);\n assertEquals(expectedResult, result.join());\n }\n```\n" + ], + "line": 59, + "token": 335, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 99, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_success", + "reference": " @Test\n public void testValidateRoleDuplicate_success() {\n // 调用,不会抛异常\n roleService.validateRoleDuplicate(randomString(), randomString(), null);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_success() {\n // Given\n String name = \"TestRole\";\n String code = \"TEST_ROLE\";\n Long id = 1L;\n\n // When\n // Call the method to be tested\n validateRoleDuplicate(name, code, id);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 1, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/4/QS/line_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Java/4/QS/line_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..b2e4eabe47222b5ec58a103f6e11dbe059ea1801 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/4/QS/line_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(22~32),subset_1(32~44),subset_2(44~74),subset_3(74~128) +StarCoder2-15b,32.69,31.06,30.98,21.04 +CodeLlama-7b,35.41,33.18,31.16,21.78 +CodeLlama-13b,30.73,30.46,30.14,21.72 +CodeLlama-34b,33.20,30.11,30.81,22.53 +DeepSeek-Coder-1.3b,28.44,31.39,32.32,22.39 +DeepSeek-Coder-6.7b,20.08,22.14,19.90,22.57 +DeepSeek-Coder-33b,32.05,36.05,34.47,21.83 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/4/QS/token_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Java/4/QS/token_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..a69755970e5d25e1dcc94a65c722eb6e6f6bfbb4 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/4/QS/token_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~234),subset_1(241~349),subset_2(354~533),subset_3(533~952) +StarCoder2-15b,33.94,27.49,33.30,21.04 +CodeLlama-7b,34.71,31.43,33.62,21.78 +CodeLlama-13b,30.54,28.56,32.23,21.72 +CodeLlama-34b,32.93,27.21,33.99,22.53 +DeepSeek-Coder-1.3b,27.61,30.73,33.81,22.39 +DeepSeek-Coder-6.7b,19.56,23.28,19.28,22.57 +DeepSeek-Coder-33b,32.95,31.98,37.64,21.83 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/4/QS/tongji.py b/dataset/Test Generation/ComplexCodeEval-Java/4/QS/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/4/QS/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/5/EI/EI.json b/dataset/Test Generation/ComplexCodeEval-Java/5/EI/EI.json new file mode 100644 index 0000000000000000000000000000000000000000..40f5dc33f6a3ec050d629b7ecf257193df6e39a8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/5/EI/EI.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTask", + "reference": " @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(TEST_S3_BUCKET);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 100L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockKafkaConsumer.consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(true);\n Mockito.when(mockKafkaConsumer.close()).thenReturn(true);\n Mockito.when(mockChunkManager.stopAsync()).thenReturn(true);\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenReturn(true);\n Mockito.when(AstraKafkaConsumer.makeKafkaConfig(Mockito.any(), Mockito.anyInt())).thenReturn(new Properties());\n Mockito.when(RecoveryChunkManager.fromConfig(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(mockChunkManager);\n Mockito.when(adminClient.listConsumerGroupOffsets(Mockito.anyString())).thenReturn(new ConsumerGroupOffsets());\n Mockito.when(AstraConfig.getRecoveryConfig()).thenReturn(new RecoveryConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig()).thenReturn(new KafkaConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic()).thenReturn(\"testTopic\");\n Mockito.when(AstraConfig.getIndexerConfig()).thenReturn(new IndexerConfig());\n Mockito.when(AstraConfig.getS3Config()).thenReturn(new S3Config());\n Mockito.when(meterRegistry.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class));\n Mockito.when(Timer.start(Mockito.any())).thenReturn(Mockito.mock(Timer.Sample.class));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n Assert.assertTrue(result);\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).prepConsumerForConsumption(Mockito.anyLong());\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong());\n Mockito.verify(mockChunkManager, Mockito.times(1)).waitForRollOvers();\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).close();\n Mockito.verify(mockChunkManager, Mockito.times(1)).stopAsync();\n Mockito.verify(mockChunkManager, Mockito.times(1)).awaitTerminated(Mockito.anyLong());\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 1, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskFailure", + "reference": " @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n String fakeS3Bucket = \"fakeBucket\";\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(fakeS3Bucket);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isFalse();\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 10L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(false);\n Mockito.when(mockKafkaConsumer.close()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.stopAsync()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertFalse(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 2, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public int migration(String srcTable, List uuids, String targetTable) {\n if (Objects.isNull(srcTable) || Objects.isNull(targetTable) || CollectionUtils.isEmpty(uuids)) {\n throw new SwValidationException(\n ValidSubject.EVALUATION, \"srcTable, targetTable and uuids must not be null\");\n }\n var filter = TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build();\n return store.migration(DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(filter)\n .build()\n );\n }\n```\n###test function signature:\n```java\n @Test\n public void testMigration()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMigration", + "reference": " @Test\n public void testMigration() {\n var srcTable = \"p/1/space/1/eval/summary\";\n var targetTable = \"p/1/eval/summary\";\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(null, List.of(\"uuid1\"), targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, null, targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, List.of(\"uuid1\"), null));\n\n evaluationRepo.migration(srcTable, List.of(\"uuid1\", \"uuid2\"), targetTable);\n\n verify(dataStore, times(1)).migration(\n DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(\n TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(List.of(\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid1\")))\n .build(),\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid2\")))\n .build()))\n .build())\n .build()\n );\n }", + "completions": [ + "```java\n @Test\n public void testMigration() {\n // Given\n String srcTable = \"srcTable\";\n List uuids = Arrays.asList(\"uuid1\", \"uuid2\", \"uuid3\");\n String targetTable = \"targetTable\";\n DataStoreMigrationRequest expectedRequest = DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build()\n )\n .build();\n when(store.migration(any(DataStoreMigrationRequest.class))).thenReturn(1);\n\n // When\n int result = migrationService.migration(srcTable, uuids, targetTable);\n\n // Then\n assertEquals(1, result);\n verify(store).migration(expectedRequest);\n }\n```\n" + ], + "line": 37, + "token": 221, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDatasMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDatasMapTest", + "reference": " @Test\n public void sendDatasMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"doris_test_db\";\n List> dataList = new ArrayList<>();\n for (int i = 0; i < 100; i++) {\n Map data = new HashMap<>();\n data.put(\"id\", i);\n data.put(\"name\", \"张三\");\n data.put(\"message\", \"股东风波的法规的规定翻跟斗广泛的给yryrtytr\");\n System.out.println(gson.toJson(data));\n dataList.add(data);\n\n }\n dorisStreamLoad.sendData(database, tableName, columnList, dataList);\n }", + "completions": [ + "```java\n @Test\n public void sendDataTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 4, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_serializable_key_and_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_serializable_key_and_value", + "reference": " @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map original = new HashMap<>();\n original.put(\"key\", new SerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map originalMap = new HashMap<>();\n originalMap.put(\"one\", 1);\n originalMap.put(\"two\", 2);\n originalMap.put(\"three\", 3);\n\n Map clonedMap = clone(originalMap);\n\n assertNotSame(originalMap, clonedMap);\n assertEquals(originalMap, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 5, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void encode(String schemaName, long nextTransactionOffset) {\n if (schemaName == null) {\n schemaName = UNKNOWN;\n }\n this.schemaName = schemaName;\n this.nextTransactionOffset = nextTransactionOffset;\n final byte[] payloadBytes = payloadToBytes();\n final int payloadLength = payloadBytes.length;\n\n // set logEventHeader\n int eventSize = eventHeaderLengthVersionGt1 + payloadLength;\n setLogEventHeader(new LogEventHeader(drc_filter_log_event.getType(), 0, eventSize, nextTransactionOffset)\n );\n\n // set payload\n final ByteBuf payloadByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(payloadLength);\n payloadByteBuf.writeBytes(payloadBytes);\n payloadByteBuf.skipBytes(payloadLength);\n setPayloadBuf(payloadByteBuf);\n }\n```\n###test function signature:\n```java\n @Test\n public void encode()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "encode", + "reference": " @Test\n public void encode() {\n FilterLogEvent filterLogEvent = new FilterLogEvent();\n filterLogEvent.encode(\"drc1\", 101);\n Assert.assertEquals(\"drc1\", filterLogEvent.getSchemaName());\n Assert.assertEquals(101, filterLogEvent.getNextTransactionOffset());\n\n ByteBuf headerBuf = filterLogEvent.getLogEventHeader().getHeaderBuf();\n headerBuf.readerIndex(0);\n ByteBuf payloadBuf = filterLogEvent.getPayloadBuf();\n payloadBuf.readerIndex(0);\n CompositeByteBuf compositeByteBuf = PooledByteBufAllocator.DEFAULT.compositeDirectBuffer();\n compositeByteBuf.addComponents(true, headerBuf, payloadBuf);\n\n FilterLogEvent newFilterLogEvent = new FilterLogEvent();\n newFilterLogEvent.read(compositeByteBuf);\n compositeByteBuf.release(compositeByteBuf.refCnt());\n Assert.assertEquals(\"drc1\", newFilterLogEvent.getSchemaName());\n Assert.assertEquals(101, newFilterLogEvent.getNextTransactionOffset());\n }", + "completions": [ + "```java\n @Test\n public void encode() {\n // Given\n String schemaName = \"testSchema\";\n long nextTransactionOffset = 1234567890L;\n drc_filter_log_event event = new drc_filter_log_event();\n\n // When\n event.encode(schemaName, nextTransactionOffset);\n\n // Then\n assertEquals(schemaName, event.getSchemaName());\n assertEquals(nextTransactionOffset, event.getNextTransactionOffset());\n assertEquals(drc_filter_log_event.getType(), event.getLogEventHeader().getType());\n assertEquals(0, event.getLogEventHeader().getServerId());\n assertEquals(eventHeaderLengthVersionGt1 + event.getPayloadBuf().readableBytes(), event.getLogEventHeader().getEventLength());\n assertEquals(nextTransactionOffset, event.getLogEventHeader().getNextTransactionOffset());\n assertEquals(event.getPayloadBuf().readableBytes(), event.getPayloadBuf().readableBytes());\n }\n```\n" + ], + "line": 29, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void put(String path, long position, ByteBuf data) {\n writeLock.lock();\n try {\n int dataLength = data.readableBytes();\n NavigableMap cache = path2cache.computeIfAbsent(path, k -> new TreeMap<>());\n Map.Entry pos2value = cache.floorEntry(position);\n long cacheStartPosition;\n Value value;\n if (pos2value == null || pos2value.getKey() + pos2value.getValue().dataLength < position) {\n cacheStartPosition = position;\n value = Value.EMPTY;\n } else {\n cacheStartPosition = pos2value.getKey();\n value = pos2value.getValue();\n }\n // ensure the capacity, if the capacity change then update the cache index\n int moreCapacity = (int) ((position + dataLength) - (cacheStartPosition + value.blocks.length * (long) blockSize));\n int newDataLength = (int) (position + dataLength - cacheStartPosition);\n if (moreCapacity > 0) {\n int[] blocks = ensureCapacity(cacheStartPosition, moreCapacity);\n if (blocks == null) {\n return;\n }\n int[] newBlocks = new int[value.blocks.length + blocks.length];\n System.arraycopy(value.blocks, 0, newBlocks, 0, value.blocks.length);\n System.arraycopy(blocks, 0, newBlocks, value.blocks.length, blocks.length);\n value = new Value(newBlocks, newDataLength);\n } else {\n value = new Value(value.blocks, newDataLength);\n }\n cache.put(cacheStartPosition, value);\n lru.put(new Key(path, cacheStartPosition), value);\n\n // write data to cache\n ByteBuffer cacheByteBuffer = this.cacheByteBuffer.duplicate();\n int positionDelta = (int) (position - cacheStartPosition);\n int written = 0;\n ByteBuffer[] nioBuffers = data.nioBuffers();\n int[] blocks = value.blocks;\n for (ByteBuffer nioBuffer : nioBuffers) {\n ByteBuf buf = Unpooled.wrappedBuffer(nioBuffer);\n while (buf.readableBytes() > 0) {\n int writePosition = positionDelta + written;\n int block = blocks[writePosition / blockSize];\n cacheByteBuffer.position(block * blockSize + writePosition % blockSize);\n int length = Math.min(buf.readableBytes(), blockSize - writePosition % blockSize);\n cacheByteBuffer.put(buf.slice(buf.readerIndex(), length).nioBuffer());\n buf.skipBytes(length);\n written += length;\n }\n }\n } finally {\n writeLock.unlock();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testMergePut() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMergePut", + "reference": " @Test\n public void testMergePut() throws IOException {\n FileCache fileCache = new FileCache(\"/tmp/file_cache_test\", 10 * 1024, 1024);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 1, 500));\n buf.addComponent(true, genBuf((byte) 2, 500));\n buf.addComponent(true, genBuf((byte) 3, 500));\n fileCache.put(\"test\", 3333L, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(1500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 4, 500));\n buf.addComponent(true, genBuf((byte) 5, 500));\n buf.addComponent(true, genBuf((byte) 6, 500));\n fileCache.put(\"test\", 3333L + 1000, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(2500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 7, 500));\n fileCache.put(\"test\", 3333L + 1000 + 1500, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(3000, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n\n assertTrue(verify(fileCache.get(\"test\", 3333L, 500).get(), (byte) 1));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 500, 500).get(), (byte) 2));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1000, 500).get(), (byte) 4));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1500, 500).get(), (byte) 5));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2000, 500).get(), (byte) 6));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2500, 500).get(), (byte) 7));\n }", + "completions": [ + "```java\n @Test\n public void testMergePut() throws IOException {\n // Given\n String path = \"/test/path\";\n long position = 100;\n ByteBuf data = Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5});\n\n // When\n cache.put(path, position, data);\n\n // Then\n NavigableMap cacheMap = cache.path2cache.get(path);\n assertNotNull(cacheMap);\n Map.Entry entry = cacheMap.floorEntry(position);\n assertNotNull(entry);\n assertEquals(position, entry.getKey().longValue());\n assertEquals(5, entry.getValue().dataLength);\n assertEquals(1, entry.getValue().blocks[0]);\n assertEquals(2, entry.getValue().blocks[1]);\n assertEquals(3, entry.getValue().blocks[2]);\n assertEquals(4, entry.getValue().blocks[3]);\n assertEquals(5, entry.getValue().blocks[4]);\n }\n```\n" + ], + "line": 65, + "token": 508, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 7, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchById()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchById", + "reference": " @Test\n public void testSearchById() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult index =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(index.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchById() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 8, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = new ArrayList<>();\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 9, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_object", + "reference": " @Test\n public void should_clone_serializable_object() {\n Object original = new SerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_object() {\n // Given\n Serializable object = new Serializable() {\n private static final long serialVersionUID = 1L;\n // Implement necessary methods\n };\n\n // When\n Serializable clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertEquals(object, clonedObject);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 10, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() {\n\n\t\tvar text = \"Hello, how are you?\";\n\n\t\tEmbedding response = vertexAiPaLm2Api.embedText(text);\n\n\t\tassertThat(response).isNotNull();\n\t\tassertThat(response.value()).hasSize(768);\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(new EmbeddingResponse(expectedEmbedding));\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Optional maybeFetchErrorIntervalMs(short errorCode, int intervalMs) {\n if (errorCode == Errors.NONE.code())\n return Optional.empty();\n\n int pushIntervalMs;\n String reason;\n\n Errors error = Errors.forCode(errorCode);\n switch (error) {\n case INVALID_REQUEST:\n case INVALID_RECORD:\n case UNSUPPORTED_VERSION:\n pushIntervalMs = Integer.MAX_VALUE;\n reason = \"The broker response indicates the client sent an request that cannot be resolved\"\n + \" by re-trying, hence disable telemetry\";\n break;\n case UNKNOWN_SUBSCRIPTION_ID:\n case UNSUPPORTED_COMPRESSION_TYPE:\n pushIntervalMs = 0;\n reason = error.message();\n break;\n case TELEMETRY_TOO_LARGE:\n case THROTTLING_QUOTA_EXCEEDED:\n reason = error.message();\n pushIntervalMs = (intervalMs != -1) ? intervalMs : ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS;\n break;\n default:\n reason = \"Unwrapped error code\";\n log.error(\"Error code: {}. Unmapped error for telemetry, disable telemetry.\", errorCode);\n pushIntervalMs = Integer.MAX_VALUE;\n }\n\n log.debug(\"Error code: {}, reason: {}. Push interval update to {} ms.\", errorCode, reason, pushIntervalMs);\n return Optional.of(pushIntervalMs);\n }\n```\n###test function signature:\n```java\n @Test\n public void testMaybeFetchErrorIntervalMs()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMaybeFetchErrorIntervalMs", + "reference": " @Test\n public void testMaybeFetchErrorIntervalMs() {\n assertEquals(Optional.empty(), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.NONE.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_REQUEST.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_RECORD.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SUBSCRIPTION_ID.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNSUPPORTED_COMPRESSION_TYPE.code(), -1));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), 20000));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), 20000));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SERVER_ERROR.code(), -1));\n }", + "completions": [ + "```java\n @Test\n public void testMaybeFetchErrorIntervalMs() {\n // Test case 1: errorCode is Errors.NONE.code()\n short errorCode = Errors.NONE.code();\n int intervalMs = 1000;\n Optional result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertFalse(result.isPresent());\n\n // Test case 2: errorCode is Errors.INVALID_REQUEST.code()\n errorCode = Errors.INVALID_REQUEST.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 3: errorCode is Errors.UNKNOWN_SUBSCRIPTION_ID.code()\n errorCode = Errors.UNKNOWN_SUBSCRIPTION_ID.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 4: errorCode is Errors.TELEMETRY_TOO_LARGE.code()\n errorCode = Errors.TELEMETRY_TOO_LARGE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 5: errorCode is Errors.THROTTLING_QUOTA_EXCEEDED.code()\n errorCode = Errors.THROTTLING_QUOTA_EXCEEDED.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 6: errorCode is Errors.UNSUPPORTED_COMPRESSION_TYPE.code()\n errorCode = Errors.UNSUPPORTED_COMPRESSION_TYPE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 7: errorCode is Errors.INVALID_RECORD.code()\n errorCode = Errors.INVALID_RECORD.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 8: errorCode is Errors.UNSUPPORTED_VERSION.code()\n errorCode = Errors.UNSUPPORTED_VERSION.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 9: errorCode is an unmapped error\n errorCode = (short) 9999;\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n }\n```\n" + ], + "line": 43, + "token": 290, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 12, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_non_serializable_object", + "reference": " @Test\n public void should_clone_non_serializable_object() {\n Object original = new NonSerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_non_serializable_object() {\n // Given\n Map nonSerializableMap = new HashMap<>();\n nonSerializableMap.put(\"key1\", new Object());\n nonSerializableMap.put(\"key2\", new Object());\n\n // When\n Map clonedMap = objectCloner.clone(nonSerializableMap);\n\n // Then\n assertNotSame(nonSerializableMap, clonedMap);\n assertNotSame(nonSerializableMap.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(nonSerializableMap.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 13, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueOperationCompletesOperationWithCachedActionResult", + "reference": " @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ActionKey actionKey = DigestUtil.asActionKey(Digest.newBuilder().setHash(\"test\").build());\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"operation-with-cached-action-result\")\n .setActionDigest(actionKey.getDigest())\n .build();\n\n ActionResult actionResult = ActionResult.getDefaultInstance();\n\n when(mockBackplane.getActionResult(eq(actionKey))).thenReturn(actionResult);\n\n Poller poller = mock(Poller.class);\n\n instance.queue(executeEntry, poller, DEFAULT_TIMEOUT).get();\n\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(CACHE_CHECK));\n verify(mockBackplane, never()).putOperation(any(Operation.class), eq(QUEUED));\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(COMPLETED));\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n ActionResult actionResult = ActionResult.newBuilder().setExitCode(0).build();\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(\"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n when(cache.get(any(ActionKey.class))).thenReturn(actionResult);\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n future.get();\n\n verify(poller).pause();\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 14, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueActionFailsQueueEligibility", + "reference": " @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n Directory inputRoot = Directory.newBuilder().build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(false);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_INVALID)\n .setSubject(INVALID_PLATFORM)\n .setDescription(\n \"properties are not valid for queue eligibility: []. If you think your\"\n + \" queue should still accept these poperties without them being\"\n + \" specified in queue configuration, consider configuring the queue\"\n + \" with `allow_unmatched: True`\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(ACTION_DIGEST));\n when(executeEntry.getStdoutStreamName()).thenReturn(STDOUT_STREAM_NAME);\n when(executeEntry.getStderrStreamName()).thenReturn(STDERR_STREAM_NAME);\n when(executeEntry.getOperationName()).thenReturn(OPERATION_NAME);\n when(executeEntry.getRequestMetadata()).thenReturn(REQUEST_METADATA);\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n when(poller.pause()).thenThrow(new RuntimeException(\"Queue eligibility failed\"));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n\n assertThrows(RuntimeException.class, future::get);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new SerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new SerializableObject(\"name2\"),\n new SerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key\", new SerializableObject());\n SerializableObject object = new SerializableObject(map);\n\n // When\n SerializableObject clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertNotSame(object.getMap(), clonedObject.getMap());\n assertNotSame(object.getMap().get(\"key\"), clonedObject.getMap().get(\"key\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 16, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForSumAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForSumAgg", + "reference": " @Test\n public void testFullIndexSearchForSumAgg() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new SumAggBuilder(\"test\", TEST_SOURCE_LONG_PROPERTY, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalSum internalSum =\n (InternalSum) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n // 1, 3, 4, 5\n assertThat(internalSum.getValue()).isEqualTo(13);\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForSumAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new SumAggBuilder(\"testField\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertEquals(SumAggBuilder.class, result.getAggregation().getClass());\n assertEquals(\"testField\", ((SumAggBuilder) result.getAggregation()).getField());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 17, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void process(String base64AuthenticatorData, String signature, String clientDataJson, Fido2RegistrationData registration,\n Fido2AuthenticationData authenticationEntity) {\n log.debug(\"Registration: {}\", registration);\n\n AuthData authData = authenticatorDataParser.parseAssertionData(base64AuthenticatorData);\n commonVerifiers.verifyRpIdHash(authData, registration.getDomain());\n\n log.debug(\"User verification option: {}\", authenticationEntity.getUserVerificationOption());\n userVerificationVerifier.verifyUserVerificationOption(authenticationEntity.getUserVerificationOption(), authData);\n\n byte[] clientDataHash = DigestUtils.getSha256Digest().digest(base64Service.urlDecode(clientDataJson));\n\n try {\n int counter = authenticatorDataParser.parseCounter(authData.getCounters());\n commonVerifiers.verifyCounter(registration.getCounter(), counter);\n registration.setCounter(counter);\n\n JsonNode uncompressedECPointNode = dataMapperService.cborReadTree(base64Service.urlDecode(registration.getUncompressedECPoint()));\n PublicKey publicKey = coseService.createUncompressedPointFromCOSEPublicKey(uncompressedECPointNode);\n\n log.debug(\"Uncompressed ECpoint node: {}\", uncompressedECPointNode);\n log.debug(\"EC Public key hex: {}\", Hex.encodeHexString(publicKey.getEncoded()));\n log.debug(\"Registration algorithm: {}, default use: -7\", registration.getSignatureAlgorithm());\n authenticatorDataVerifier.verifyAssertionSignature(authData, clientDataHash, signature, publicKey, -7);\n\n } catch (Fido2CompromisedDevice ex) {\n log.error(\"Error compromised device: {}\", ex.getMessage());\n throw ex;\n } catch (Exception ex) {\n log.error(\"Error to check none assertion: {}\", ex.getMessage());\n throw new Fido2RuntimeException(\"Failed to check none assertion: {}\", ex.getMessage(), ex);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "process_ifCborReadTreeThrowException_fido2RuntimeException", + "reference": " @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n String base64AuthenticatorData = \"base64AuthenticatorData_test\";\n String signature = \"signature_test\";\n String clientDataJson = \"clientDataJson_test\";\n Fido2RegistrationData registration = mock(Fido2RegistrationData.class);\n Fido2AuthenticationData authenticationEntity = mock(Fido2AuthenticationData.class);\n\n when(authenticationEntity.getUserVerificationOption()).thenReturn(UserVerification.preferred);\n when(registration.getDomain()).thenReturn(\"domain_test\");\n when(registration.getCounter()).thenReturn(100);\n when(registration.getUncompressedECPoint()).thenReturn(\"uncompressedECPoint_test\");\n\n when(authenticatorDataParser.parseAssertionData(any())).thenReturn(mock(AuthData.class));\n when(base64Service.urlDecode(any(String.class))).thenReturn(\"decode_test\".getBytes());\n when(dataMapperService.cborReadTree(any(byte[].class))).thenThrow(new IOException(\"IOException_test\"));\n\n Fido2RuntimeException ex = assertThrows(Fido2RuntimeException.class, () -> noneAssertionFormatProcessor.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity));\n assertNotNull(ex);\n assertEquals(ex.getMessage(), \"IOException_test\");\n\n verify(log).debug(eq(\"Registration: {}\"), any(Fido2RegistrationData.class));\n verify(log).debug(eq(\"User verification option: {}\"), any(UserVerification.class));\n verify(commonVerifiers).verifyRpIdHash(any(AuthData.class), any(String.class));\n verify(authenticatorDataParser).parseCounter(any());\n verify(log).error(eq(\"Error to check none assertion: {}\"), any(String.class));\n\n verify(log, never()).error(eq(\"Error compromised device: {}\"), any(String.class));\n verifyNoInteractions(coseService, authenticatorDataVerifier);\n verifyNoMoreInteractions(log);\n }", + "completions": [ + "```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n // Arrange\n String base64AuthenticatorData = \"base64AuthenticatorData\";\n String signature = \"signature\";\n String clientDataJson = \"clientDataJson\";\n Fido2RegistrationData registration = new Fido2RegistrationData();\n Fido2AuthenticationData authenticationEntity = new Fido2AuthenticationData();\n\n when(authenticatorDataParser.parseAssertionData(base64AuthenticatorData)).thenReturn(new AuthData());\n when(authenticatorDataParser.parseCounter(any())).thenReturn(1);\n when(dataMapperService.cborReadTree(any())).thenThrow(new IOException());\n\n // Act & Assert\n Assertions.assertThrows(Fido2RuntimeException.class, () -> {\n fido2AuthenticationService.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity);\n });\n }\n```\n" + ], + "line": 38, + "token": 365, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 18, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void setUserData_invalidKey_throws()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "setUserData_invalidKey_throws", + "reference": "\t@Test\n\tpublic void setUserData_invalidKey_throws() {\n\t\t// setup\n\t\tTestParameters parameters = new TestParameters();\n\n\t\t// test\n\t\tfor (String key : new String[] { null, \"\" }) {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, \"test\");\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid key; key must be non-empty, non-null\"),\n\t\t\t\t\tex.getMessage());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@Test(expected = IllegalArgumentException.class)\n\tpublic void setUserData_invalidKey_throws() {\n\t\tTestClass testClass = new TestClass();\n\t\ttestClass.setUserData(\"\", \"value\");\n\t}\n```\n\nIn this test, we are creating an instance of the TestClass and calling the setUserData method with an empty key and a valid value. We expect this to throw an IllegalArgumentException, which is what we specify in the @Test annotation.\n" + ], + "line": 23, + "token": 170, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueDirectoryMissingErrorsOperation", + "reference": " @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ByteString foo = ByteString.copyFromUtf8(\"foo\");\n Digest subdirDigest = DIGEST_UTIL.compute(foo);\n Directory inputRoot =\n Directory.newBuilder()\n .addDirectories(\n DirectoryNode.newBuilder().setName(\"missing-subdir\").setDigest(subdirDigest))\n .build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(true);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_MISSING)\n .setSubject(\"blobs/\" + DigestUtil.toString(subdirDigest))\n .setDescription(\"The directory `/missing-subdir` was not found in the CAS.\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(10);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(DigestUtil.DEFAULT_INSTANCE_NAME + \"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n ActionKey actionKey = DigestUtil.asActionKey(executeEntry.getActionDigest());\n when(checkCacheFuture(actionKey, any(), any())).thenReturn(immediateFuture(true));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n assertThat(future.get()).isEqualTo(null);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 20, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime = Instant.now();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Start recovery with an offset range that is partially unavailable\n long startOffset = 50;\n long endOffset = 150;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry)).isEqualTo(50);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(51);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 100L,\n 200L,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 21, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public List selectHeaderAccept(String[] accepts) {\n if (accepts.length == 0) {\n return null;\n }\n for (String accept : accepts) {\n MediaType mediaType = MediaType.parseMediaType(accept);\n if (isJsonMime(mediaType) && !isProblemJsonMime(accept)) {\n return Collections.singletonList(mediaType);\n }\n }\n return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts));\n }\n```\n###test function signature:\n```java\n @Test\n public void testSelectHeaderAccept()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSelectHeaderAccept", + "reference": " @Test\n public void testSelectHeaderAccept() {\n String[] accepts = {\"application/json\", \"application/xml\"};\n assertEquals(\"application/json\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"APPLICATION/XML\", \"APPLICATION/JSON\"};\n assertEquals(\"APPLICATION/JSON\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"application/xml\", \"application/json; charset=UTF8\"};\n assertEquals(\"application/json; charset=UTF8\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"text/plain\", \"application/xml\"};\n assertEquals(\"text/plain,application/xml\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{};\n assertNull(apiClient.selectHeaderAccept(accepts));\n }", + "completions": [ + "```java\n @Test\n public void testSelectHeaderAccept() {\n // Given\n String[] accepts = {\"application/json\", \"application/problem+json\"};\n // When\n List result = selectHeaderAccept(accepts);\n // Then\n assertEquals(1, result.size());\n assertEquals(\"application/json\", result.get(0).toString());\n }\n```\n" + ], + "line": 23, + "token": 154, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollEventSentOnConsumerPoll", + "reference": " @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n SubscriptionState subscriptions = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);\n consumer = newConsumer(\n mock(FetchBuffer.class),\n new ConsumerInterceptors<>(Collections.emptyList()),\n mock(ConsumerRebalanceListenerInvoker.class),\n subscriptions,\n singletonList(new RoundRobinAssignor()),\n \"group-id\",\n \"client-id\");\n final TopicPartition tp = new TopicPartition(\"topic\", 0);\n final List> records = singletonList(\n new ConsumerRecord<>(\"topic\", 0, 2, \"key1\", \"value1\"));\n doAnswer(invocation -> Fetch.forPartition(tp, records, true))\n .when(fetchCollector)\n .collectFetch(Mockito.any(FetchBuffer.class));\n\n consumer.subscribe(singletonList(\"topic1\"));\n consumer.poll(Duration.ofMillis(100));\n verify(applicationEventHandler).add(any(PollEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n // Arrange\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n consumer.poll(Duration.ofMillis(100));\n\n // Act\n ConsumerRecords records = consumer.poll(Duration.ofMillis(100));\n\n // Assert\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 23, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetMultiplePartitionRecoveriesBehind", + "reference": " @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n final RecoveryTaskMetadata recoveryTask2 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"2\",\n \"2\",\n recoveryStartOffset * 3 + 1,\n recoveryStartOffset * 4,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask2);\n final RecoveryTaskMetadata recoveryTask21 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"21\", \"2\", recoveryStartOffset * 4 + 1, 50000, createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask21);\n await()\n .until(\n () ->\n recoveryTaskStore\n .listSync()\n .containsAll(\n List.of(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(5);\n assertThat(recoveryTasks)\n .contains(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n // Given\n String partitionId = \"partition1\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 24, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture commitStreamObject(apache.rocketmq.controller.v1.S3StreamObject streamObject,\n List compactedObjects) {\n LOGGER.info(\"commitStreamObject with streamObject: {}, compactedObjects: {}\", TextFormat.shortDebugString(streamObject),\n compactedObjects);\n\n CompletableFuture future = new CompletableFuture<>();\n try (SqlSession session = sessionFactory.openSession()) {\n if (streamObject.getObjectId() == S3Constants.NOOP_OBJECT_ID) {\n LOGGER.error(\"S3StreamObject[object-id={}] is null or objectId is unavailable\", streamObject.getObjectId());\n String msg = String.format(\"S3StreamObject[object-id=%d] is null or objectId is unavailable\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.NOT_FOUND_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n\n long committedTs = System.currentTimeMillis();\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n\n // commit object\n if (!commitObject(streamObject.getObjectId(), streamObject.getStreamId(), streamObject.getObjectSize(), session)) {\n String msg = String.format(\"S3StreamObject[object-id=%d] is not ready for commit\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.ILLEGAL_STATE_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n long dataTs = committedTs;\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n dataTs = compactedObjects.stream()\n .map(id -> {\n // mark destroy compacted object\n S3Object object = s3ObjectMapper.getById(id);\n object.setState(S3ObjectState.BOS_WILL_DELETE);\n object.setMarkedForDeletionTimestamp(new Date());\n s3ObjectMapper.markToDelete(object.getId(), new Date());\n\n // update dataTs to the min compacted object's dataTs\n com.automq.rocketmq.metadata.dao.S3StreamObject s3StreamObject =\n s3StreamObjectMapper.getByObjectId(id);\n return s3StreamObject.getBaseDataTimestamp().getTime();\n })\n .min(Long::compareTo).get();\n }\n\n List toCache = new ArrayList<>();\n\n // create a new S3StreamObject to replace committed ones\n if (streamObject.getObjectId() != S3Constants.NOOP_OBJECT_ID) {\n com.automq.rocketmq.metadata.dao.S3StreamObject newS3StreamObj =\n new com.automq.rocketmq.metadata.dao.S3StreamObject();\n newS3StreamObj.setStreamId(streamObject.getStreamId());\n newS3StreamObj.setObjectId(streamObject.getObjectId());\n newS3StreamObj.setObjectSize(streamObject.getObjectSize());\n newS3StreamObj.setStartOffset(streamObject.getStartOffset());\n newS3StreamObj.setEndOffset(streamObject.getEndOffset());\n newS3StreamObj.setBaseDataTimestamp(new Date(dataTs));\n newS3StreamObj.setCommittedTimestamp(new Date(committedTs));\n s3StreamObjectMapper.create(newS3StreamObj);\n toCache.add(newS3StreamObj);\n }\n\n // delete the compactedObjects of S3Stream\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n compactedObjects.forEach(id -> s3StreamObjectMapper.delete(null, null, id));\n }\n session.commit();\n\n // Update Cache\n s3StreamObjectCache.cache(streamObject.getStreamId(), toCache);\n s3StreamObjectCache.onCompact(streamObject.getStreamId(), compactedObjects);\n\n LOGGER.info(\"S3StreamObject[object-id={}] commit success, compacted objects: {}\",\n streamObject.getObjectId(), compactedObjects);\n future.complete(null);\n } catch (Exception e) {\n LOGGER.error(\"CommitStream failed\", e);\n ControllerException ex = new ControllerException(Code.INTERNAL_VALUE, \"CommitStream failed\" + e.getMessage());\n future.completeExceptionally(ex);\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCommitStreamObject() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCommitStreamObject", + "reference": " @Test\n public void testCommitStreamObject() throws IOException {\n long objectId, streamId = 1;\n\n try (S3MetadataService metadataService = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n objectId = metadataService.prepareS3Objects(3, 5).get();\n } catch (Exception e) {\n throw new RuntimeException(e);\n }\n\n S3StreamObject news3StreamObject = S3StreamObject.newBuilder()\n .setObjectId(objectId + 2)\n .setStreamId(streamId)\n .setObjectSize(111L)\n .build();\n\n try (SqlSession session = this.getSessionFactory().openSession()) {\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n buildS3StreamObjs(objectId, 2, 3, 100L).forEach(s3StreamObjectMapper::create);\n session.commit();\n }\n\n try (S3MetadataService service = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n List compactedObjects = new ArrayList<>();\n compactedObjects.add(objectId);\n compactedObjects.add(objectId + 1);\n service.commitStreamObject(news3StreamObject, compactedObjects);\n }\n\n try (SqlSession session = getSessionFactory().openSession()) {\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3Object s3Object = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object.getState());\n\n S3Object s3Object1 = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object1.getState());\n\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n for (long index = objectId; index < objectId + 2; index++) {\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(index);\n Assertions.assertNull(object);\n }\n\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(objectId + 2);\n Assertions.assertEquals(111L, object.getObjectSize());\n Assertions.assertEquals(streamId, object.getStreamId());\n Assertions.assertTrue(object.getBaseDataTimestamp().getTime() > 0);\n Assertions.assertTrue(object.getCommittedTimestamp().getTime() > 0);\n }\n }", + "completions": [ + "```java\n @Test\n public void testCommitStreamObject() throws IOException {\n // Given\n apache.rocketmq.controller.v1.S3StreamObject streamObject = new apache.rocketmq.controller.v1.S3StreamObject();\n streamObject.setObjectId(1L);\n streamObject.setStreamId(2L);\n streamObject.setObjectSize(3L);\n streamObject.setStartOffset(4L);\n streamObject.setEndOffset(5L);\n\n List compactedObjects = Arrays.asList(6L, 7L, 8L);\n\n // When\n CompletableFuture future = controller.commitStreamObject(streamObject, compactedObjects);\n\n // Then\n assertDoesNotThrow(() -> future.get());\n }\n```\n" + ], + "line": 85, + "token": 664, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 25, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyTooManyResults", + "reference": " @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(totalHits.value).thenReturn(2L);\n Mockito.when(hits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(response.getHits()).thenReturn(hits);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processService.getProcessByKey(processDefinitionKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Expected exception message to match\",\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n e.getMessage());\n }\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String doGet(String url) {\n try (CloseableHttpClient client = HttpClients.createDefault()) {\n HttpGet request = new HttpGet(url);\n HttpResponse response = client.execute(request);\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n return EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n }\n } catch (IOException e) {\n log.error(\"DoGet exception.\", e);\n\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n void testDoGet()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDoGet", + "reference": " @Test\n void testDoGet() {\n String result = HttpUtil.doGet(\"http://baidu.com\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testDoGet() {\n String url = \"http://example.com\";\n String expectedResponse = \"Hello, world!\";\n\n // Mock the HttpClient and HttpResponse\n CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);\n HttpResponse mockResponse = Mockito.mock(HttpResponse.class);\n StatusLine mockStatusLine = Mockito.mock(StatusLine.class);\n HttpEntity mockEntity = Mockito.mock(HttpEntity.class);\n\n // Set up the mock objects to return the expected values\n Mockito.when(mockClient.execute(Mockito.any(HttpGet.class))).thenReturn(mockResponse);\n Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);\n Mockito.when(mockStatusLine.getStatusCode()).thenReturn(HttpStatus.SC_OK);\n Mockito.when(mockResponse.getEntity()).thenReturn(mockEntity);\n Mockito.when(EntityUtils.toString(mockEntity, StandardCharsets.UTF_8)).thenReturn(expectedResponse);\n\n // Call the method under test\n String actualResponse = YourClassName.doGet(url);\n\n // Verify the result\n Assertions.assertEquals(expectedResponse, actualResponse);\n }\n```\nPlease replace `YourClassName` with the actual name of the class where the `doGet` method is located.\n" + ], + "line": 23, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n dpopService.validateDpopThumprint(null, \"any\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n // Given\n String existingThumprint = \"\";\n String requestThumprint = \"requestThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n try {\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), e.getResponse().getStatus());\n assertEquals(\"Thumprint does not match\", e.getMessage());\n }\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDataMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDataMapTest", + "reference": " @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"hera_log_doris_table_2_11\";\n for (int i = 0; i < 1; i++) {\n String jsonString = \"{\\\"linenumber\\\":1256,\\\"tailId\\\":90219,\\\"mqtag\\\":\\\"tags_4_13_90219\\\",\\\"filename\\\":\\\"/home/work/log/log-agent/server.log\\\",\\\"tail\\\":\\\"demo-client-agent\\\",\\\"mqtopic\\\":\\\"90219_hera-demo-client\\\",\\\"message\\\":\\\"2024-01-10 19:34:40,360|INFO ||NettyClientPublicExecutor_1|c.x.mone.log.agent.rpc.task.PingTask|83|ping res: log-agent-server:2022-12-05:0.0.2->2024-01-10 19:34:40 358->10.53.129.250:9899\\\",\\\"logstore\\\":\\\"测试doris日志\\\",\\\"logip\\\":\\\"10.53.129.176\\\",\\\"timestamp\\\":1704886481241}\";\n\n Map map = gson.fromJson(jsonString, Map.class);\n dorisStreamLoad.sendData(database, tableName, columnListNew, map);\n }\n }", + "completions": [ + "```java\n @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 29, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyNoResults", + "reference": " @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final NotFoundException exception =\n Assertions.assertThrows(\n NotFoundException.class, () -> processService.getProcessByKey(processDefinitionKey));\n\n // Then\n Assertions.assertEquals(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 30, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String getDiagramByKey(Long processDefinitionKey) {\n final IdsQueryBuilder q = idsQuery().addIds(processDefinitionKey.toString());\n\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n\n if (response.getHits().getTotalHits().value == 1) {\n final Map result = response.getHits().getHits()[0].getSourceAsMap();\n return (String) result.get(BPMN_XML);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with id '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with id '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\n \"Exception occurred, while obtaining the process diagram: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetDiagramByKeyTooManyResults", + "reference": " @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getDiagramByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 1L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHits.Relation.EQUAL_TO));\n\n // When\n try {\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n processDiagramService.getDiagramByKey(processDefinitionKey);\n } catch (NotFoundException e) {\n // Then\n Assert.assertEquals(\n \"Could not find unique process with id '\" + processDefinitionKey + \"'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 37, + "token": 309, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 31, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n void shutdown() {\n LOG.info(\"Running shutdown hook.\");\n try {\n serviceManager.stopAsync().awaitStopped(30, TimeUnit.SECONDS);\n } catch (Exception e) {\n // stopping timed out\n LOG.error(\"ServiceManager shutdown timed out\", e);\n }\n try {\n curatorFramework.unwrap().close();\n } catch (Exception e) {\n LOG.error(\"Error while closing curatorFramework \", e);\n }\n LOG.info(\"Shutting down LogManager\");\n LogManager.shutdown();\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexerShutdownTwice() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexerShutdownTwice", + "reference": " @Test\n public void testIndexerShutdownTwice() throws Exception {\n startKafkaServer();\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // Create a live partition for this partiton\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTimeMs = 1;\n final long endTimeMs = 100;\n final long maxOffset = 30;\n SnapshotMetadata livePartition0 =\n new SnapshotMetadata(\n name + \"live0\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"0\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition0);\n\n SnapshotMetadata livePartition1 =\n new SnapshotMetadata(\n name + \"live1\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"1\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition1);\n\n final SnapshotMetadata partition0 =\n new SnapshotMetadata(name, path, startTimeMs, endTimeMs, maxOffset, \"0\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition0);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .containsOnly(livePartition1, livePartition0, partition0);\n\n // Empty consumer offset since there is no prior consumer.\n astraIndexer =\n new AstraIndexer(\n chunkManagerUtil.chunkManager,\n curatorFramework,\n makeIndexerConfig(50),\n getKafkaConfig(),\n metricsRegistry);\n astraIndexer.startAsync();\n astraIndexer.awaitRunning(DEFAULT_START_STOP_DURATION);\n await().until(() -> kafkaServer.getConnectedConsumerGroups() == 1);\n\n // Produce more messages since the recovery task is created for head.\n produceMessagesToKafka(kafkaServer.getBroker(), startTime);\n\n consumeMessagesAndSearchMessagesTest(100, 1);\n\n // Live snapshot is deleted, recovery task is created.\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(livePartition1, partition0);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isEqualTo(4);\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).size()).isEqualTo(1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(searchMetadataStore).size()).isEqualTo(1);\n RecoveryTaskMetadata recoveryTask1 =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).get(0);\n assertThat(recoveryTask1.startOffset).isEqualTo(31);\n assertThat(recoveryTask1.endOffset).isEqualTo(99);\n assertThat(recoveryTask1.partitionId).isEqualTo(\"0\");\n\n // Shutting down is idempotent. So, doing it twice shouldn't throw an error.\n astraIndexer.shutDown();\n astraIndexer.shutDown();\n astraIndexer = null;\n }", + "completions": [ + "```java\n @Test\n public void testIndexerShutdownTwice() throws Exception {\n // Given\n Indexer indexer = new Indexer();\n indexer.serviceManager = Mockito.mock(ServiceManager.class);\n indexer.curatorFramework = Mockito.mock(CuratorFramework.class);\n Mockito.when(indexer.serviceManager.stopAsync()).thenReturn(Mockito.mock(ServiceManager.class));\n Mockito.when(indexer.serviceManager.awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class))).thenThrow(new TimeoutException());\n Mockito.when(indexer.curatorFramework.unwrap()).thenReturn(Mockito.mock(Closeable.class));\n Mockito.doThrow(new IOException()).when(indexer.curatorFramework.unwrap()).close();\n\n // When\n indexer.shutdown();\n\n // Then\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).stopAsync();\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class));\n Mockito.verify(indexer.curatorFramework.unwrap(), Mockito.times(1)).close();\n }\n```\n" + ], + "line": 27, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(addressList,\"test\",\"test\",\"test\", SystemEnum.AGENT, MsgTypeEnum.ALARM);\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = null;\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 33, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyWithException", + "reference": " @Test\n public void testGetProcessByKeyWithException() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(OperateRuntimeException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(totalHits.value).thenReturn(2L);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final Executable executable = () -> processService.getProcessByKey(processDefinitionKey);\n\n // Then\n final NotFoundException exception =\n Assertions.assertThrows(NotFoundException.class, executable);\n Assertions.assertEquals(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 34, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testConcurrentSearches() throws InterruptedException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testConcurrentSearches", + "reference": " @Test\n public void testConcurrentSearches() throws InterruptedException {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n AtomicInteger searchFailures = new AtomicInteger(0);\n AtomicInteger statsFailures = new AtomicInteger(0);\n AtomicInteger searchExceptions = new AtomicInteger(0);\n AtomicInteger successfulRuns = new AtomicInteger(0);\n\n Runnable searchRun =\n () -> {\n for (int i = 0; i < 100; i++) {\n try {\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n 0L,\n MAX_TIME,\n 100,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n if (babies.hits.size() != 2) {\n searchFailures.addAndGet(1);\n } else {\n successfulRuns.addAndGet(1);\n }\n } catch (Exception e) {\n searchExceptions.addAndGet(1);\n }\n }\n };\n\n Thread t1 = new Thread(searchRun);\n Thread t2 = new Thread(searchRun);\n t1.start();\n t2.start();\n t1.join();\n t2.join();\n assertThat(searchExceptions.get()).isEqualTo(0);\n assertThat(statsFailures.get()).isEqualTo(0);\n assertThat(searchFailures.get()).isEqualTo(0);\n assertThat(successfulRuns.get()).isEqualTo(200);\n }", + "completions": [ + "```java\n @Test\n public void testConcurrentSearches() throws InterruptedException {\n // Create a fixed thread pool with 10 threads\n ExecutorService executorService = Executors.newFixedThreadPool(10);\n\n // Create a list to hold the futures\n List>> futures = new ArrayList<>();\n\n // Submit 100 search tasks to the executor service\n for (int i = 0; i < 100; i++) {\n // Create a search task\n Callable> searchTask = () -> {\n // Generate a random dataset\n String dataset = UUID.randomUUID().toString();\n\n // Generate a random query string\n String queryStr = UUID.randomUUID().toString();\n\n // Generate a random start time\n long startTimeMsEpoch = System.currentTimeMillis() - ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random end time\n long endTimeMsEpoch = startTimeMsEpoch + ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random number of hits requested\n int howMany = ThreadLocalRandom.current().nextInt(100);\n\n // Generate a random aggregation builder\n AggBuilder aggBuilder = new AggBuilder();\n\n // Perform the search\n return logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n };\n\n // Submit the search task to the executor service and add the future to the list\n futures.add(executorService.submit(searchTask));\n }\n\n // Wait for all search tasks to complete\n for (Future> future : futures) {\n future.get();\n }\n\n // Shutdown the executor service\n executorService.shutdown();\n }\n```\n" + ], + "line": 85, + "token": 608, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 35, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String sendGetData(String path, Map headMap) {\n log.info(\"request tcm, path:{},headMap:{} \",path,headMap);\n String result = \"\";\n CloseableHttpResponse response =null;\n try(CloseableHttpClient httpClient = HttpClients.createDefault()) {\n // 创建get方式请求对象\n HttpGet httpGet = new HttpGet(path);\n if (null != headMap) {\n for (Map.Entry entry : headMap.entrySet()) {\n httpGet.addHeader(entry.getKey(), entry.getValue());\n }\n }\n // 通过请求对象获取响应对象\n response = httpClient.execute(httpGet);\n\n // 获取结果实体\n // 判断网络连接状态码是否正常(0--200都数正常)\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n result = EntityUtils.toString(response.getEntity(), UTF_8);\n } else {\n log.error(\"get请求传输 异常.url:{}, headMap:{}\", path, headMap);\n }\n } catch (ClientProtocolException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n } catch (IOException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n }\n finally {\n try {\n if (null!=response){\n response.close();\n }\n } catch (IOException e) {\n log.error(\"关闭response 异常 \", e);\n }\n }\n log.debug(result);\n return result;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendGetData()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendGetData", + "reference": " @Test\n void testSendGetData() {\n final Map headMap = new HashMap<>();\n final String result = HttpUtils.sendGetData(\"path\", headMap);\n assertThat(result).isEqualTo(\"\");\n }", + "completions": [ + "```java\n @Test\n void testSendGetData() {\n String path = \"http://localhost:8080/api/test\";\n Map headMap = new HashMap<>();\n headMap.put(\"Content-Type\", \"application/json\");\n headMap.put(\"Authorization\", \"Bearer your_token\");\n\n String result = sendGetData(path, headMap);\n\n // Add assertions to verify the result\n // For example, if the result is expected to be a JSON string, you can use a JSON library to parse the result and assert the values\n // Here is an example using the JSON library Jackson\n ObjectMapper objectMapper = new ObjectMapper();\n JsonNode jsonNode = objectMapper.readTree(result);\n assertEquals(\"expectedValue\", jsonNode.get(\"key\").asText());\n }\n```\nPlease replace \"your_token\" with the actual token you want to use for the test. Also, replace \"expectedValue\" and \"key\" with the actual expected value and key from the JSON response.\n" + ], + "line": 51, + "token": 367, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 36, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ByteString retainKeyPrefix(ByteString tenantNS, List topicFilterLevels) {\n ByteString prefix = ByteString.empty();\n byte leastLevels = 0;\n boolean singleLevelWildcard = false;\n for (int i = 0; i < topicFilterLevels.size(); i++) {\n String tfl = topicFilterLevels.get(i);\n if (\"+\".equals(tfl)) {\n leastLevels++;\n singleLevelWildcard = true;\n continue;\n }\n if (\"#\".equals(tfl)) {\n break;\n }\n leastLevels++;\n if (!singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(tfl));\n }\n if (i + 1 < topicFilterLevels.size()) {\n if (!topicFilterLevels.get(i + 1).equals(\"#\") && !singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(NUL));\n }\n }\n }\n return tenantNS.concat(unsafeWrap(new byte[] {leastLevels})).concat(prefix);\n }\n```\n###test function signature:\n```java\n @Test\n public void testRetainKeyPrefix()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRetainKeyPrefix", + "reference": " @Test\n public void testRetainKeyPrefix() {\n String tenantId = \"tenantA\";\n ByteString tenantNS = tenantNS(tenantId);\n assertEquals(retainKeyPrefix(tenantNS, parse(\"#\", false)), tenantNS.concat(levelByte(0)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/#\", false)),\n tenantNS.concat(levelByte(1).concat(copyFromUtf8(\"a\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/+\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(\"a\").concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/b\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)).concat(copyFromUtf8(\"b\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/#\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+\", false)),\n tenantNS.concat(\n levelByte(3).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+/+\", false)),\n tenantNS.concat(\n levelByte(4).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/\", false)),\n tenantNS.concat(levelByte(4).concat(copyFromUtf8(NUL))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/+/\", false)),\n tenantNS.concat(levelByte(5).concat(copyFromUtf8(NUL))));\n }", + "completions": [ + "```java\n @Test\n public void testRetainKeyPrefix() {\n ByteString tenantNS = ByteString.copyFromUtf8(\"tenant1\");\n List topicFilterLevels = Arrays.asList(\"level1\", \"level2\", \"+\", \"#\");\n ByteString expectedPrefix = ByteString.copyFromUtf8(\"tenant1\\u0001level1\\u0000level2\");\n ByteString actualPrefix = retainKeyPrefix(tenantNS, topicFilterLevels);\n assertEquals(expectedPrefix, actualPrefix);\n }\n```\n" + ], + "line": 37, + "token": 255, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 37, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n @Transactional(rollbackFor = Exception.class)\n public void updateFileConfigMaster(Long id) {\n // 校验存在\n validateFileConfigExists(id);\n // 更新其它为非 master\n fileConfigMapper.updateBatch(new FileConfigDO().setMaster(false));\n // 更新\n fileConfigMapper.updateById(new FileConfigDO().setId(id).setMaster(true));\n // 发送刷新配置的消息\n TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {\n\n @Override\n public void afterCommit() {\n fileConfigProducer.sendFileConfigRefreshMessage();\n }\n\n });\n }\n```\n###test function signature:\n```java\n @Test\n public void testUpdateFileConfigMaster_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testUpdateFileConfigMaster_success", + "reference": " @Test\n public void testUpdateFileConfigMaster_success() {\n // mock 数据\n FileConfigDO dbFileConfig = randomFileConfigDO().setMaster(false);\n fileConfigMapper.insert(dbFileConfig);// @Sql: 先插入出一条存在的数据\n FileConfigDO masterFileConfig = randomFileConfigDO().setMaster(true);\n fileConfigMapper.insert(masterFileConfig);// @Sql: 先插入出一条存在的数据\n\n // 调用\n fileConfigService.updateFileConfigMaster(dbFileConfig.getId());\n // 断言数据\n assertTrue(fileConfigMapper.selectById(dbFileConfig.getId()).getMaster());\n assertFalse(fileConfigMapper.selectById(masterFileConfig.getId()).getMaster());\n // verify 调用\n verify(fileConfigProducer).sendFileConfigRefreshMessage();\n }", + "completions": [ + "```java\n @Test\n public void testUpdateFileConfigMaster_success() {\n // 准备数据\n FileConfigDO fileConfigDO = new FileConfigDO();\n fileConfigDO.setId(1L);\n fileConfigDO.setMaster(false);\n fileConfigMapper.insert(fileConfigDO);\n\n // 执行方法\n fileConfigService.updateFileConfigMaster(1L);\n\n // 验证结果\n FileConfigDO updatedFileConfigDO = fileConfigMapper.selectById(1L);\n Assert.assertTrue(updatedFileConfigDO.getMaster());\n\n // 清理数据\n fileConfigMapper.deleteById(1L);\n }\n```\n" + ], + "line": 28, + "token": 164, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\t@Override\n\tpublic void add(List documents) {\n\n\t\tAssert.notNull(documents, \"The document list should not be null.\");\n\t\tif (CollectionUtils.isEmpty(documents)) {\n\t\t\treturn; // nothing to do;\n\t\t}\n\n\t\tfinal var searchDocuments = documents.stream().map(document -> {\n\t\t\tfinal var embeddings = this.embeddingClient.embed(document);\n\t\t\tSearchDocument searchDocument = new SearchDocument();\n\t\t\tsearchDocument.put(ID_FIELD_NAME, document.getId());\n\t\t\tsearchDocument.put(EMBEDDING_FIELD_NAME, embeddings);\n\t\t\tsearchDocument.put(CONTENT_FIELD_NAME, document.getContent());\n\t\t\tsearchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());\n\n\t\t\t// Add the filterable metadata fields as top level fields, allowing filler\n\t\t\t// expressions on them.\n\t\t\tfor (MetadataField mf : this.filterMetadataFields) {\n\t\t\t\tif (document.getMetadata().containsKey(mf.name())) {\n\t\t\t\t\tsearchDocument.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name()));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn searchDocument;\n\t\t}).toList();\n\n\t\tIndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);\n\n\t\tfor (IndexingResult indexingResult : result.getResults()) {\n\t\t\tAssert.isTrue(indexingResult.isSucceeded(),\n\t\t\t\t\tString.format(\"Document with key %s upload is not successfully\", indexingResult.getKey()));\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void addAndSearchTest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "addAndSearchTest", + "reference": "\t@Test\n\tpublic void addAndSearchTest() {\n\n\t\tcontextRunner.run(context -> {\n\n\t\t\tVectorStore vectorStore = context.getBean(VectorStore.class);\n\n\t\t\tvectorStore.add(documents);\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\t\t\t}, hasSize(1));\n\n\t\t\tList results = vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\n\t\t\tassertThat(results).hasSize(1);\n\t\t\tDocument resultDoc = results.get(0);\n\t\t\tassertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());\n\t\t\tassertThat(resultDoc.getContent()).contains(\"The Great Depression (1929–1939) was an economic shock\");\n\t\t\tassertThat(resultDoc.getMetadata()).hasSize(2);\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"meta2\");\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"distance\");\n\n\t\t\t// Remove all documents from the store\n\t\t\tvectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Hello\").withTopK(1));\n\t\t\t}, hasSize(0));\n\t\t});\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void addAndSearchTest() {\n\t\t// Given\n\t\tList documents = new ArrayList<>();\n\t\tDocument document1 = new Document();\n\t\tdocument1.setId(\"1\");\n\t\tdocument1.setContent(\"This is a test document\");\n\t\tdocument1.setMetadata(new HashMap<>());\n\t\tdocument1.getMetadata().put(\"author\", \"John Doe\");\n\t\tdocument1.getMetadata().put(\"date\", \"2022-01-01\");\n\t\tdocuments.add(document1);\n\n\t\tDocument document2 = new Document();\n\t\tdocument2.setId(\"2\");\n\t\tdocument2.setContent(\"This is another test document\");\n\t\tdocument2.setMetadata(new HashMap<>());\n\t\tdocument2.getMetadata().put(\"author\", \"Jane Doe\");\n\t\tdocument2.getMetadata().put(\"date\", \"2022-01-02\");\n\t\tdocuments.add(document2);\n\n\t\t// When\n\t\tsearchService.add(documents);\n\n\t\t// Then\n\t\tSearchRequest searchRequest = new SearchRequest();\n\t\tsearchRequest.setQuery(\"test\");\n\t\tsearchRequest.setFilter(new HashMap<>());\n\t\tsearchRequest.getFilter().put(\"author\", \"John Doe\");\n\t\tSearchResponse searchResponse = searchService.search(searchRequest);\n\n\t\tAssert.assertEquals(1, searchResponse.getDocuments().size());\n\t\tAssert.assertEquals(\"1\", searchResponse.getDocuments().get(0).getId());\n\t}\n```\n" + ], + "line": 39, + "token": 312, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 39, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {\n long startTime = System.currentTimeMillis();\n\n // 校验参数\n try {\n validateParams(metrics);\n } catch (Exception e) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(e.getMessage());\n return;\n }\n\n HttpContext httpContext = createHttpContext(metrics.getNginx());\n HttpUriRequest request = createHttpRequest(metrics.getNginx());\n try (CloseableHttpResponse response = CommonHttpClient.getHttpClient().execute(request, httpContext)){\n // 发起http请求,获取响应数据\n int statusCode = response.getStatusLine().getStatusCode();\n if (statusCode != SUCCESS_CODE) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(\"StatusCode \" + statusCode);\n return;\n }\n String resp = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n\n Long responseTime = System.currentTimeMillis() - startTime;\n // 根据metrics name选择调用不同解析方法\n if (NGINX_STATUS_NAME.equals(metrics.getName()) || AVAILABLE.equals(metrics.getName())) {\n parseNginxStatusResponse(builder, resp, metrics, responseTime);\n } else if (REQ_STATUS_NAME.equals(metrics.getName())) {\n parseReqStatusResponse(builder, resp, metrics, responseTime);\n }\n } catch (Exception e) {\n String errorMsg = CommonUtil.getMessageFromThrowable(e);\n log.info(errorMsg);\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(errorMsg);\n } finally {\n if (request != null) {\n request.abort();\n }\n }\n\n }\n```\n###test function signature:\n```java\n @Test\n public void testNginxCollect()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNginxCollect", + "reference": " @Test\n public void testNginxCollect() {\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n long monitorId = 999;\n String app = \"testNginx\";\n\n Metrics metrics = new Metrics();\n metrics.setNginx(nginxProtocol);\n nginxCollect.collect(builder, monitorId, app, metrics);\n }", + "completions": [ + "```java\n @Test\n public void testNginxCollect() {\n // 创建一个Metrics对象\n Metrics metrics = new Metrics();\n metrics.setName(\"nginx_status\");\n metrics.setNginx(new Nginx(\"http://localhost:8080/nginx_status\"));\n\n // 创建一个CollectRep.MetricsData.Builder对象\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n\n // 创建一个NginxCollect对象\n NginxCollect nginxCollect = new NginxCollect();\n\n // 调用collect方法\n nginxCollect.collect(builder, 1, \"test\", metrics);\n\n // 验证结果\n assertEquals(CollectRep.Code.SUCCESS, builder.getCode());\n assertNotNull(builder.getData());\n assertTrue(builder.getData().getMetrics().containsKey(\"active_connections\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"accepts\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"handled\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"requests\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"reading\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"writing\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"waiting\"));\n }\n```\n" + ], + "line": 51, + "token": 335, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 40, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Transactional(rollbackOn = Exception.class)\n public void submit(@NonNull PartitionPlanConfig partitionPlanConfig)\n throws SchedulerException, ClassNotFoundException {\n Long databaseId = partitionPlanConfig.getDatabaseId();\n Validate.notNull(databaseId, \"DatabaseId can not be null\");\n // disable all related partition plan task\n Database database = this.databaseService.detail(databaseId);\n disablePartitionPlan(database.getId());\n PartitionPlanEntity partitionPlanEntity = modelToEntity(partitionPlanConfig);\n partitionPlanEntity = this.partitionPlanRepository.save(partitionPlanEntity);\n if (!partitionPlanConfig.isEnabled()\n || CollectionUtils.isEmpty(partitionPlanConfig.getPartitionTableConfigs())) {\n log.info(\"Partition plan is disabled or table config is empty, do nothing and return\");\n return;\n }\n Validate.isTrue(partitionPlanConfig.getCreationTrigger() != null, \"Creation trigger can not be null\");\n if (partitionPlanConfig.getDroppingTrigger() == null) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(partitionPlanConfig.getPartitionTableConfigs(),\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n return;\n }\n Map> strategy2TblCfgs =\n partitionPlanConfig.getPartitionTableConfigs().stream().flatMap(tableConfig -> {\n Map> strategy2Cfgs =\n tableConfig.getPartitionKeyConfigs().stream()\n .collect(Collectors.groupingBy(PartitionPlanKeyConfig::getStrategy));\n return strategy2Cfgs.values().stream().map(cfgs -> {\n PartitionPlanTableConfig cfg = new PartitionPlanTableConfig();\n cfg.setPartitionKeyConfigs(cfgs);\n cfg.setTableName(tableConfig.getTableName());\n cfg.setEnabled(tableConfig.isEnabled());\n cfg.setPartitionNameInvoker(tableConfig.getPartitionNameInvoker());\n cfg.setPartitionNameInvokerParameters(tableConfig.getPartitionNameInvokerParameters());\n return cfg;\n });\n }).collect(Collectors.groupingBy(cfg -> cfg.getPartitionKeyConfigs().get(0).getStrategy()));\n List createConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.CREATE);\n List dropConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.DROP);\n if (CollectionUtils.isNotEmpty(createConfigs)) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(createConfigs,\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n if (CollectionUtils.isNotEmpty(dropConfigs)) {\n ScheduleEntity dropScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getDroppingTrigger());\n createPartitionPlanTables(dropConfigs,\n partitionPlanEntity.getId(), dropScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "submit_bothCreateAndDropTrigger_submitSucceed", + "reference": " @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n PartitionPlanTableConfig tableConfig = new PartitionPlanTableConfig();\n tableConfig.setTableName(MYSQL_REAL_RANGE_TABLE_NAME);\n tableConfig.setPartitionNameInvoker(\"CUSTOM_PARTITION_NAME_GENERATOR\");\n SqlExprBasedGeneratorConfig config = new SqlExprBasedGeneratorConfig();\n config.setGenerateExpr(\"concat('p', date_format(from_unixtime(unix_timestamp(\"\n + \"STR_TO_DATE(20240125, '%Y%m%d')) + \"\n + PartitionPlanVariableKey.INTERVAL.getVariable() + \"), '%Y%m%d'))\");\n config.setIntervalGenerateExpr(\"86400\");\n tableConfig.setPartitionNameInvokerParameters(getSqlExprBasedNameGeneratorParameters(config));\n PartitionPlanKeyConfig c3Create = getMysqlc3CreateConfig();\n PartitionPlanKeyConfig datekeyCreate = getMysqldatekeyCreateConfig();\n PartitionPlanKeyConfig dropConfig = getDropConfig();\n tableConfig.setPartitionKeyConfigs(Arrays.asList(c3Create, datekeyCreate, dropConfig));\n\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setPartitionTableConfigs(Collections.singletonList(tableConfig));\n partitionPlanConfig.setFlowInstanceId(1L);\n partitionPlanConfig.setTimeoutMillis(180000L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setDatabaseId(1L);\n\n long t1 = System.currentTimeMillis();\n long t2 = System.currentTimeMillis() + 1000;\n TriggerConfig createTrigger = TestRandom.nextObject(TriggerConfig.class);\n createTrigger.setStartAt(new Date(t1));\n partitionPlanConfig.setCreationTrigger(createTrigger);\n\n TriggerConfig dropTrigger = TestRandom.nextObject(TriggerConfig.class);\n dropTrigger.setStartAt(new Date(t2));\n partitionPlanConfig.setDroppingTrigger(dropTrigger);\n Database database = TestRandom.nextObject(Database.class);\n database.setId(1L);\n Mockito.when(this.databaseService.detail(1L)).thenReturn(database);\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t1;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t2;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.doNothing().when(this.scheduleService).enable(Mockito.isA(ScheduleEntity.class));\n this.partitionPlanScheduleService.submit(partitionPlanConfig);\n List actuals = this.partitionPlanTableRepository.findAll();\n Set scheduleIds = actuals.stream()\n .map(PartitionPlanTableEntity::getScheduleId).collect(Collectors.toSet());\n Assert.assertEquals(scheduleIds.size(), 2);\n }", + "completions": [ + "```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n // Given\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setDatabaseId(1L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanConfig.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanConfig.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanConfig.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanConfig.setTaskId(\"taskId\");\n partitionPlanConfig.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanConfig.setTimeoutMillis(1000L);\n\n Database database = new Database();\n database.setId(1L);\n database.setName(\"database\");\n database.setType(DatabaseType.MYSQL);\n database.setHost(\"localhost\");\n database.setPort(3306);\n database.setUsername(\"root\");\n database.setPassword(\"password\");\n database.setDatabaseName(\"test\");\n database.setProperties(\"{}\");\n database.setStatus(DatabaseStatus.ONLINE);\n database.setCreateTime(new Date());\n database.setUpdateTime(new Date());\n when(databaseService.detail(1L)).thenReturn(database);\n\n PartitionPlanEntity partitionPlanEntity = new PartitionPlanEntity();\n partitionPlanEntity.setId(1L);\n partitionPlanEntity.setDatabaseId(1L);\n partitionPlanEntity.setEnabled(true);\n partitionPlanEntity.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanEntity.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanEntity.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanEntity.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanEntity.setTaskId(\"taskId\");\n partitionPlanEntity.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanEntity.setTimeoutMillis(1000L);\n when(partitionPlanRepository.save(any(PartitionPlanEntity.class))).thenReturn(partitionPlanEntity);\n\n ScheduleEntity createScheduleEntity = new ScheduleEntity();\n createScheduleEntity.setId(1L);\n createScheduleEntity.setDatabaseId(1L);\n createScheduleEntity.setTriggerConfig(new TriggerConfig(TriggerType" + ], + "line": 70, + "token": 530, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 41, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAllQueryWithFullTextSearchDisabled", + "reference": " @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n Instant time = Instant.now();\n Trace.KeyValue customField =\n Trace.KeyValue.newBuilder()\n .setVStr(\"value\")\n .setKey(\"customField\")\n .setFieldType(Schema.SchemaFieldType.KEYWORD)\n .build();\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time, List.of(customField)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n SearchResult termQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"customField:value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(termQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermStrQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermStrQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermNumericQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermNumericQuery.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"*\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNull(result.getInternalAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 42, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollExecutedCommitAsyncCallbacks", + "reference": " @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks() {\n consumer = newConsumer();\n MockCommitCallback callback = new MockCommitCallback();\n completeCommitAsyncApplicationEventSuccessfully();\n doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n completeFetchedCommittedOffsetApplicationEventSuccessfully(mkMap());\n\n consumer.assign(Collections.singleton(new TopicPartition(\"foo\", 0)));\n assertDoesNotThrow(() -> consumer.commitAsync(new HashMap<>(), callback));\n assertMockCommitCallbackInvoked(() -> consumer.poll(Duration.ZERO),\n callback,\n null);\n }", + "completions": [ + "```java\n @Test\n public void testPoll() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertNotNull(records);\n assertEquals(0, records.count());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 43, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static URI resolve(URI base, URI ref) {\n String path = ref.getRawPath();\n if (path != null && !path.isEmpty()) {\n path = base.getRawPath();\n if (path == null || !path.endsWith(\"/\")) {\n try {\n base = new URI(base.getScheme(), base.getAuthority(), base.getPath() + '/', null, null);\n } catch (URISyntaxException e) {\n throw new IllegalStateException(e);\n }\n }\n }\n return URIUtils.resolve(base, ref);\n }\n```\n###test function signature:\n```java\n @Test\n void testResolve_BasePathTrailingSlash()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testResolve_BasePathTrailingSlash", + "reference": " @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://host/sub/dir/\");\n assertEquals(\"http://host/sub/dir/file.jar\", resolve(base, \"file.jar\"));\n assertEquals(\"http://host/sub/dir/dir/file.jar\", resolve(base, \"dir/file.jar\"));\n assertEquals(\"http://host/sub/dir/?arg=val\", resolve(base, \"?arg=val\"));\n assertEquals(\"http://host/sub/dir/file?arg=val\", resolve(base, \"file?arg=val\"));\n assertEquals(\"http://host/sub/dir/dir/file?arg=val\", resolve(base, \"dir/file?arg=val\"));\n }", + "completions": [ + "```java\n @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://example.com/path/\");\n URI ref = URI.create(\"http://example.com/path/file.txt\");\n URI expected = URI.create(\"http://example.com/path/file.txt\");\n URI actual = resolve(base, ref);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 25, + "token": 187, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_collection_of_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_collection_of_non_serializable_object", + "reference": " @Test\n public void should_clone_collection_of_non_serializable_object() {\n List original = new ArrayList<>();\n original.add(new NonSerializableObject(\"value\"));\n List cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_collection_of_non_serializable_object() {\n // Given\n List nonSerializableList = new ArrayList<>();\n nonSerializableList.add(new NonSerializableObject(\"test\"));\n nonSerializableList.add(new NonSerializableObject(\"test2\"));\n\n // When\n List clonedList = objectCloner.clone(nonSerializableList);\n\n // Then\n assertNotSame(nonSerializableList, clonedList);\n assertEquals(nonSerializableList.size(), clonedList.size());\n for (int i = 0; i < nonSerializableList.size(); i++) {\n assertNotSame(nonSerializableList.get(i), clonedList.get(i));\n assertEquals(nonSerializableList.get(i), clonedList.get(i));\n }\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 45, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssign()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssign", + "reference": " @Test\n public void testAssign() {\n consumer = newConsumer();\n final TopicPartition tp = new TopicPartition(\"foo\", 3);\n consumer.assign(singleton(tp));\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().contains(tp));\n verify(applicationEventHandler).add(any(AssignmentChangeEvent.class));\n verify(applicationEventHandler).add(any(NewTopicsMetadataUpdateRequestEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testAssign() {\n // Create a mock Consumer instance\n Consumer mockConsumer = mock(Consumer.class);\n\n // Create a collection of TopicPartition instances\n Collection partitions = new ArrayList<>();\n partitions.add(new TopicPartition(\"testTopic\", 0));\n partitions.add(new TopicPartition(\"testTopic\", 1));\n\n // Call the assign method\n mockConsumer.assign(partitions);\n\n // Verify that the assign method was called with the correct parameters\n verify(mockConsumer).assign(partitions);\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 46, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String relativePath(Path absolutePath, Path basePath) {\n Preconditions.checkArgument(absolutePath.isAbsolute(), \"absolutePath must be an absolute path\");\n Preconditions.checkArgument(basePath.isAbsolute(), \"basePath must be an absolute path\");\n\n List absolutePathComponents = Lists.newArrayList(toPathComponents(absolutePath)); // make a copy\n List basePathComponents = toPathComponents(basePath);\n boolean hasCommonPrefix = basePathComponents.isEmpty(); // when basePath is \"/\", we always have a common prefix\n\n for (String base : basePathComponents) {\n if (absolutePathComponents.get(0).equals(base)) {\n absolutePathComponents.remove(0);\n hasCommonPrefix = true;\n } else {\n break;\n }\n }\n if (hasCommonPrefix) {\n return PATH_JOINER.join(absolutePathComponents);\n } else {\n return absolutePath.toString();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testRelativePath() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRelativePath", + "reference": " @Test\n public void testRelativePath() throws Exception {\n assertEquals(\"a\", PathUtils.relativePath(Path.of(\"/a\"), Path.of(\"/\")));\n assertEquals(\"b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/a\")));\n assertEquals(\"b/c.json\", PathUtils.relativePath(Path.of(\"/a/b/c.json\"), Path.of(\"/a\")));\n assertEquals(\"c/d/e\", PathUtils.relativePath(Path.of(\"/a/b/c/d/e\"), Path.of(\"/a/b\")));\n assertEquals(\"/a/b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/c/d\"))); // no common prefix\n }", + "completions": [ + "```java\n @Test\n public void testRelativePath() throws Exception {\n Path absolutePath = Paths.get(\"/a/b/c/d\");\n Path basePath = Paths.get(\"/a/b\");\n String expected = \"c/d\";\n String actual = relativePath(absolutePath, basePath);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 31, + "token": 234, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Boolean isJson(String value) {\n JsonElement jsonElement;\n try {\n jsonElement = JsonParser.parseString(value);\n } catch (Exception e) {\n return false;\n }\n if (jsonElement == null || jsonElement.isJsonNull()) {\n return false;\n }\n if (!jsonElement.isJsonObject() && !jsonElement.isJsonArray()) {\n return false;\n }\n return true;\n }\n```\n###test function signature:\n```java\n @Test\n void testIsJson()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJson", + "reference": " @Test\n void testIsJson() {\n assertThat(JsonUtil.isJson(\"value\")).isFalse();\n }", + "completions": [ + "```java\n @Test\n void testIsJson() {\n // Test case 1: Valid JSON object\n String jsonObject = \"{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}\";\n assertTrue(isJson(jsonObject));\n\n // Test case 2: Valid JSON array\n String jsonArray = \"[{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}, {\\\"name\\\":\\\"Jane\\\", \\\"age\\\":25, \\\"city\\\":\\\"Chicago\\\"}]\";\n assertTrue(isJson(jsonArray));\n\n // Test case 3: Invalid JSON (not an object or array)\n String invalidJson = \"\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"\";\n assertFalse(isJson(invalidJson));\n\n // Test case 4: Empty string\n String emptyString = \"\";\n assertFalse(isJson(emptyString));\n\n // Test case 5: Null value\n String nullValue = null;\n assertFalse(isJson(nullValue));\n }\n```\n" + ], + "line": 26, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchOnlyHistogram()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchOnlyHistogram", + "reference": " @Test\n public void testSearchOnlyHistogram() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 0,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(0);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(2);\n\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(1).getDocCount()).isEqualTo(1);\n\n assertThat(\n Long.parseLong(histogram.getBuckets().get(0).getKeyAsString()) >= time.toEpochMilli())\n .isTrue();\n assertThat(\n Long.parseLong(histogram.getBuckets().get(1).getKeyAsString())\n <= time.plusSeconds(10).toEpochMilli())\n .isTrue();\n }", + "completions": [ + "```java\n @Test\n public void testSearchOnlyHistogram() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 0;\n AggBuilder aggBuilder = new AggBuilder();\n aggBuilder.addHistogramAgg(\"testField\", \"testInterval\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation().getHistogramAggs().size() > 0);\n assertEquals(\"testField\", result.getAggregation().getHistogramAggs().get(0).getField());\n assertEquals(\"testInterval\", result.getAggregation().getHistogramAggs().get(0).getInterval());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 49, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent", + "reference": " @Test\n public void testReadComponent() {\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n EntityMetadata entityMetadata = mock(EntityMetadata.class);\n VariableSource variableSource = mock(VariableSource.class);\n when(contentPermissionChecker.isPermitted(anyString(), any(), eq(BreadActions.BROWSE), any())).thenReturn(true);\n when(component.getEntityMetadata()).thenReturn(entityMetadata);\n when(entityMetadata.getId()).thenReturn(new DetachedEntityId(\"someid\"));\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(Arrays.asList(asset));\n when(assetVariableResolver.fromAsset(asset)).thenReturn(variableSource);\n ComponentXO componentXO = underTest.readComponent(\"someid\", \"testRepositoryName\");\n\n assertThat(componentXO, is(notNullValue()));\n assertThat(componentXO.getId(), is(\"someid\"));\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n Supplier txSupplier = () -> storageTx;\n\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(txSupplier);\n when(storageTx.findComponent(componentId)).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(ImmutableList.of(asset));\n\n // When\n ComponentXO result = readComponent(repository, componentId);\n\n // Then\n verify(repository).facet(StorageFacet.class);\n verify(storageFacet).txSupplier();\n verify(storageTx).begin();\n verify(storageTx).findComponent(componentId);\n verify(storageTx).browseAssets(component);\n verify(storageTx).close();\n // Add assertions for the result\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 50, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ListResult listEntities(String tableName, String searchIndexName, List matchFilters,\n List queryFilters, List multiMatchFilter, String nextToken, List sorters, Class clazz) {\n SearchQuery searchQuery = createSearchQuery(matchFilters, queryFilters, multiMatchFilter);\n SearchRequest searchRequest = new SearchRequest(tableName, searchIndexName, searchQuery);\n if (!StringUtils.isEmpty(nextToken)) {\n byte[] tokenBytes = EncryptionUtil.decode(nextToken);\n searchRequest.getSearchQuery().setToken(tokenBytes);\n } else {\n if (sorters != null &&!sorters.isEmpty()) {\n searchQuery.setSort(new Sort(sorters));\n }\n }\n SearchRequest.ColumnsToGet columnsToGet = new SearchRequest.ColumnsToGet();\n columnsToGet.setColumns(ReflectionUtil.getPropertyNames(clazz));\n searchRequest.setColumnsToGet(columnsToGet);\n log.info(\"searchRequest:{}\", JsonUtil.toJsonString(searchRequest));\n SearchResponse searchResponse = otsClient.search(searchRequest);\n if (searchResponse == null || searchResponse.getRows() == null) {\n return ListResult.genSuccessListResult(null, 0);\n }\n byte[] nextTokenBytes = searchResponse.getNextToken();\n nextToken = nextTokenBytes == null || nextTokenBytes.length == 0 ? null : EncryptionUtil.encode(nextTokenBytes);\n List result = searchResponse.getRows().stream()\n .map(row -> OtsUtil.convertRowToDTO(row, clazz))\n .collect(Collectors.toList());\n return ListResult.genSuccessListResult(result, searchResponse.getTotalCount(), nextToken);\n }\n```\n###test function signature:\n```java\n @Test\n void testListEntities()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListEntities", + "reference": " @Test\n void testListEntities() {\n final List matchFilters = createMatchFilters();\n final List queryFilters = createQueryFilters();\n final ListResult expectedResult = new ListResult<>();\n FieldSort fieldSort = new FieldSort(OrderOtsConstant.GMT_CREATE_LONG);\n fieldSort.setOrder(SortOrder.DESC);\n expectedResult.setData(Arrays.asList());\n expectedResult.setCount(0L);\n String nextToken = \"CAESFQoTChEKDWdtdENyZWF0ZUxvbmcQARgBIlQKCQBI8UqGigEAAApHA0IAAAAxUzM1MzQzMTM0NjQzMjYzMzAzMzYyMzE2MTMzMzkzOTM1MzEzNjM2MzM2NDM2MzAzMDMwNjYzNTM1MzA2NjY0MzM=\";\n expectedResult.setNextToken(nextToken);\n final SearchResponse searchResponse = new SearchResponse(new Response(\"requestId\"));\n when(mockOtsClient.search(any(SearchRequest.class))).thenReturn(searchResponse);\n ListResult result = baseOtsHelper.listEntities(\"order\", \"order_index\", matchFilters, queryFilters, null, nextToken, Arrays.asList(fieldSort), OrderDTO.class);\n assertThat(result.getCount()).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n void testListEntities() {\n // Given\n String tableName = \"testTable\";\n String searchIndexName = \"testIndex\";\n List matchFilters = Arrays.asList(new OtsFilter(\"field1\", OtsFilter.CompareOperator.EQUAL, \"value1\"),\n new OtsFilter(\"field2\", OtsFilter.CompareOperator.EQUAL, \"value2\"));\n List queryFilters = Arrays.asList(new OtsFilter(\"field3\", OtsFilter.CompareOperator.EQUAL, \"value3\"),\n new OtsFilter(\"field4\", OtsFilter.CompareOperator.EQUAL, \"value4\"));\n List multiMatchFilter = Arrays.asList(new OtsFilter(\"field5\", OtsFilter.CompareOperator.EQUAL, \"value5\"),\n new OtsFilter(\"field6\", OtsFilter.CompareOperator.EQUAL, \"value6\"));\n String nextToken = \"testNextToken\";\n List sorters = Arrays.asList(new Sort.Sorter(\"field7\", Sort.SortOrder.ASC),\n new Sort.Sorter(\"field8\", Sort.SortOrder.DESC));\n Class clazz = TestDTO.class;\n\n // When\n ListResult result = listEntities(tableName, searchIndexName, matchFilters, queryFilters, multiMatchFilter, nextToken, sorters, clazz);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalCount());\n assertNull(result.getNextToken());\n assertTrue(result.getData().isEmpty());\n }\n```\n" + ], + "line": 38, + "token": 346, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 51, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String[] listFiles(URI fileUri, boolean recursive) throws IOException {\n try {\n ImmutableList.Builder builder = ImmutableList.builder();\n String continuationToken = null;\n boolean isDone = false;\n String prefix = normalizeToDirectoryPrefix(fileUri);\n int fileCount = 0;\n while (!isDone) {\n ListObjectsV2Request.Builder listObjectsV2RequestBuilder =\n ListObjectsV2Request.builder().bucket(fileUri.getHost());\n if (!prefix.equals(DELIMITER)) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.prefix(prefix);\n }\n if (!recursive) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.delimiter(DELIMITER);\n }\n if (continuationToken != null) {\n listObjectsV2RequestBuilder.continuationToken(continuationToken);\n }\n ListObjectsV2Request listObjectsV2Request = listObjectsV2RequestBuilder.build();\n LOG.debug(\"Trying to send ListObjectsV2Request {}\", listObjectsV2Request);\n ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);\n LOG.debug(\"Getting ListObjectsV2Response: {}\", listObjectsV2Response);\n List filesReturned = listObjectsV2Response.contents();\n fileCount += filesReturned.size();\n filesReturned.stream()\n .forEach(\n object -> {\n // Only add files and not directories\n if (!object.key().equals(fileUri.getPath())\n && !object.key().endsWith(DELIMITER)) {\n String fileKey = object.key();\n if (fileKey.startsWith(DELIMITER)) {\n fileKey = fileKey.substring(1);\n }\n builder.add(S3_SCHEME + fileUri.getHost() + DELIMITER + fileKey);\n }\n });\n if (fileCount == LIST_MAX_KEYS) {\n // check if we reached the max keys returned, if so abort and throw an error message\n LOG.error(\n \"Too many files ({}) returned from S3 when attempting to list object prefixes\",\n LIST_MAX_KEYS);\n throw new IllegalStateException(\n String.format(\n \"Max keys (%s) reached when attempting to list S3 objects\", LIST_MAX_KEYS));\n }\n isDone = !listObjectsV2Response.isTruncated();\n continuationToken = listObjectsV2Response.nextContinuationToken();\n }\n String[] listedFiles = builder.build().toArray(new String[0]);\n LOG.debug(\n \"Listed {} files from URI: {}, is recursive: {}\", listedFiles.length, fileUri, recursive);\n return listedFiles;\n } catch (Throwable t) {\n throw new IOException(t);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListFilesInFolderRecursive", + "reference": " @Test\n public void testListFilesInFolderRecursive() throws Exception {\n String folder = \"list-files-rec\";\n String[] nestedFolders = new String[] {\"list-files-child-1\", \"list-files-child-2\"};\n String[] originalFiles = new String[] {\"a-list-3.txt\", \"b-list-3.txt\", \"c-list-3.txt\"};\n\n List expectedResultList = new ArrayList<>();\n for (String childFolder : nestedFolders) {\n String folderName = folder + DELIMITER + childFolder;\n for (String fileName : originalFiles) {\n createEmptyFile(folderName, fileName);\n expectedResultList.add(\n String.format(FILE_FORMAT, SCHEME, bucket, folderName + DELIMITER + fileName));\n }\n }\n String[] actualFiles =\n s3BlobFs.listFiles(URI.create(String.format(FILE_FORMAT, SCHEME, bucket, folder)), true);\n\n actualFiles =\n Arrays.stream(actualFiles).filter(x -> x.contains(\"list-3\")).toArray(String[]::new);\n assertEquals(actualFiles.length, expectedResultList.size());\n assertTrue(Arrays.equals(expectedResultList.toArray(), actualFiles));\n }", + "completions": [ + "```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception {\n // Given\n URI fileUri = new URI(\"s3://bucket/folder/\");\n boolean recursive = true;\n String[] expectedFiles = {\n \"s3://bucket/folder/file1.txt\",\n \"s3://bucket/folder/file2.txt\",\n \"s3://bucket/folder/subfolder/file3.txt\",\n \"s3://bucket/folder/subfolder/file4.txt\"\n };\n\n // When\n String[] actualFiles = s3FileSystem.listFiles(fileUri, recursive);\n\n // Then\n assertArrayEquals(expectedFiles, actualFiles);\n }\n```\n" + ], + "line": 70, + "token": 490, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 52, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_nameDuplicate", + "reference": " @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // mock 数据\n RoleDO roleDO = randomPojo(RoleDO.class, o -> o.setName(\"role_name\"));\n roleMapper.insert(roleDO);\n // 准备参数\n String name = \"role_name\";\n\n // 调用,并断言异常\n assertServiceException(() -> roleService.validateRoleDuplicate(name, randomString(), null),\n ROLE_NAME_DUPLICATE, name);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // Given\n String name = \"duplicateName\";\n String code = \"uniqueCode\";\n Long id = 1L;\n RoleDO existingRole = new RoleDO();\n existingRole.setId(2L);\n existingRole.setName(name);\n existingRole.setCode(code);\n when(roleMapper.selectByName(name)).thenReturn(existingRole);\n\n // When\n try {\n roleService.validateRoleDuplicate(name, code, id);\n fail(\"Expected exception not thrown\");\n } catch (ServiceException e) {\n // Then\n assertEquals(ROLE_NAME_DUPLICATE, e.getCode());\n assertEquals(name, e.getArgs()[0]);\n }\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public String sign(Map map) {\n try {\n Date date = new Date(System.currentTimeMillis() + jwtProperties.getExpireTimeMills());\n Algorithm algorithm = Algorithm.HMAC256(jwtProperties.getTokenSecret());\n Map header = new HashMap<>(2);\n header.put(\"typ\", \"jwt\");\n JWTCreator.Builder builder = JWT.create()\n .withHeader(header)\n .withIssuedAt(new Date())\n .withExpiresAt(date);\n map.entrySet().forEach(entry -> {\n if (entry.getValue() instanceof Integer) {\n builder.withClaim(entry.getKey(), (Integer) entry.getValue());\n } else if (entry.getValue() instanceof Long) {\n builder.withClaim(entry.getKey(), (Long) entry.getValue());\n } else if (entry.getValue() instanceof Boolean) {\n builder.withClaim(entry.getKey(), (Boolean) entry.getValue());\n } else if (entry.getValue() instanceof String) {\n builder.withClaim(entry.getKey(), String.valueOf(entry.getValue()));\n } else if (entry.getValue() instanceof Double) {\n builder.withClaim(entry.getKey(), (Double) entry.getValue());\n } else if (entry.getValue() instanceof Date) {\n builder.withClaim(entry.getKey(), (Date) entry.getValue());\n }\n });\n return builder.sign(algorithm);\n } catch (Exception e) {\n return null;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sign_generateJwtToken_succeed()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sign_generateJwtToken_succeed", + "reference": " @Test\n public void sign_generateJwtToken_succeed() {\n Map map = new HashMap<>();\n map.put(\"userId\", 123);\n map.put(\"username\", \"testUser\");\n String token = jwtService.sign(map);\n Assert.assertNotNull(token);\n }", + "completions": [ + "```java\n @Test\n public void sign_generateJwtToken_succeed() {\n // Given\n Map map = new HashMap<>();\n map.put(\"userId\", 1);\n map.put(\"username\", \"testUser\");\n map.put(\"isAdmin\", true);\n map.put(\"expireTime\", System.currentTimeMillis() + 3600000);\n\n // When\n String jwtToken = jwtService.sign(map);\n\n // Then\n assertNotNull(jwtToken);\n assertTrue(jwtToken.length() > 0);\n }\n```\n" + ], + "line": 41, + "token": 349, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 54, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_non_serializable_value_with_null_value", + "reference": " @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map original = new LinkedHashMap<>();\n original.put(\"null\", null);\n original.put(\"key\", new NonSerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map map = new HashMap<>();\n map.put(\"key\", null);\n\n Map clonedMap = (Map) clone(map);\n\n assertNotSame(map, clonedMap);\n assertEquals(map, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 55, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void init(final DiscoveryConfig config) {\n if (this.namingService != null) {\n LOGGER.info(\"Nacos naming service already registered\");\n return;\n }\n Properties properties = config.getProps();\n Properties nacosProperties = new Properties();\n this.groupName = properties.getProperty(\"groupName\", \"SHENYU_GROUP\");\n String serverAddr = config.getServerList();\n nacosProperties.put(PropertyKeyConst.SERVER_ADDR, serverAddr);\n nacosProperties.put(PropertyKeyConst.NAMESPACE, properties.getProperty(NAMESPACE, \"\"));\n nacosProperties.put(PropertyKeyConst.USERNAME, properties.getProperty(PropertyKeyConst.USERNAME, \"\"));\n nacosProperties.put(PropertyKeyConst.PASSWORD, properties.getProperty(PropertyKeyConst.PASSWORD, \"\"));\n nacosProperties.put(PropertyKeyConst.ACCESS_KEY, properties.getProperty(PropertyKeyConst.ACCESS_KEY, \"\"));\n nacosProperties.put(PropertyKeyConst.SECRET_KEY, properties.getProperty(PropertyKeyConst.SECRET_KEY, \"\"));\n try {\n this.namingService = NamingFactory.createNamingService(nacosProperties);\n LOGGER.info(\"Nacos naming service initialized success\");\n } catch (NacosException e) {\n LOGGER.error(\"Error initializing Nacos naming service\", e);\n throw new ShenyuException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testInit", + "reference": " @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n // Set the discovery config\n setField(nacosDiscoveryServiceUnderTest.getClass(), \"namingService\", null);\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n config.setServerList(\"127.0.0.1:8848\");\n properties.setProperty(\"groupName\", \"SHENYU_GROUP\");\n config.setProps(properties);\n\n try (MockedStatic mockedNamingFactory = mockStatic(NamingFactory.class)) {\n // Mock the successful creation of NamingService\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenReturn(namingService);\n nacosDiscoveryServiceUnderTest.init(config);\n mockedNamingFactory.verify(() -> NamingFactory.createNamingService(any(Properties.class)));\n assertEquals(namingService, getField(nacosDiscoveryServiceUnderTest, \"namingService\"));\n // Mock the situation where NamingService fails to be created and throws an exception\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenThrow(new NacosException());\n assertDoesNotThrow(() -> nacosDiscoveryServiceUnderTest.init(config));\n }\n }", + "completions": [ + "```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n properties.put(\"groupName\", \"SHENYU_GROUP\");\n properties.put(\"serverAddr\", \"localhost:8848\");\n properties.put(\"namespace\", \"\");\n properties.put(\"username\", \"\");\n properties.put(\"password\", \"\");\n properties.put(\"accessKey\", \"\");\n properties.put(\"secretKey\", \"\");\n config.setProps(properties);\n config.setServerList(\"localhost:8848\");\n\n NacosDiscoveryService nacosDiscoveryService = new NacosDiscoveryService();\n nacosDiscoveryService.init(config);\n\n Field namingServiceField = NacosDiscoveryService.class.getDeclaredField(\"namingService\");\n namingServiceField.setAccessible(true);\n NamingService namingService = (NamingService) namingServiceField.get(nacosDiscoveryService);\n\n assertNotNull(namingService);\n }\n```\n" + ], + "line": 35, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 56, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyRecoveryNotBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(850, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 2) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(750, 0, indexerConfig));\n\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n assertThat(recoveryTaskCreator.determineStartingOffset(1201, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTaskCreator.determineStartingOffset(1200, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(1150, 0, indexerConfig));\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore))\n .contains(recoveryTask1, recoveryTask11);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 57, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public boolean isJsonMime(String mediaType) {\n // \"* / *\" is default to JSON\n if (\"*/*\".equals(mediaType)) {\n return true;\n }\n\n try {\n return isJsonMime(MediaType.parseMediaType(mediaType));\n } catch (InvalidMediaTypeException e) {\n }\n return false;\n }\n```\n###test function signature:\n```java\n @Test\n public void testIsJsonMime()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJsonMime", + "reference": " @Test\n public void testIsJsonMime() {\n assertFalse(apiClient.isJsonMime(null));\n assertFalse(apiClient.isJsonMime(\"\"));\n assertFalse(apiClient.isJsonMime(\"text/plain\"));\n assertFalse(apiClient.isJsonMime(\"application/xml\"));\n assertFalse(apiClient.isJsonMime(\"application/jsonp\"));\n assertFalse(apiClient.isJsonMime(\"example/json\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+jsonx\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+xjson\"));\n\n assertTrue(apiClient.isJsonMime(\"application/json\"));\n assertTrue(apiClient.isJsonMime(\"application/json; charset=UTF8\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/JSON\"));\n\n assertTrue(apiClient.isJsonMime(\"application/problem+json\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/PROBLEM+JSON\"));\n assertTrue(apiClient.isJsonMime(\"application/json\\t\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+bar+json\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json;x;y\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json\\t;\"));\n assertTrue(apiClient.isJsonMime(\"Example/fOO+JSON\"));\n }", + "completions": [ + "```java\n @Test\n public void testIsJsonMime() {\n // Test case 1: \"* / *\"\n assertTrue(isJsonMime(\"*/*\"));\n\n // Test case 2: \"application/json\"\n assertTrue(isJsonMime(\"application/json\"));\n\n // Test case 3: \"application/json; charset=UTF-8\"\n assertTrue(isJsonMime(\"application/json; charset=UTF-8\"));\n\n // Test case 4: \"application/xml\"\n assertFalse(isJsonMime(\"application/xml\"));\n\n // Test case 5: \"text/plain\"\n assertFalse(isJsonMime(\"text/plain\"));\n\n // Test case 6: \"invalid/media/type\"\n assertFalse(isJsonMime(\"invalid/media/type\"));\n\n // Test case 7: null\n assertFalse(isJsonMime(null));\n }\n```\n" + ], + "line": 22, + "token": 139, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n dpopService.validateDpopThumprint(\"test\", \"test\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n // Given\n String existingThumprint = \"existingThumprint\";\n String requestThumprint = \"existingThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n // Call the method to test\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String getExpandUrl(String url, Object paramObj) {\n UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url);\n if (paramObj instanceof Map) {\n for (Map.Entry param : ((Map) paramObj).entrySet()) {\n if (null != param.getValue()) {\n builder.queryParam(param.getKey(), param.getValue());\n }\n }\n } else {\n for (Field param : paramObj.getClass().getDeclaredFields()) {\n param.setAccessible(true);\n\n try {\n Object value = param.get(paramObj);\n if (null != value) {\n if (value instanceof List) {\n builder.queryParam(param.getName(), (List) value);\n } else {\n builder.queryParam(param.getName(), value);\n }\n }\n } catch (IllegalAccessException e) {\n log.warn(\"Field: {} value retrieval failed, exception: {}\", param.getName(), e);\n }\n }\n }\n\n return builder.build().encode().toUri().toString();\n }\n```\n###test function signature:\n```java\n @Test\n void testGetExpandUrl()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetExpandUrl", + "reference": " @Test\n void testGetExpandUrl() {\n String result = HttpUtil.getExpandUrl(\"http://baidu.com\", \"paramObj\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testGetExpandUrl() {\n // Test with a Map\n Map map = new HashMap<>();\n map.put(\"param1\", \"value1\");\n map.put(\"param2\", \"value2\");\n String url = \"http://example.com\";\n String expected = \"http://example.com?param1=value1¶m2=value2\";\n String actual = getExpandUrl(url, map);\n assertEquals(expected, actual);\n\n // Test with a custom object\n TestObject obj = new TestObject(\"value1\", \"value2\");\n expected = \"http://example.com?param1=value1¶m2=value2\";\n actual = getExpandUrl(url, obj);\n assertEquals(expected, actual);\n }\n\n public static class TestObject {\n private String param1;\n private String param2;\n\n public TestObject(String param1, String param2) {\n this.param1 = param1;\n this.param2 = param2;\n }\n\n public String getParam1() {\n return param1;\n }\n\n public String getParam2() {\n return param2;\n }\n }\n```\n" + ], + "line": 38, + "token": 277, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 60, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"asdf\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n String invalidExpirationTime = \"invalid\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(invalidExpirationTime);\n assertEquals(CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 61, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Get\n @Path(\"/:indexName/_mapping\")\n public HttpResponse mapping(\n @Param(\"indexName\") Optional indexName,\n @Param(\"startTimeEpochMs\") Optional startTimeEpochMs,\n @Param(\"endTimeEpochMs\") Optional endTimeEpochMs)\n throws IOException {\n // Use a tree map so the results are naturally sorted\n Map> propertiesMap = new TreeMap<>();\n\n // we default the schema search to the last hour if params are not provided\n AstraSearch.SchemaResult schemaResult =\n searcher.getSchema(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(\n startTimeEpochMs.orElse(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build());\n\n Map schema = SearchResultUtils.fromSchemaResultProto(schemaResult);\n schema.forEach((key, value) -> propertiesMap.put(key, Map.of(\"type\", value.getName())));\n\n // todo - remove this after we add support for a \"date\" type\n // override the timestamp as a date field for proper autocomplete\n propertiesMap.put(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, Map.of(\"type\", \"date\"));\n\n return HttpResponse.of(\n HttpStatus.OK,\n MediaType.JSON,\n JsonUtil.writeAsString(\n ImmutableMap.of(\n indexName.orElseThrow(),\n ImmutableMap.of(\"mappings\", ImmutableMap.of(\"properties\", propertiesMap)))));\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexMapping() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexMapping", + "reference": " @Test\n public void testIndexMapping() throws IOException {\n AstraQueryServiceBase searcher = mock(AstraQueryServiceBase.class);\n ElasticsearchApiService serviceUnderTest = new ElasticsearchApiService(searcher);\n\n Instant start = Instant.now();\n Instant end = start.minusSeconds(60);\n\n when(searcher.getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build())))\n .thenReturn(AstraSearch.SchemaResult.newBuilder().build());\n\n HttpResponse response =\n serviceUnderTest.mapping(\n Optional.of(\"foo\"), Optional.of(start.toEpochMilli()), Optional.of(end.toEpochMilli()));\n verify(searcher)\n .getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build()));\n\n // handle response\n AggregatedHttpResponse aggregatedRes = response.aggregate().join();\n String body = aggregatedRes.content(StandardCharsets.UTF_8);\n JsonNode jsonNode = new ObjectMapper().readTree(body);\n\n assertThat(aggregatedRes.status().code()).isEqualTo(200);\n\n assertThat(jsonNode.findValue(\"foo\")).isNotNull();\n assertThat(\n jsonNode.findValue(\"foo\").findValue(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName))\n .isNotNull();\n\n when(searcher.getSchema(any()))\n .thenAnswer(\n invocationOnMock -> {\n AstraSearch.SchemaRequest request =\n ((AstraSearch.SchemaRequest) invocationOnMock.getArguments()[0]);\n assertThat(request.getDataset()).isEqualTo(\"bar\");\n assertThat(request.getStartTimeEpochMs())\n .isCloseTo(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli(),\n Offset.offset(1000L));\n assertThat(request.getEndTimeEpochMs())\n .isCloseTo(Instant.now().toEpochMilli(), Offset.offset(1000L));\n return AstraSearch.SchemaResult.newBuilder().build();\n });\n serviceUnderTest.mapping(Optional.of(\"bar\"), Optional.empty(), Optional.empty());\n }", + "completions": [ + "```java\n @Test\n public void testIndexMapping() throws IOException {\n Optional indexName = Optional.of(\"testIndex\");\n Optional startTimeEpochMs = Optional.of(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli());\n Optional endTimeEpochMs = Optional.of(Instant.now().toEpochMilli());\n\n AstraSearch.SchemaResult schemaResult =\n AstraSearch.SchemaResult.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(startTimeEpochMs.orElse(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build();\n\n when(searcher.getSchema(any(AstraSearch.SchemaRequest.class))).thenReturn(schemaResult);\n\n HttpResponse response = mapping(indexName, startTimeEpochMs, endTimeEpochMs);\n\n assertEquals(HttpStatus.OK, response.getStatus());\n assertEquals(MediaType.JSON, response.getContentType());\n assertNotNull(response.getContent());\n }\n```\n" + ], + "line": 43, + "token": 354, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 62, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "@SuppressWarnings", + "reference": " @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n consumer = newConsumer();\n Exception e = assertThrows(UnsupportedOperationException.class, () -> consumer.poll(0L));\n assertEquals(\"Consumer.poll(long) is not supported when \\\"group.protocol\\\" is \\\"consumer\\\". \" +\n \"This method is deprecated and will be removed in the next major release.\", e.getMessage());\n }", + "completions": [ + "```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 46, + "token": 373, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 63, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testNullSearchString()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNullSearchString", + "reference": " @Test\n public void testNullSearchString() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n assertThatExceptionOfType(IllegalArgumentException.class)\n .isThrownBy(\n () ->\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME + \"miss\",\n null,\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\")));\n }", + "completions": [ + "```java\n @Test\n public void testNullSearchString() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = null;\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n try {\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n fail(\"Expected an IllegalArgumentException to be thrown\");\n } catch (IllegalArgumentException e) {\n // Expected\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 64, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture sendMessageBack(ProxyContext ctx, ReceiptHandle handle, String messageId,\n ConsumerSendMsgBackRequestHeader requestHeader, long timeoutMillis) {\n // Build the response.\n final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null, null);\n\n Integer delayLevel = requestHeader.getDelayLevel();\n if (Objects.isNull(delayLevel)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument delay level is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n Long offset = requestHeader.getOffset();\n if (Objects.isNull(offset)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument offset is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n VirtualQueue virtualQueue = new VirtualQueue(requestHeader.getBname());\n\n CompletableFuture topicFuture = topicOf(requestHeader.getOriginTopic());\n CompletableFuture groupFuture = consumerGroupOf(requestHeader.getGroup());\n return topicFuture.thenCombine(groupFuture, (topic, group) -> {\n if (topic.getTopicId() != virtualQueue.topicId()) {\n LOGGER.error(\"Topic id in request header {} does not match topic id in message queue {}, maybe the topic is recreated.\",\n topic.getTopicId(), virtualQueue.topicId());\n throw new ProxyException(apache.rocketmq.v2.Code.TOPIC_NOT_FOUND, \"Topic resource does not exist.\");\n }\n return Pair.of(topic, group);\n }).thenCompose(pair -> {\n Topic topic = pair.getLeft();\n ConsumerGroup group = pair.getRight();\n\n return store.pull(StoreContext.EMPTY, group.getGroupId(), topic.getTopicId(), virtualQueue.physicalQueueId(),\n Filter.DEFAULT_FILTER, requestHeader.getOffset(), 1, false)\n .thenApply(pullResult -> {\n if (pullResult.status() == com.automq.rocketmq.store.model.message.PullResult.Status.FOUND) {\n return pullResult.messageList().get(0);\n }\n throw new ProxyException(apache.rocketmq.v2.Code.MESSAGE_NOT_FOUND, \"Message not found from server.\");\n }).thenCompose(messageExt -> {\n if (messageExt.deliveryAttempts() > group.getMaxDeliveryAttempt()) {\n return deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n }\n\n // Message consume retry strategy\n // <0: no retry,put into DLQ directly\n // =0: broker control retry frequency\n // >0: client control retry frequency\n return switch (Integer.compare(delayLevel, 0)) {\n case -1 ->\n deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n case 0 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n // Keep the same logic as apache RocketMQ.\n int serverDelayLevel = messageExt.deliveryAttempts() + 1;\n messageExt.setDeliveryAttempts(serverDelayLevel);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(serverDelayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put messageExt to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n case 1 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n messageExt.setDeliveryAttempts(messageExt.deliveryAttempts() + 1);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(delayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put message to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n default -> throw new IllegalStateException(\"Never reach here\");\n };\n });\n }).whenComplete((nil, throwable) -> {\n if (throwable != null) {\n response.setCode(ResponseCode.SYSTEM_ERROR);\n response.setRemark(throwable.getMessage());\n }\n }).thenApply(nil -> response);\n }\n```\n###test function signature:\n```java\n @Test\n void sendMessageBack()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendMessageBack", + "reference": " @Test\n void sendMessageBack() {\n String topicName = \"topic\";\n VirtualQueue virtualQueue = new VirtualQueue(2, 0);\n ConsumerSendMsgBackRequestHeader header = new ConsumerSendMsgBackRequestHeader();\n header.setBname(virtualQueue.brokerName());\n header.setGroup(\"group\");\n header.setOriginTopic(topicName);\n header.setOffset(0L);\n header.setDelayLevel(0);\n\n CompletionException exception = assertThrowsExactly(CompletionException.class, () -> messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join());\n assertInstanceOf(ProxyException.class, ExceptionUtils.getRealException(exception));\n ProxyException realException = (ProxyException) ExceptionUtils.getRealException(exception);\n assertEquals(realException.getErrorCode(), Code.MESSAGE_NOT_FOUND);\n assertEquals(realException.getMessage(), \"Message not found from server.\");\n\n long topicId = metadataService.topicOf(topicName).join().getTopicId();\n messageStore.put(StoreContext.EMPTY, FlatMessageUtil.convertTo(ProxyContextExt.create(), topicId, 0, \"\", new Message(topicName, \"\", new byte[] {})));\n\n // Broker controlled delay level.\n header.setDelayLevel(0);\n RemotingCommand response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Client controlled delay level.\n header.setDelayLevel(16);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Forward message into dead letter topic.\n header.setDelayLevel(-1);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.times(1)).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n }", + "completions": [ + "```java\n @Test\n void sendMessageBack() {\n // Given\n ProxyContext ctx = mock(ProxyContext.class);\n ReceiptHandle handle = mock(ReceiptHandle.class);\n String messageId = \"testMessageId\";\n ConsumerSendMsgBackRequestHeader requestHeader = mock(ConsumerSendMsgBackRequestHeader.class);\n long timeoutMillis = 1000L;\n\n when(requestHeader.getDelayLevel()).thenReturn(1);\n when(requestHeader.getOffset()).thenReturn(1L);\n when(requestHeader.getOriginTopic()).thenReturn(\"testOriginTopic\");\n when(requestHeader.getGroup()).thenReturn(\"testGroup\");\n when(requestHeader.getBname()).thenReturn(\"testBname\");\n\n Topic topic = mock(Topic.class);\n when(topic.getTopicId()).thenReturn(1);\n\n ConsumerGroup group = mock(ConsumerGroup.class);\n when(group.getGroupId()).thenReturn(\"testGroupId\");\n when(group.getMaxDeliveryAttempt()).thenReturn(3);\n\n FlatMessage message = mock(FlatMessage.class);\n when(message.deliveryAttempts()).thenReturn(1);\n when(message.originalOffset()).thenReturn(1L);\n\n MessageExt messageExt = mock(MessageExt.class);\n when(messageExt.message()).thenReturn(message);\n when(messageExt.deliveryAttempts()).thenReturn(1);\n\n PullResult pullResult = mock(PullResult.class);\n when(pullResult.status()).thenReturn(PullResult.Status.FOUND);\n when(pullResult.messageList()).thenReturn(List.of(messageExt));\n\n Store store = mock(Store.class);\n when(store.pull(any(), any(), any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(pullResult));\n when(store.put(any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n DeadLetterService deadLetterService = mock(DeadLetterService.class);\n when(deadLetterService.send(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n ProxyContextExt ctxExt = mock(ProxyContextExt.class);\n\n TopicService topicService = mock(TopicService.class);\n when(topicService.topicOf(any())).thenReturn(CompletableFuture.completedFuture(topic));\n when(topicService.consumerGroupOf(any())).thenReturn(CompletableFuture.completedFuture(group));\n\n ConsumerService consumerService = new ConsumerService(store, deadLetterService, topicService);\n\n // When\n CompletableFuture future = consumerService.sendMessageBack(ctx, handle, messageId, requestHeader, timeoutMillis);\n\n // Then\n RemotingCommand response = future.join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n }\n```\n" + ], + "line": 99, + "token": 793, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 65, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInRequestThreads()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInRequestThreads", + "reference": " @Test\n public void testGetTracerInRequestThreads() {\n ApolloAuditTracer mockTracer = new ApolloAuditTracer(Mockito.mock(ApolloAuditScopeManager.class), supplier);\n RequestAttributes mockRequestAttributes = Mockito.mock(RequestAttributes.class);\n RequestContextHolder.setRequestAttributes(mockRequestAttributes);\n Mockito.when(mockRequestAttributes.getAttribute(Mockito.eq(ApolloAuditConstants.TRACER), Mockito.eq(RequestAttributes.SCOPE_REQUEST)))\n .thenReturn(mockTracer);\n ApolloAuditTracer get = traceContext.tracer();\n assertNotNull(get);\n Mockito.verify(traceContext, Mockito.times(0))\n .setTracer(Mockito.any(ApolloAuditTracer.class));\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInRequestThreads() {\n // Given\n RequestAttributes requestAttributes = Mockito.mock(RequestAttributes.class);\n ApolloAuditTracer tracer = Mockito.mock(ApolloAuditTracer.class);\n Mockito.when(requestAttributes.getAttribute(ApolloAuditConstants.TRACER, RequestAttributes.SCOPE_REQUEST)).thenReturn(tracer);\n Mockito.when(RequestContextHolder.getRequestAttributes()).thenReturn(requestAttributes);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n Assert.assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Map generate(String modelName) {\n Map kv = new HashMap<>();\n if (modelName != null) {\n final Schema schema = this.globalContext.getSchemaMap().get(modelName);\n if (schema != null) {\n String example = Json.pretty(this.resolveModelToExample(modelName, schema));\n\n if (example != null) {\n kv.put(EXAMPLE, example);\n return Map.copyOf(kv);\n }\n }\n }\n return Collections.emptyMap();\n }\n```\n###test function signature:\n```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "shouldGenerateOneOfWhenOneOfInRoot", + "reference": " @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n OpenAPIModelGenerator generator = setupPayloadGenerator();\n\n Map example = generator.generate(\"PetType\");\n String exampleJson = example.get(\"example\");\n\n Assertions.assertThat(exampleJson)\n .contains(\"PetTypeONE_OF#/components/schemas/Husky\")\n .contains(\"PetTypeONE_OF#/components/schemas/Labrador\");\n }", + "completions": [ + "```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n // Given\n String modelName = \"ModelName\";\n Schema schema = new Schema();\n schema.setOneOf(Arrays.asList(new Schema(), new Schema()));\n this.globalContext.getSchemaMap().put(modelName, schema);\n\n // When\n Map result = this.generate(modelName);\n\n // Then\n assertNotNull(result);\n assertFalse(result.isEmpty());\n assertTrue(result.containsKey(EXAMPLE));\n assertNotNull(result.get(EXAMPLE));\n }\n```\n" + ], + "line": 25, + "token": 181, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInAnotherThreadButSameRequest", + "reference": " @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n ApolloAuditTracer mockTracer = Mockito.mock(ApolloAuditTracer.class);\n {\n Mockito.when(traceContext.tracer()).thenReturn(mockTracer);\n }\n CountDownLatch latch = new CountDownLatch(1);\n Executors.newSingleThreadExecutor().submit(() -> {\n ApolloAuditTracer tracer = traceContext.tracer();\n\n assertEquals(mockTracer, tracer);\n\n latch.countDown();\n });\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n // Given\n RequestAttributes requestAttributes = new ServletRequestAttributes(new MockHttpServletRequest());\n RequestContextHolder.setRequestAttributes(requestAttributes);\n ApolloAuditTracer tracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n requestAttributes.setAttribute(ApolloAuditConstants.TRACER, tracer, RequestAttributes.SCOPE_REQUEST);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchMultipleItemsAndIndices()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchMultipleItemsAndIndices", + "reference": " @Test\n public void testSearchMultipleItemsAndIndices() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(1);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchMultipleItemsAndIndices() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(10, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 69, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n static Map>> getCandidateJobs(String tableNameWithType,\n Map> allJobMetadata)\n throws Exception {\n long nowMs = System.currentTimeMillis();\n Map>> candidates = new HashMap<>();\n // If the job started most recently has already completed, then skip retry for the table.\n Pair latestStartedJob = null;\n Pair latestCompletedJob = null;\n // The processing order of job metadata from the given Map is not deterministic. Track the completed original\n // jobs so that we can simply skip the retry jobs belonging to the completed original jobs.\n Map completedOriginalJobs = new HashMap<>();\n Set cancelledOriginalJobs = new HashSet<>();\n for (Map.Entry> entry : allJobMetadata.entrySet()) {\n String jobId = entry.getKey();\n Map jobMetadata = entry.getValue();\n long statsUpdatedAt = Long.parseLong(jobMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));\n String jobStatsInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);\n if (StringUtils.isEmpty(jobStatsInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job progress stats\", jobId);\n continue;\n }\n String jobCtxInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n if (StringUtils.isEmpty(jobCtxInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job context\", jobId);\n continue;\n }\n TableRebalanceProgressStats jobStats = JsonUtils.stringToObject(jobStatsInStr, TableRebalanceProgressStats.class);\n TableRebalanceContext jobCtx = JsonUtils.stringToObject(jobCtxInStr, TableRebalanceContext.class);\n long jobStartTimeMs = jobStats.getStartTimeMs();\n if (latestStartedJob == null || latestStartedJob.getRight() < jobStartTimeMs) {\n latestStartedJob = Pair.of(jobId, jobStartTimeMs);\n }\n String originalJobId = jobCtx.getOriginalJobId();\n RebalanceResult.Status jobStatus = jobStats.getStatus();\n if (jobStatus == RebalanceResult.Status.DONE || jobStatus == RebalanceResult.Status.NO_OP) {\n LOGGER.info(\"Skip rebalance job: {} as it has completed with status: {}\", jobId, jobStatus);\n completedOriginalJobs.put(originalJobId, jobId);\n if (latestCompletedJob == null || latestCompletedJob.getRight() < jobStartTimeMs) {\n latestCompletedJob = Pair.of(jobId, jobStartTimeMs);\n }\n continue;\n }\n if (jobStatus == RebalanceResult.Status.FAILED || jobStatus == RebalanceResult.Status.ABORTED) {\n LOGGER.info(\"Found rebalance job: {} for original job: {} has been stopped with status: {}\", jobId,\n originalJobId, jobStatus);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n continue;\n }\n if (jobStatus == RebalanceResult.Status.CANCELLED) {\n LOGGER.info(\"Found cancelled rebalance job: {} for original job: {}\", jobId, originalJobId);\n cancelledOriginalJobs.add(originalJobId);\n continue;\n }\n // Check if an IN_PROGRESS job is still actively running.\n long heartbeatTimeoutMs = jobCtx.getConfig().getHeartbeatTimeoutInMs();\n if (nowMs - statsUpdatedAt < heartbeatTimeoutMs) {\n LOGGER.info(\"Rebalance job: {} is actively running with status updated at: {} within timeout: {}. Skip \"\n + \"retry for table: {}\", jobId, statsUpdatedAt, heartbeatTimeoutMs, tableNameWithType);\n return Collections.emptyMap();\n }\n // The job is considered failed, but it's possible it is still running, then we might end up with more than one\n // rebalance jobs running in parallel for a table. The rebalance algorithm is idempotent, so this should be fine\n // for the correctness.\n LOGGER.info(\"Found stuck rebalance job: {} for original job: {}\", jobId, originalJobId);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n }\n if (latestCompletedJob != null && latestCompletedJob.getLeft().equals(latestStartedJob.getLeft())) {\n LOGGER.info(\"Rebalance job: {} started most recently has already done. Skip retry for table: {}\",\n latestCompletedJob.getLeft(), tableNameWithType);\n return Collections.emptyMap();\n }\n for (String jobId : cancelledOriginalJobs) {\n LOGGER.info(\"Skip original job: {} as it's cancelled\", jobId);\n candidates.remove(jobId);\n }\n for (Map.Entry entry : completedOriginalJobs.entrySet()) {\n LOGGER.info(\"Skip original job: {} as it's completed by attempt: {}\", entry.getKey(), entry.getValue());\n candidates.remove(entry.getKey());\n }\n return candidates;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetCandidateJobs()\n throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetCandidateJobs", + "reference": " @Test\n public void testGetCandidateJobs()\n throws Exception {\n String tableName = \"table01\";\n Map> allJobMetadata = new HashMap<>();\n\n // Original job run as job1, and all its retry jobs failed too.\n RebalanceConfig jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n TableRebalanceProgressStats stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(1000);\n TableRebalanceContext jobCtx = TableRebalanceContext.forInitialAttempt(\"job1\", jobCfg);\n Map jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job1\", stats, jobCtx);\n allJobMetadata.put(\"job1\", jobMetadata);\n // 3 failed retry runs for job1\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 2, 1100, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 3, 1200, RebalanceResult.Status.ABORTED);\n allJobMetadata.put(\"job1_3\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 4, 1300, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_4\", jobMetadata);\n\n // Original job run as job2, and its retry job job2_1 completed.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(2000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job2\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job2\", stats, jobCtx);\n allJobMetadata.put(\"job2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job2\", 2, 2100, RebalanceResult.Status.DONE);\n allJobMetadata.put(\"job2_2\", jobMetadata);\n\n // Original job run as job3, and failed to send out heartbeat in time.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.IN_PROGRESS);\n stats.setStartTimeMs(3000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job3\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job3\", stats, jobCtx);\n jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"3000\");\n allJobMetadata.put(\"job3\", jobMetadata);\n\n // Original job run as job4, which didn't have retryJobCfg as from old version of the code.\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(4000);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job4\", stats, null);\n jobMetadata.remove(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n allJobMetadata.put(\"job4\", jobMetadata);\n\n // Only need to retry job1 and job3, as job2 is completed and job4 is from old version of code.\n Map>> jobs =\n RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 2);\n assertTrue(jobs.containsKey(\"job1\"));\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job1\").size(), 4); // four runs including job1,job1_1,job1_2,job1_3\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Abort job1 and cancel its retries, then only job3 is retry candidate.\n jobMetadata = allJobMetadata.get(\"job1_4\");\n cancelRebalanceJob(jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 1);\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Add latest job5 that's already done, thus no need to retry for table.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.DONE);\n stats.setStartTimeMs(5000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job5\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job5\", stats, jobCtx);\n allJobMetadata.put(\"job5\", jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 0);\n }", + "completions": [ + "```java\n @Test\n public void testGetCandidateJobs()\n throws Exception {\n // Given\n String tableNameWithType = \"table1_type1\";\n Map> allJobMetadata = new HashMap<>();\n Map jobMetadata1 = new HashMap<>();\n jobMetadata1.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"FAILED\\\"}\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job1\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job1\", jobMetadata1);\n Map jobMetadata2 = new HashMap<>();\n jobMetadata2.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"CANCELLED\\\"}\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job2\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job2\", jobMetadata2);\n\n // When\n Map>> result = getCandidateJobs(tableNameWithType, allJobMetadata);\n\n // Then\n assertEquals(1, result.size());\n assertTrue(result.containsKey(\"job2\"));\n Set> jobSet = result.get(\"job2\");\n assertEquals(1, jobSet.size());\n Pair jobPair = jobSet.iterator().next();\n assertEquals(\"job2\", jobPair.getLeft().getOriginalJobId());\n assertEquals(1640995200000L, jobPair.getRight().longValue());\n }\n```\n" + ], + "line": 94, + "token": 952, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 70, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchError()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchError", + "reference": " @Test\n public void distClientMatchError() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.ERROR);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.ERROR);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.NO_SUB));\n for (IMQTTTransientSession session : sessions) {\n verify(session, never()).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchError() {\n // Create a mock for the ITenantMeter interface\n ITenantMeter mockTenantMeter = Mockito.mock(ITenantMeter.class);\n // Create a mock for the ILocalDistService interface\n ILocalDistService mockLocalDistService = Mockito.mock(ILocalDistService.class);\n // Create a mock for the sessionMap\n Map mockSessionMap = new HashMap<>();\n // Create a mock for the routeMap\n Map> mockRouteMap = new HashMap<>();\n // Create a mock for the resourceThrottler\n IResourceThrottler mockResourceThrottler = Mockito.mock(IResourceThrottler.class);\n\n // Set up the mock objects\n Mockito.when(mockTenantMeter.get(Mockito.anyString())).thenReturn(mockTenantMeter);\n Mockito.when(mockLocalDistService.isGlobal(Mockito.anyString())).thenReturn(false);\n Mockito.when(mockSessionMap.get(Mockito.anyString())).thenReturn(null);\n Mockito.when(mockRouteMap.get(Mockito.any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(null));\n Mockito.when(mockResourceThrottler.hasResource(Mockito.anyString(), Mockito.anyLong())).thenReturn(false);\n\n // Create a new instance of the class under test\n DistService distService = new DistService(mockTenantMeter, mockLocalDistService, mockSessionMap, mockRouteMap, mockResourceThrottler);\n\n // Create a test request\n DeliveryRequest request = DeliveryRequest.newBuilder().build();\n\n // Call the method under test\n CompletableFuture replyFuture = distService.dist(request);\n\n // Verify the results\n DeliveryReply reply = replyFuture.join();\n Assert.assertEquals(0, reply.getResultMap().size());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 71, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchOK()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchOK", + "reference": " @Test\n public void distClientMatchOK() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n when(session.publish(any(), anyList())).thenReturn(true);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.OK);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.OK);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.OK));\n for (IMQTTTransientSession session : sessions) {\n verify(session, times(1)).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchOK() {\n // Given\n DeliveryRequest request = DeliveryRequest.newBuilder()\n .putPackage(\"tenant1\", DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .setMessagePack(TopicMessagePack.newBuilder()\n .setTopic(\"topic1\")\n .setPayload(\"payload1\")\n .build())\n .addMatchInfo(MatchInfo.newBuilder()\n .setReceiverId(\"receiver1\")\n .setTopicFilter(\"topic1\")\n .build())\n .build())\n .build())\n .build();\n sessionMap.put(\"receiver1\", mock(IMQTTTransientSession.class));\n when(sessionMap.get(\"receiver1\").publish(any(MatchInfo.class), anyList())).thenReturn(true);\n when(routeMap.get(any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(LocalRoutes.newBuilder()\n .setLocalizedReceiverId(\"receiver1\")\n .putRoute(\"receiver1\", mock(IMQTTTransientSession.class))\n .build()));\n when(ITenantMeter.get(\"tenant1\")).thenReturn(mock(ITenantMeter.class));\n when(resourceThrottler.hasResource(\"tenant1\", TotalTransientFanOutBytesPerSeconds)).thenReturn(true);\n\n // When\n CompletableFuture replyFuture = dist(request);\n\n // Then\n DeliveryReply reply = replyFuture.join();\n DeliveryResults results = reply.getResultMap().get(\"tenant1\");\n assertEquals(1, results.getResultCount());\n DeliveryResult result = results.getResult(0);\n assertEquals(DeliveryResult.Code.OK, result.getCode());\n assertEquals(\"receiver1\", result.getMatchInfo().getReceiverId());\n assertEquals(\"topic1\", result.getMatchInfo().getTopicFilter());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 72, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Range calcInterquartileRangeApproximation(List sortedValues) {\n Range iqr = null;\n\n if (sortedValues.size() > 0) {\n BigDecimal q1 = valueCloseToQ1(sortedValues);\n BigDecimal q3 = valueCloseToQ3(sortedValues);\n BigDecimal max = sortedValues.get(sortedValues.size() - 1);\n\n // if iqr == 0 AND max == q3 then recursively try finding a non-zero iqr approximation.\n if (q1.compareTo(q3) == 0 && max.compareTo(q3) == 0) {\n // filter out max and try again\n iqr = this.calcInterquartileRangeApproximation(\n sortedValues.stream().filter(d -> d.compareTo(max) < 0).collect(Collectors.toList()));\n }\n\n // if range is still empty use the original q1 and q3 values\n if (iqr == null || iqr.isEmpty()) {\n iqr = Range.closedOpen(q1, q3);\n }\n }\n\n return iqr;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCalcInterquartileRangeApproximation()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCalcInterquartileRangeApproximation", + "reference": " @Test\n public void testCalcInterquartileRangeApproximation() {\n List values = decList(1, 2, 3, 4, 5, 6, 7, 8);\n assertEquals(new BigDecimal(\"3\"), dataBinHelper.calcInterquartileRangeApproximation(values).lowerEndpoint());\n assertEquals(new BigDecimal(\"7\"), dataBinHelper.calcInterquartileRangeApproximation(values).upperEndpoint());\n }", + "completions": [ + "```java\n @Test\n public void testCalcInterquartileRangeApproximation() {\n List sortedValues = Arrays.asList(\n new BigDecimal(\"1.0\"),\n new BigDecimal(\"2.0\"),\n new BigDecimal(\"3.0\"),\n new BigDecimal(\"4.0\"),\n new BigDecimal(\"5.0\"),\n new BigDecimal(\"6.0\"),\n new BigDecimal(\"7.0\"),\n new BigDecimal(\"8.0\"),\n new BigDecimal(\"9.0\"),\n new BigDecimal(\"10.0\")\n );\n\n Range expectedIqr = Range.closedOpen(new BigDecimal(\"2.5\"), new BigDecimal(\"7.5\"));\n Range actualIqr = calcInterquartileRangeApproximation(sortedValues);\n\n assertEquals(expectedIqr, actualIqr);\n }\n```\n" + ], + "line": 30, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 73, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_non_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new NonSerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new NonSerializableObject(\"name2\"),\n new NonSerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObjectWithNonSerializableNestedObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serialable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key1\", \"value1\");\n map.put(\"key2\", new NonSerializableObject(\"nonSerializableValue\"));\n\n // When\n Map clonedMap = objectCloner.clone(map);\n\n // Then\n assertNotSame(map, clonedMap);\n assertEquals(map.size(), clonedMap.size());\n assertEquals(map.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(map.get(\"key2\"), clonedMap.get(\"key2\"));\n assertEquals(map.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 74, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SignedJWT token2Jwt(String token) {\n if (token == null) {\n log.error(\"Token is null.\");\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n try {\n return SignedJWT.parse(token);\n } catch (ParseException e) {\n log.error(\"Token to JWT Error,message = {}\", e.getMessage());\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testToken2Jwt()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testToken2Jwt", + "reference": " @Test\n public void testToken2Jwt() {\n Assertions.assertThrows(BizException.class, () -> tokenParseHelper.token2Jwt(null));\n }", + "completions": [ + "```java\n @Test\n public void testToken2Jwt() {\n // Given\n String token = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c\";\n\n // When\n SignedJWT jwt = token2Jwt(token);\n\n // Then\n assertNotNull(jwt);\n assertEquals(token, jwt.serialize());\n }\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testLongPollWaitIsLimited()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testLongPollWaitIsLimited", + "reference": " @Test\n public void testLongPollWaitIsLimited() {\n consumer = newConsumer();\n String topicName = \"topic1\";\n consumer.subscribe(singletonList(topicName));\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertTrue(consumer.assignment().isEmpty());\n\n final int partition = 3;\n final TopicPartition tp = new TopicPartition(topicName, partition);\n final List> records = asList(\n new ConsumerRecord<>(topicName, partition, 2, \"key1\", \"value1\"),\n new ConsumerRecord<>(topicName, partition, 3, \"key2\", \"value2\")\n );\n\n // On the first iteration, return no data; on the second, return two records\n doAnswer(invocation -> {\n // Mock the subscription being assigned as the first fetch is collected\n consumer.subscriptions().assignFromSubscribed(Collections.singleton(tp));\n return Fetch.empty();\n }).doAnswer(invocation -> {\n return Fetch.forPartition(tp, records, true);\n }).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n\n // And then poll for up to 10000ms, which should return 2 records without timing out\n ConsumerRecords returnedRecords = consumer.poll(Duration.ofMillis(10000));\n assertEquals(2, returnedRecords.count());\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertEquals(singleton(tp), consumer.assignment());\n }", + "completions": [ + "```java\n @Test\n public void testLongPollWaitIsLimited() {\n // Arrange\n Duration timeout = Duration.ofMillis(1000);\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // Act\n long startTime = System.currentTimeMillis();\n ConsumerRecords records = consumer.poll(timeout);\n long endTime = System.currentTimeMillis();\n\n // Assert\n long pollDuration = endTime - startTime;\n assertTrue(pollDuration >= timeout.toMillis());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 76, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForMinAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForMinAgg", + "reference": " @Test\n public void testFullIndexSearchForMinAgg() {\n Instant time = Instant.now();\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new MinAggBuilder(\n \"test\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalMin internalMin =\n (InternalMin) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n assertThat(Double.valueOf(internalMin.getValue()).longValue()).isEqualTo(time.toEpochMilli());\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForMinAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new MinAggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation() instanceof MinAgg);\n assertTrue(result.getElapsedTimeMicros() > 0);\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 77, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_bsaChecksumFetched", + "reference": " @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n setupMocks();\n when(connection.getResponseCode()).thenReturn(SC_OK);\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream(\"bsa-checksum\\ndata\".getBytes(UTF_8)));\n LazyBlockList download = fetcher.tryFetch(BlockListType.BLOCK);\n assertThat(download.getName()).isEqualTo(BlockListType.BLOCK);\n assertThat(download.checksum()).isEqualTo(\"bsa-checksum\");\n verify(connection, times(1)).setRequestMethod(\"GET\");\n verify(connection, times(1)).setRequestProperty(\"Authorization\", \"Bearer authToken\");\n }", + "completions": [ + "```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n // Given\n URL dataUrl = new URL(\"https://example.com/bsaChecksum\");\n HttpsURLConnection connection = mock(HttpsURLConnection.class);\n when(connection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);\n when(connection.getErrorStream()).thenReturn(new ByteArrayInputStream(\"\".getBytes()));\n when(urlConnectionService.createConnection(dataUrl)).thenReturn(connection);\n\n // When\n LazyBlockList lazyBlockList = bsaService.tryFetch(BlockListType.BSA_CHECKSUM);\n\n // Then\n assertEquals(BlockListType.BSA_CHECKSUM, lazyBlockList.getBlockListType());\n assertEquals(connection, lazyBlockList.getConnection());\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 78, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDisabledFullTextSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDisabledFullTextSearch", + "reference": " @Test\n public void testDisabledFullTextSearch() {\n Instant time = Instant.ofEpochSecond(1593365471);\n Trace.KeyValue field1Tag =\n Trace.KeyValue.newBuilder()\n .setVInt32(1234)\n .setKey(\"field1\")\n .setFieldType(Schema.SchemaFieldType.INTEGER)\n .build();\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(2, \"apple baby\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(3, \"baby car 1234\", time.plusSeconds(4)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Without the _all field as default.\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(3);\n\n // empty string\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"app*\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Returns baby or car, 2 messages.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby car\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Test numbers\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"apple 1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"123\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n }", + "completions": [ + "```java\n @Test\n public void testDisabledFullTextSearch() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getResults().size());\n assertNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 79, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyWithException", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(\n OperateRuntimeException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHitsRelation.EQUAL_TO));\n Mockito.when(searchHits.getHits()).thenReturn(new SearchHit[0]);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processInstanceListViewService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException to be thrown\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find unique process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 80, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_SecurityException_notRetriable", + "reference": " @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n when(connectionService.createConnection(any(URL.class)))\n .thenThrow(new GeneralSecurityException());\n assertThat(\n assertThrows(BsaException.class, () -> fetcher.tryFetch(BlockListType.BLOCK))\n .isRetriable())\n .isFalse();\n }", + "completions": [ + "```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n // Arrange\n String blockListTypeName = \"TEST\";\n BlockListType blockListType = new BlockListType(blockListTypeName);\n URL dataUrl = new URL(\"http://example.com\");\n when(blockListUrls.get(blockListTypeName)).thenReturn(dataUrl.toString());\n when(urlConnectionService.createConnection(dataUrl)).thenThrow(new GeneralSecurityException());\n\n // Act and Assert\n assertThrows(BsaException.class, () -> {\n try {\n blockListService.tryFetch(blockListType);\n } catch (BsaException e) {\n assertFalse(e.isRetriable());\n throw e;\n }\n });\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 81, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssignOnEmptyTopicPartition()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssignOnEmptyTopicPartition", + "reference": " @Test\n public void testAssignOnEmptyTopicPartition() {\n consumer = newConsumer();\n completeUnsubscribeApplicationEventSuccessfully();\n\n consumer.assign(Collections.emptyList());\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().isEmpty());\n }", + "completions": [ + "```java\n @Test\n public void testAssignOnEmptyTopicPartition() {\n // Given\n Collection partitions = new ArrayList<>();\n ConsumerImpl consumer = new ConsumerImpl();\n\n // When\n consumer.assign(partitions);\n\n // Then\n // Add assertions to verify the expected behavior\n // For example, you can check if the consumer is unsubscribed and if the fetch buffer is empty\n // assertFalse(consumer.isSubscribed());\n // assertTrue(consumer.getFetchBuffer().isEmpty());\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 82, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ImmutableList applyLabelDiff(\n ImmutableList labels,\n IdnChecker idnChecker,\n DownloadSchedule schedule,\n DateTime now) {\n ImmutableList.Builder nonBlockedDomains = new ImmutableList.Builder<>();\n ImmutableMap> labelsByType =\n ImmutableMap.copyOf(\n labels.stream().collect(groupingBy(BlockLabel::labelType, toImmutableList())));\n\n tm().transact(\n () -> {\n for (Map.Entry> entry :\n labelsByType.entrySet()) {\n switch (entry.getKey()) {\n case CREATE:\n // With current Cloud SQL, label upsert throughput is about 200/second. If\n // better performance is needed, consider bulk insert in native SQL.\n tm().putAll(\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(\n label ->\n new BsaLabel(label.label(), schedule.jobCreationTime()))\n .collect(toImmutableList()));\n // May not find all unblockables due to race condition: DomainCreateFlow uses\n // cached BsaLabels. Eventually will be consistent.\n nonBlockedDomains.addAll(\n tallyUnblockableDomainsForNewLabels(entry.getValue(), idnChecker, now));\n break;\n case DELETE:\n ImmutableSet deletedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n // Delete labels in DB. Also cascade-delete BsaUnblockableDomain.\n int nDeleted = Queries.deleteBsaLabelByLabels(deletedLabels);\n if (nDeleted != deletedLabels.size()) {\n logger.atSevere().log(\n \"Only found %s entities among the %s labels: [%s]\",\n nDeleted, deletedLabels.size(), deletedLabels);\n }\n break;\n case NEW_ORDER_ASSOCIATION:\n ImmutableSet affectedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n ImmutableSet labelsInDb =\n Queries.queryBsaLabelByLabels(affectedLabels)\n .map(BsaLabel::getLabel)\n .collect(toImmutableSet());\n verify(\n labelsInDb.size() == affectedLabels.size(),\n \"Missing labels in DB: %s\",\n LazyArgs.lazy(() -> difference(affectedLabels, labelsInDb)));\n\n // Reuse registered and reserved names that are already computed.\n Queries.queryBsaUnblockableDomainByLabels(affectedLabels)\n .map(BsaUnblockableDomain::toUnblockableDomain)\n .forEach(nonBlockedDomains::add);\n\n for (BlockLabel label : entry.getValue()) {\n getInvalidTldsForLabel(label, idnChecker)\n .map(tld -> UnblockableDomain.of(label.label(), tld, Reason.INVALID))\n .forEach(nonBlockedDomains::add);\n }\n break;\n }\n }\n },\n TRANSACTION_REPEATABLE_READ);\n logger.atInfo().log(\"Processed %s of labels.\", labels.size());\n return nonBlockedDomains.build();\n }\n```\n###test function signature:\n```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "applyLabelDiffs_newAssociationOfLabelToOrder", + "reference": " @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n tm().transact(\n () -> {\n tm().insert(new BsaLabel(\"label\", fakeClock.nowUtc()));\n tm().insert(new BsaUnblockableDomain(\"label\", \"app\", Reason.REGISTERED));\n });\n when(idnChecker.getSupportingTlds(any())).thenReturn(ImmutableSet.of(app));\n when(idnChecker.getForbiddingTlds(any()))\n .thenReturn(Sets.difference(ImmutableSet.of(dev), ImmutableSet.of()).immutableCopy());\n\n ImmutableList unblockableDomains =\n applyLabelDiff(\n ImmutableList.of(\n BlockLabel.of(\"label\", LabelType.NEW_ORDER_ASSOCIATION, ImmutableSet.of())),\n idnChecker,\n schedule,\n fakeClock.nowUtc());\n assertThat(unblockableDomains)\n .containsExactly(\n UnblockableDomain.of(\"label.app\", UnblockableDomain.Reason.REGISTERED),\n UnblockableDomain.of(\"label.dev\", UnblockableDomain.Reason.INVALID));\n assertThat(tm().transact(() -> tm().loadByKeyIfPresent(BsaLabel.vKey(\"label\")))).isPresent();\n assertThat(\n tm().transact(() -> tm().loadByKeyIfPresent(BsaUnblockableDomain.vKey(\"label\", \"app\"))))\n .isPresent();\n }", + "completions": [ + "```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n // Given\n ImmutableList labels = ImmutableList.of(BlockLabel.of(\"label1\", LabelType.NEW_ORDER_ASSOCIATION));\n IdnChecker idnChecker = mock(IdnChecker.class);\n DownloadSchedule schedule = mock(DownloadSchedule.class);\n DateTime now = DateTime.now();\n\n // When\n ImmutableList result = applyLabelDiff(labels, idnChecker, schedule, now);\n\n // Then\n // Add assertions here to verify the result\n }\n```\n" + ], + "line": 85, + "token": 602, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 83, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture createTopic(CreateTopicRequest request) {\n CompletableFuture future = new CompletableFuture<>();\n for (; ; ) {\n if (metadataStore.isLeader()) {\n try (SqlSession session = metadataStore.openSession()) {\n if (!metadataStore.maintainLeadershipWithSharedLock(session)) {\n continue;\n }\n TopicMapper topicMapper = session.getMapper(TopicMapper.class);\n if (null != topicMapper.get(null, request.getTopic())) {\n ControllerException e = new ControllerException(Code.DUPLICATED_VALUE,\n String.format(\"Topic %s was taken\", request.getTopic()));\n future.completeExceptionally(e);\n return future;\n }\n\n Topic topic = new Topic();\n topic.setName(request.getTopic());\n topic.setQueueNum(request.getCount());\n topic.setStatus(TopicStatus.TOPIC_STATUS_ACTIVE);\n topic.setAcceptMessageTypes(JsonFormat.printer().print(request.getAcceptTypes()));\n topic.setRetentionHours(request.getRetentionHours());\n topicMapper.create(topic);\n long topicId = topic.getId();\n List assignments = createQueues(IntStream.range(0, request.getCount()),\n topicId, session);\n // Commit transaction\n session.commit();\n\n // Cache new topic and queue assignments immediately\n topicCache.apply(List.of(topic));\n assignmentCache.apply(assignments);\n future.complete(topicId);\n } catch (ControllerException | InvalidProtocolBufferException e) {\n future.completeExceptionally(e);\n }\n return future;\n } else {\n Optional leaderAddress = metadataStore.electionService().leaderAddress();\n if (leaderAddress.isEmpty()) {\n return CompletableFuture.failedFuture(new ControllerException(Code.NO_LEADER_VALUE, \"No leader is elected yet\"));\n }\n metadataStore.controllerClient().createTopic(leaderAddress.get(), request).whenComplete((res, e) -> {\n if (null != e) {\n future.completeExceptionally(e);\n } else {\n future.complete(res);\n }\n });\n }\n break;\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCreateTopic()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCreateTopic", + "reference": " @Test\n public void testCreateTopic() {\n MetadataStore metadataStore = Mockito.mock(MetadataStore.class);\n ControllerClient controllerClient = Mockito.mock(ControllerClient.class);\n Mockito.when(controllerClient.createTopic(ArgumentMatchers.anyString(), ArgumentMatchers.any())).thenReturn(\n CompletableFuture.failedFuture(new CompletionException(new RuntimeException()))\n );\n Mockito.when(metadataStore.controllerClient()).thenReturn(controllerClient);\n Mockito.when(metadataStore.isLeader()).thenReturn(false);\n\n ElectionService electionService = Mockito.mock(ElectionService.class);\n Mockito.when(metadataStore.electionService()).thenReturn(electionService);\n\n Mockito.when(electionService.leaderAddress()).thenReturn(Optional.of(\"localhost:1234\"));\n\n TopicManager topicManager = new TopicManager(metadataStore);\n topicManager.createTopic(CreateTopicRequest.newBuilder().build());\n }", + "completions": [ + "```java\n @Test\n public void testCreateTopic() {\n // Given\n CreateTopicRequest request = new CreateTopicRequest();\n request.setTopic(\"testTopic\");\n request.setCount(3);\n request.setRetentionHours(24);\n request.setAcceptTypes(List.of(MessageType.MESSAGE_TYPE_JSON));\n\n // When\n CompletableFuture future = controller.createTopic(request);\n\n // Then\n assertDoesNotThrow(() -> {\n Long topicId = future.get();\n assertNotNull(topicId);\n assertTrue(topicId > 0);\n });\n }\n```\n" + ], + "line": 63, + "token": 424, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 84, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull(){\n try(MockedStatic sessionMockedStatic = Mockito.mockStatic(Session.class)){\n Session mockSession = mock(Session.class);\n Transport transport = mock(Transport.class);\n sessionMockedStatic.when(()->Session.getDefaultInstance(any(Properties.class))).thenReturn(mockSession);\n when(blacklistService.inBlacklist(anyString())).thenReturn(false);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n when(settingsService.getByCategoryAndKey(anyString(),anyString())).thenReturn(\"test\").thenReturn(5678).thenReturn(\"test\").thenReturn(\"test\").thenReturn(\"123456\");\n when(mockSession.getTransport(anyString())).thenReturn(transport);\n doNothing().when(transport).connect(anyString(),anyInt(),anyString(),anyString());\n doAnswer(invocationOnMock -> {\n InternetAddress[] internetAddressList = invocationOnMock.getArgument(1);\n Assertions.assertEquals(addressList.get(0),internetAddressList[0].getAddress());\n return null;\n }).when(transport).sendMessage(any(MimeMessage.class),any());\n mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n } catch (MessagingException e) {\n throw new RuntimeException(e);\n }\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = Arrays.asList(\"test1@example.com\", \"test2@example.com\");\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://example.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 85, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead", + "reference": " @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 1,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n AstraConfigs.IndexerConfig headLocationAndNoRecoveryConfig =\n AstraConfigs.IndexerConfig.newBuilder()\n .setCreateRecoveryTasksOnStart(false)\n .setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST)\n .build();\n\n // When there is no data and ReadFromLocationOnStart is set to LATEST, return the current head\n assertThat(\n recoveryTaskCreator.determineStartingOffset(1000, 0, headLocationAndNoRecoveryConfig))\n .isEqualTo(1000);\n\n // Data exists for not for this partition.\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTime = 1;\n final long endTime = 100;\n final long maxOffset = 100;\n\n final SnapshotMetadata partition1 =\n new SnapshotMetadata(name, path, startTime, endTime, maxOffset, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition1);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition1));\n\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n\n final SnapshotMetadata partition11 =\n new SnapshotMetadata(\n name + \"1\", path, endTime + 1, endTime * 2, maxOffset * 2, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition11);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition11));\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(partition1, partition11);\n assertThat(recoveryTaskCreator.determineStartingOffset(0, 0, indexerConfig)).isNegative();\n\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n \"2\",\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(false);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n\n // When\n long result = determineStartingOffset(currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(currentEndOffsetForPartition, result);\n }\n```\n" + ], + "line": 128, + "token": 935, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 86, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime =\n LocalDateTime.of(2020, 10, 1, 10, 10, 0).atZone(ZoneOffset.UTC).toInstant();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n long startOffset = 1;\n long endOffset = msgsToProduce - 1;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry))\n .isEqualTo(endOffset - startOffset + 1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(0);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(0);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isEmpty();\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 0,\n 100,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 87, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"3g\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n String expirationTime = \"10x\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(expirationTime);\n assertEquals(CacheControlHeader.DEFAULT_BASE_VALUE + CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_empty_collection()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_empty_collection", + "reference": " @Test\n public void should_clone_empty_collection() {\n List original = new ArrayList<>();\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_empty_collection() {\n // Given\n Collection collection = new ArrayList<>();\n\n // When\n Collection clonedCollection = (Collection) clone(collection);\n\n // Then\n assertNotSame(collection, clonedCollection);\n assertEquals(collection, clonedCollection);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 89, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testTimeBoundSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testTimeBoundSearch", + "reference": " @Test\n public void testTimeBoundSearch() {\n Instant time = Instant.now();\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(1, time));\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(2, time.plusSeconds(100)));\n strictLogStore.logStore.commit();\n strictLogStore.logStore.refresh();\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(2);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(0);\n assertThat(getTimerCount(REFRESHES_TIMER, strictLogStore.metricsRegistry)).isEqualTo(1);\n\n // Start inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Extended range still only picking one element.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(90).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Both ranges are inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.toEpochMilli(),\n time.plusSeconds(100).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n // Extended range to pick up both events\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(1000).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n }", + "completions": [ + "```java\n @Test\n public void testTimeBoundSearch() {\n // Given\n String dataset = \"test_dataset\";\n String queryStr = \"test_query\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n for (LogMessage logMessage : result.getResults()) {\n assertTrue(logMessage.getTimestampMs() >= startTimeMsEpoch);\n assertTrue(logMessage.getTimestampMs() <= endTimeMsEpoch);\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 90, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "\t@ParameterizedTest\n\t@MethodSource", + "reference": "\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\t// setup\n\t\tString key = \"key\";\n\t\tTestParameters parameters = new TestParameters();\n\t\tObject testValue = theParams.getTestValue();\n\n\t\t// test\n\t\tif (theParams.isExpectedToWork()) {\n\t\t\tparameters.setUserData(key, testValue);\n\t\t\tassertFalse(parameters.getUserData().isEmpty());\n\t\t\tassertEquals(testValue, parameters.getUserData().get(key));\n\t\t} else {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, testValue);\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tString dataType = testValue.getClass().getName();\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid data type provided \" + dataType),\n\t\t\t\t\tex.getMessage());\n\t\t\t\tassertTrue(parameters.getUserData().isEmpty());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\tUser user = new User();\n\t\tuser.setUserData(theParams.getKey(), theParams.getValue());\n\t\tassertEquals(theParams.getValue(), user.getUserData().get(theParams.getKey()));\n\t}\n\n\tstatic Stream parameters() {\n\t\treturn Stream.of(\n\t\t\t\tnew TestParam(\"stringKey\", \"stringValue\"),\n\t\t\t\tnew TestParam(\"numberKey\", 123),\n\t\t\t\tnew TestParam(\"booleanKey\", true),\n\t\t\t\tnew TestParam(\"nullKey\", null)\n\t\t);\n\t}\n\n\tstatic class TestParam {\n\t\tprivate final String key;\n\t\tprivate final Object value;\n\n\t\tTestParam(String key, Object value) {\n\t\t\tthis.key = key;\n\t\t\tthis.value = value;\n\t\t}\n\n\t\tString getKey() {\n\t\t\treturn key;\n\t\t}\n\n\t\tObject getValue() {\n\t\t\treturn value;\n\t\t}\n\t}\n```\n" + ], + "line": 24, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\n\t\tString text = \"Hello, how are you?\";\n\n\t\tEmbedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));\n\n\t\tserver\n\t\t\t.expect(requestToUriTemplate(\"/models/{generative}:embedText?key={apiKey}\",\n\t\t\t\t\tVertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))\n\t\t\t.andExpect(method(HttpMethod.POST))\n\t\t\t.andExpect(content().json(objectMapper.writeValueAsString(Map.of(\"text\", text))))\n\t\t\t.andRespond(withSuccess(objectMapper.writeValueAsString(Map.of(\"embedding\", expectedEmbedding)),\n\t\t\t\t\tMediaType.APPLICATION_JSON));\n\n\t\tEmbedding embedding = client.embedText(text);\n\n\t\tassertThat(embedding).isEqualTo(expectedEmbedding);\n\n\t\tserver.verify();\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\tEmbeddingResponse response = new EmbeddingResponse(expectedEmbedding);\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(response);\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Mono syncAclWithAclCsv(KafkaCluster cluster, String csv) {\n return adminClientService.get(cluster)\n .flatMap(ac -> ac.listAcls(ResourcePatternFilter.ANY).flatMap(existingAclList -> {\n var existingSet = Set.copyOf(existingAclList);\n var newAcls = Set.copyOf(AclCsv.parseCsv(csv));\n var toDelete = Sets.difference(existingSet, newAcls);\n var toAdd = Sets.difference(newAcls, existingSet);\n logAclSyncPlan(cluster, toAdd, toDelete);\n if (toAdd.isEmpty() && toDelete.isEmpty()) {\n return Mono.empty();\n }\n log.info(\"Starting new ACLs creation\");\n return ac.createAcls(toAdd)\n .doOnSuccess(v -> {\n log.info(\"{} new ACLs created\", toAdd.size());\n log.info(\"Starting ACLs deletion\");\n })\n .then(ac.deleteAcls(toDelete)\n .doOnSuccess(v -> log.info(\"{} ACLs deleted\", toDelete.size())));\n }));\n }\n```\n###test function signature:\n```java\n @Test\n void testSyncAclWithAclCsv()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSyncAclWithAclCsv", + "reference": " @Test\n void testSyncAclWithAclCsv() {\n var existingBinding1 = new AclBinding(\n new ResourcePattern(ResourceType.TOPIC, \"*\", PatternType.LITERAL),\n new AccessControlEntry(\"User:test1\", \"*\", AclOperation.READ, AclPermissionType.ALLOW));\n\n var existingBinding2 = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"group1\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test2\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n var newBindingToBeAdded = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"groupNew\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test3\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n when(adminClientMock.listAcls(ResourcePatternFilter.ANY))\n .thenReturn(Mono.just(List.of(existingBinding1, existingBinding2)));\n\n ArgumentCaptor> createdCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.createAcls(createdCaptor.capture()))\n .thenReturn(Mono.empty());\n\n ArgumentCaptor> deletedCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.deleteAcls(deletedCaptor.capture()))\n .thenReturn(Mono.empty());\n\n aclsService.syncAclWithAclCsv(\n CLUSTER,\n \"Principal,ResourceType, PatternType, ResourceName,Operation,PermissionType,Host\\n\"\n + \"User:test1,TOPIC,LITERAL,*,READ,ALLOW,*\\n\"\n + \"User:test3,GROUP,PREFIXED,groupNew,DESCRIBE,DENY,localhost\"\n ).block();\n\n Collection createdBindings = createdCaptor.getValue();\n assertThat(createdBindings)\n .hasSize(1)\n .contains(newBindingToBeAdded);\n\n Collection deletedBindings = deletedCaptor.getValue();\n assertThat(deletedBindings)\n .hasSize(1)\n .contains(existingBinding2);\n }", + "completions": [ + "```java\n @Test\n void testSyncAclWithAclCsv() {\n // Given\n KafkaCluster cluster = new KafkaCluster(\"testCluster\", \"localhost:9092\");\n String csv = \"principal,host,operation,permissionType\\nUser:Alice,*,Read,Allow\\nUser:Bob,*,Write,Allow\";\n AdminClient adminClient = Mockito.mock(AdminClient.class);\n when(adminClientService.get(cluster)).thenReturn(Mono.just(adminClient));\n when(adminClient.listAcls(ResourcePatternFilter.ANY)).thenReturn(Flux.just(\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Alice\", \"*\", AclOperation.READ, AclPermissionType.ALLOW)),\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Bob\", \"*\", AclOperation.WRITE, AclPermissionType.ALLOW))\n ).collectList());\n\n // When\n Mono result = service.syncAclWithAclCsv(cluster, csv);\n\n // Then\n StepVerifier.create(result)\n .verifyComplete();\n\n verify(adminClient).createAcls(any());\n verify(adminClient).deleteAcls(any());\n }\n```\n" + ], + "line": 32, + "token": 259, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 93, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyNoResults", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final SearchResponse mockResponse = mock(SearchResponse.class);\n final SearchHits mockHits = mock(SearchHits.class);\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n when(mockHits.getHits()).thenReturn(new SearchHit[0]);\n when(tenantAwareClient.search(any(SearchRequest.class))).thenReturn(mockResponse);\n\n // When\n try {\n processInstanceService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 94, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyMultipleRecoveryBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(3);\n assertThat(recoveryTasks).contains(recoveryTask1, recoveryTask11);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo(recoveryStartOffset * 3 + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(100);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 95, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n private String saveAndChecksumBlockList(String jobName, LazyBlockList blockList) {\n BlobId blobId = getBlobId(jobName, getBlockListFileName(blockList.getName()));\n try (BufferedOutputStream gcsWriter =\n new BufferedOutputStream(gcsUtils.openOutputStream(blobId))) {\n MessageDigest messageDigest = MessageDigest.getInstance(checksumAlgorithm);\n blockList.consumeAll(\n (byteArray, length) -> {\n try {\n gcsWriter.write(byteArray, 0, length);\n } catch (IOException e) {\n throw new RuntimeException(e);\n }\n messageDigest.update(byteArray, 0, length);\n });\n return base16().lowerCase().encode(messageDigest.digest());\n } catch (IOException | NoSuchAlgorithmException e) {\n throw new RuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "saveAndChecksumBlockList_success", + "reference": " @Test\n void saveAndChecksumBlockList_success() throws Exception {\n String payload = \"somedata\\n\";\n String payloadChecksum = \"0737c8e591c68b93feccde50829aca86a80137547d8cfbe96bab6b20f8580c63\";\n\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream((\"bsa-checksum\\n\" + payload).getBytes(UTF_8)));\n lazyBlockList = new LazyBlockList(BlockListType.BLOCK, connection);\n\n ImmutableMap checksums =\n gcsClient.saveAndChecksumBlockList(\"some-name\", ImmutableList.of(lazyBlockList));\n assertThat(gcsUtils.existsAndNotEmpty(BlobId.of(\"my-bucket\", \"some-name/BLOCK.csv\"))).isTrue();\n assertThat(checksums).containsExactly(BlockListType.BLOCK, payloadChecksum);\n assertThat(gcsClient.readBlockList(\"some-name\", BlockListType.BLOCK))\n .containsExactly(\"somedata\");\n }", + "completions": [ + "```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception {\n // Given\n String jobName = \"testJob\";\n LazyBlockList blockList = mock(LazyBlockList.class);\n BlobId blobId = mock(BlobId.class);\n BufferedOutputStream gcsWriter = mock(BufferedOutputStream.class);\n MessageDigest messageDigest = mock(MessageDigest.class);\n when(gcsUtils.openOutputStream(blobId)).thenReturn(gcsWriter);\n when(checksumAlgorithm).thenReturn(\"SHA-256\");\n when(messageDigest.digest()).thenReturn(new byte[] {1, 2, 3, 4, 5});\n when(base16().lowerCase().encode(messageDigest.digest())).thenReturn(\"0102030405\");\n\n // When\n String checksum = saveAndChecksumBlockList(jobName, blockList);\n\n // Then\n assertEquals(\"0102030405\", checksum);\n verify(gcsWriter).write(any(), anyInt(), anyInt());\n verify(messageDigest).update(any(), anyInt(), anyInt());\n }\n```\n" + ], + "line": 30, + "token": 211, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ListenableFuture> findMissingBlobs(\n Iterable blobDigests, RequestMetadata requestMetadata) {\n // Some requests have been blocked, and we should tell the client we refuse to perform a lookup.\n try {\n if (inDenyList(requestMetadata)) {\n return immediateFailedFuture(\n Status.UNAVAILABLE\n .withDescription(\"The action associated with this request is forbidden\")\n .asException());\n }\n } catch (IOException e) {\n return immediateFailedFuture(Status.fromThrowable(e).asException());\n }\n\n // Empty blobs are an exceptional case. Filter them out.\n // If the user only requested empty blobs we can immediately tell them we already have it.\n Iterable nonEmptyDigests =\n Iterables.filter(blobDigests, (digest) -> digest.getSizeBytes() != 0);\n if (Iterables.isEmpty(nonEmptyDigests)) {\n return immediateFuture(ImmutableList.of());\n }\n\n if (configs.getServer().isFindMissingBlobsViaBackplane()) {\n return findMissingBlobsViaBackplane(nonEmptyDigests, requestMetadata);\n }\n\n return findMissingBlobsQueryingEachWorker(nonEmptyDigests, requestMetadata);\n }\n```\n###test function signature:\n```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "findMissingBlobsTest_ViaBackPlane", + "reference": " @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n Set activeWorkers = ImmutableSet.of(\"worker1\", \"worker2\", \"worker3\");\n Set expiredWorkers = ImmutableSet.of(\"workerX\", \"workerY\", \"workerZ\");\n Set imposterWorkers = ImmutableSet.of(\"imposter1\", \"imposter2\", \"imposter3\");\n\n Set availableDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFound1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build());\n\n Set missingDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"missing1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build());\n\n Set digestAvailableOnImposters =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFoundOnImposter1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter3\").setSizeBytes(1).build());\n\n Set emptyDigests =\n new HashSet<>(\n Arrays.asList(\n Digest.newBuilder().setHash(\"empty1\").build(),\n Digest.newBuilder().setHash(\"empty2\").build()));\n\n Iterable allDigests =\n Iterables.concat(\n availableDigests,\n missingDigests,\n emptyDigests,\n digestAvailableOnImposters,\n Arrays.asList(\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build()));\n\n Map> digestAndWorkersMap = new HashMap<>();\n\n for (Digest digest : availableDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(activeWorkers));\n }\n for (Digest digest : missingDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(expiredWorkers));\n }\n for (Digest digest : digestAvailableOnImposters) {\n digestAndWorkersMap.put(digest, getRandomSubset(imposterWorkers));\n }\n\n BuildfarmConfigs buildfarmConfigs = instance.getBuildFarmConfigs();\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(true);\n Set activeAndImposterWorkers =\n Sets.newHashSet(Iterables.concat(activeWorkers, imposterWorkers));\n\n when(mockBackplane.getStorageWorkers()).thenReturn(activeAndImposterWorkers);\n when(mockBackplane.getBlobDigestsWorkers(any(Iterable.class))).thenReturn(digestAndWorkersMap);\n when(mockInstanceLoader.load(anyString())).thenReturn(mockWorkerInstance);\n when(mockWorkerInstance.findMissingBlobs(anyIterable(), any(RequestMetadata.class)))\n .thenReturn(Futures.immediateFuture(new ArrayList<>()));\n\n long serverStartTime = 1686951033L; // june 15th, 2023\n Map workersStartTime = new HashMap<>();\n for (String worker : activeAndImposterWorkers) {\n workersStartTime.put(worker, serverStartTime);\n }\n when(mockBackplane.getWorkersStartTimeInEpochSecs(activeAndImposterWorkers))\n .thenReturn(workersStartTime);\n long oneDay = 86400L;\n for (Digest digest : availableDigests) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime + oneDay);\n }\n for (Digest digest : digestAvailableOnImposters) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime - oneDay);\n }\n\n Iterable actualMissingDigests =\n instance.findMissingBlobs(allDigests, RequestMetadata.getDefaultInstance()).get();\n Iterable expectedMissingDigests =\n Iterables.concat(missingDigests, digestAvailableOnImposters);\n\n assertThat(actualMissingDigests).containsExactlyElementsIn(expectedMissingDigests);\n verify(mockWorkerInstance, atMost(3))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n verify(mockWorkerInstance, atLeast(1))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n\n for (Digest digest : actualMissingDigests) {\n assertThat(digest).isNotIn(availableDigests);\n assertThat(digest).isNotIn(emptyDigests);\n assertThat(digest).isIn(expectedMissingDigests);\n }\n\n // reset BuildfarmConfigs\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(false);\n }", + "completions": [ + "```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n // Given\n Iterable blobDigests = Arrays.asList(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n RequestMetadata requestMetadata = RequestMetadata.newBuilder().setClientId(\"client1\").build();\n\n // When\n ListenableFuture> future =\n underTest.findMissingBlobs(blobDigests, requestMetadata);\n\n // Then\n Iterable missingBlobs = future.get();\n assertThat(missingBlobs).containsExactly(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n }\n```\n" + ], + "line": 37, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 97, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent_notFound()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent_notFound", + "reference": " @Test\n public void testReadComponent_notFound() {\n Component component = mock(Component.class);\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(null);\n try {\n underTest.readComponent(\"someid\", \"testRepositoryName\");\n fail(\"Exception should have been thrown\");\n }\n catch (WebApplicationException e) {\n assertThat(e.getResponse(), is(notNullValue()));\n assertThat(e.getResponse().getStatus(), is(404));\n }\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent_notFound() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(() -> storageTx);\n when(storageTx.findComponent(componentId)).thenReturn(null);\n\n // When\n try {\n readComponent(repository, componentId);\n fail(\"Expected WebApplicationException\");\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Status.NOT_FOUND.getStatusCode(), e.getResponse().getStatus());\n }\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture auth(MQTT3AuthData authData) {\n try {\n HttpRequest request = HttpRequest.newBuilder()\n .uri(URI.create(webhookURI + \"/auth\"))\n .header(\"Content-Type\", \"application/json\")\n .POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))\n .timeout(Duration.ofSeconds(5))\n .build();\n return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())\n .thenApply(response -> {\n if (response.statusCode() == 200) {\n try {\n MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();\n JsonFormat.parser()\n .ignoringUnknownFields()\n .merge(response.body(), resultBuilder);\n return resultBuilder.build();\n } catch (Throwable e) {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build();\n }\n } else {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(\"Authenticate failed\")\n .build())\n .build();\n }\n })\n .exceptionally(e -> {\n System.out.println(\"Failed to call webhook: \" + e.getMessage());\n return null;\n });\n } catch (Throwable e) {\n return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAuth()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAuth", + "reference": " @Test\n public void testAuth() {\n MQTT3AuthData authedUser = MQTT3AuthData.newBuilder()\n .setUsername(\"authUser\")\n .build();\n\n MQTT3AuthData unauthUser = MQTT3AuthData.newBuilder()\n .setUsername(\"unauthUser\")\n .build();\n testServer.addAuthedUser(authedUser.getUsername());\n WebHookBasedAuthProvider provider = new WebHookBasedAuthProvider(testServer.getURI());\n MQTT3AuthResult authResult = provider.auth(authedUser).join();\n assertTrue(authResult.hasOk());\n\n authResult = provider.auth(unauthUser).join();\n assertFalse(authResult.hasOk());\n }", + "completions": [ + "```java\n @Test\n public void testAuth() {\n // Given\n String webhookURI = \"http://localhost:8080\";\n HttpClient httpClient = HttpClient.newHttpClient();\n MQTT3AuthData authData = MQTT3AuthData.newBuilder()\n .setClientId(\"client1\")\n .setUsername(\"user1\")\n .setPassword(\"password1\")\n .build();\n MQTT3AuthResult expectedResult = MQTT3AuthResult.newBuilder()\n .setSuccess(Success.newBuilder()\n .setIsSuperuser(false)\n .build())\n .build();\n\n // When\n CompletableFuture result = new MQTT3AuthService(webhookURI, httpClient).auth(authData);\n\n // Then\n assertNotNull(result);\n assertEquals(expectedResult, result.join());\n }\n```\n" + ], + "line": 59, + "token": 335, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 99, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_success", + "reference": " @Test\n public void testValidateRoleDuplicate_success() {\n // 调用,不会抛异常\n roleService.validateRoleDuplicate(randomString(), randomString(), null);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_success() {\n // Given\n String name = \"TestRole\";\n String code = \"TEST_ROLE\";\n Long id = 1L;\n\n // When\n // Call the method to be tested\n validateRoleDuplicate(name, code, id);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/5/EI/line_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Java/5/EI/line_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..709401efc33329dae00c8dca66d72148190d2d79 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/5/EI/line_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(22~43),subset_1(44~63),subset_2(65~85),subset_3(94~99),subset_4(109~128) +StarCoder2-15b,31.28,33.86,24.02,15.04,20.50 +CodeLlama-7b,33.31,32.86,24.18,23.35,23.62 +CodeLlama-13b,30.06,31.34,24.00,20.23,23.23 +CodeLlama-34b,30.25,34.21,23.72,23.56,23.22 +DeepSeek-Coder-1.3b,28.75,35.56,24.86,18.36,22.76 +DeepSeek-Coder-6.7b,21.24,18.71,23.31,20.24,24.54 +DeepSeek-Coder-33b,33.10,36.17,26.63,18.64,21.86 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/5/EI/token_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Java/5/EI/token_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..294579d57018f2e07397709735e328c7fc1e136c --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/5/EI/token_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~209),subset_1(211~300),subset_2(300~365),subset_3(366~544),subset_4(602~952) +StarCoder2-15b,30.01,36.39,21.63,18.67,22.42 +CodeLlama-7b,33.01,34.76,22.99,25.34,23.90 +CodeLlama-13b,29.56,32.30,23.94,21.82,21.52 +CodeLlama-34b,29.84,35.10,22.95,22.07,26.06 +DeepSeek-Coder-1.3b,28.67,35.08,24.86,21.22,18.65 +DeepSeek-Coder-6.7b,20.97,20.27,21.39,25.76,23.39 +DeepSeek-Coder-33b,32.12,37.38,25.53,26.03,20.64 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/5/EI/tongji.py b/dataset/Test Generation/ComplexCodeEval-Java/5/EI/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/5/EI/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/5/QS/QS.json b/dataset/Test Generation/ComplexCodeEval-Java/5/QS/QS.json new file mode 100644 index 0000000000000000000000000000000000000000..e9b016f5ae744a6e02a232537f3ec75c4fb6e00f --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/5/QS/QS.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTask", + "reference": " @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(TEST_S3_BUCKET);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 100L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockKafkaConsumer.consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(true);\n Mockito.when(mockKafkaConsumer.close()).thenReturn(true);\n Mockito.when(mockChunkManager.stopAsync()).thenReturn(true);\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenReturn(true);\n Mockito.when(AstraKafkaConsumer.makeKafkaConfig(Mockito.any(), Mockito.anyInt())).thenReturn(new Properties());\n Mockito.when(RecoveryChunkManager.fromConfig(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(mockChunkManager);\n Mockito.when(adminClient.listConsumerGroupOffsets(Mockito.anyString())).thenReturn(new ConsumerGroupOffsets());\n Mockito.when(AstraConfig.getRecoveryConfig()).thenReturn(new RecoveryConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig()).thenReturn(new KafkaConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic()).thenReturn(\"testTopic\");\n Mockito.when(AstraConfig.getIndexerConfig()).thenReturn(new IndexerConfig());\n Mockito.when(AstraConfig.getS3Config()).thenReturn(new S3Config());\n Mockito.when(meterRegistry.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class));\n Mockito.when(Timer.start(Mockito.any())).thenReturn(Mockito.mock(Timer.Sample.class));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n Assert.assertTrue(result);\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).prepConsumerForConsumption(Mockito.anyLong());\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong());\n Mockito.verify(mockChunkManager, Mockito.times(1)).waitForRollOvers();\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).close();\n Mockito.verify(mockChunkManager, Mockito.times(1)).stopAsync();\n Mockito.verify(mockChunkManager, Mockito.times(1)).awaitTerminated(Mockito.anyLong());\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 1, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskFailure", + "reference": " @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n String fakeS3Bucket = \"fakeBucket\";\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(fakeS3Bucket);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isFalse();\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 10L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(false);\n Mockito.when(mockKafkaConsumer.close()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.stopAsync()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertFalse(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 2, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public int migration(String srcTable, List uuids, String targetTable) {\n if (Objects.isNull(srcTable) || Objects.isNull(targetTable) || CollectionUtils.isEmpty(uuids)) {\n throw new SwValidationException(\n ValidSubject.EVALUATION, \"srcTable, targetTable and uuids must not be null\");\n }\n var filter = TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build();\n return store.migration(DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(filter)\n .build()\n );\n }\n```\n###test function signature:\n```java\n @Test\n public void testMigration()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMigration", + "reference": " @Test\n public void testMigration() {\n var srcTable = \"p/1/space/1/eval/summary\";\n var targetTable = \"p/1/eval/summary\";\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(null, List.of(\"uuid1\"), targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, null, targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, List.of(\"uuid1\"), null));\n\n evaluationRepo.migration(srcTable, List.of(\"uuid1\", \"uuid2\"), targetTable);\n\n verify(dataStore, times(1)).migration(\n DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(\n TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(List.of(\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid1\")))\n .build(),\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid2\")))\n .build()))\n .build())\n .build()\n );\n }", + "completions": [ + "```java\n @Test\n public void testMigration() {\n // Given\n String srcTable = \"srcTable\";\n List uuids = Arrays.asList(\"uuid1\", \"uuid2\", \"uuid3\");\n String targetTable = \"targetTable\";\n DataStoreMigrationRequest expectedRequest = DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build()\n )\n .build();\n when(store.migration(any(DataStoreMigrationRequest.class))).thenReturn(1);\n\n // When\n int result = migrationService.migration(srcTable, uuids, targetTable);\n\n // Then\n assertEquals(1, result);\n verify(store).migration(expectedRequest);\n }\n```\n" + ], + "line": 37, + "token": 221, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 3, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDatasMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDatasMapTest", + "reference": " @Test\n public void sendDatasMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"doris_test_db\";\n List> dataList = new ArrayList<>();\n for (int i = 0; i < 100; i++) {\n Map data = new HashMap<>();\n data.put(\"id\", i);\n data.put(\"name\", \"张三\");\n data.put(\"message\", \"股东风波的法规的规定翻跟斗广泛的给yryrtytr\");\n System.out.println(gson.toJson(data));\n dataList.add(data);\n\n }\n dorisStreamLoad.sendData(database, tableName, columnList, dataList);\n }", + "completions": [ + "```java\n @Test\n public void sendDataTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 4, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_serializable_key_and_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_serializable_key_and_value", + "reference": " @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map original = new HashMap<>();\n original.put(\"key\", new SerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map originalMap = new HashMap<>();\n originalMap.put(\"one\", 1);\n originalMap.put(\"two\", 2);\n originalMap.put(\"three\", 3);\n\n Map clonedMap = clone(originalMap);\n\n assertNotSame(originalMap, clonedMap);\n assertEquals(originalMap, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 5, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void encode(String schemaName, long nextTransactionOffset) {\n if (schemaName == null) {\n schemaName = UNKNOWN;\n }\n this.schemaName = schemaName;\n this.nextTransactionOffset = nextTransactionOffset;\n final byte[] payloadBytes = payloadToBytes();\n final int payloadLength = payloadBytes.length;\n\n // set logEventHeader\n int eventSize = eventHeaderLengthVersionGt1 + payloadLength;\n setLogEventHeader(new LogEventHeader(drc_filter_log_event.getType(), 0, eventSize, nextTransactionOffset)\n );\n\n // set payload\n final ByteBuf payloadByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(payloadLength);\n payloadByteBuf.writeBytes(payloadBytes);\n payloadByteBuf.skipBytes(payloadLength);\n setPayloadBuf(payloadByteBuf);\n }\n```\n###test function signature:\n```java\n @Test\n public void encode()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "encode", + "reference": " @Test\n public void encode() {\n FilterLogEvent filterLogEvent = new FilterLogEvent();\n filterLogEvent.encode(\"drc1\", 101);\n Assert.assertEquals(\"drc1\", filterLogEvent.getSchemaName());\n Assert.assertEquals(101, filterLogEvent.getNextTransactionOffset());\n\n ByteBuf headerBuf = filterLogEvent.getLogEventHeader().getHeaderBuf();\n headerBuf.readerIndex(0);\n ByteBuf payloadBuf = filterLogEvent.getPayloadBuf();\n payloadBuf.readerIndex(0);\n CompositeByteBuf compositeByteBuf = PooledByteBufAllocator.DEFAULT.compositeDirectBuffer();\n compositeByteBuf.addComponents(true, headerBuf, payloadBuf);\n\n FilterLogEvent newFilterLogEvent = new FilterLogEvent();\n newFilterLogEvent.read(compositeByteBuf);\n compositeByteBuf.release(compositeByteBuf.refCnt());\n Assert.assertEquals(\"drc1\", newFilterLogEvent.getSchemaName());\n Assert.assertEquals(101, newFilterLogEvent.getNextTransactionOffset());\n }", + "completions": [ + "```java\n @Test\n public void encode() {\n // Given\n String schemaName = \"testSchema\";\n long nextTransactionOffset = 1234567890L;\n drc_filter_log_event event = new drc_filter_log_event();\n\n // When\n event.encode(schemaName, nextTransactionOffset);\n\n // Then\n assertEquals(schemaName, event.getSchemaName());\n assertEquals(nextTransactionOffset, event.getNextTransactionOffset());\n assertEquals(drc_filter_log_event.getType(), event.getLogEventHeader().getType());\n assertEquals(0, event.getLogEventHeader().getServerId());\n assertEquals(eventHeaderLengthVersionGt1 + event.getPayloadBuf().readableBytes(), event.getLogEventHeader().getEventLength());\n assertEquals(nextTransactionOffset, event.getLogEventHeader().getNextTransactionOffset());\n assertEquals(event.getPayloadBuf().readableBytes(), event.getPayloadBuf().readableBytes());\n }\n```\n" + ], + "line": 29, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void put(String path, long position, ByteBuf data) {\n writeLock.lock();\n try {\n int dataLength = data.readableBytes();\n NavigableMap cache = path2cache.computeIfAbsent(path, k -> new TreeMap<>());\n Map.Entry pos2value = cache.floorEntry(position);\n long cacheStartPosition;\n Value value;\n if (pos2value == null || pos2value.getKey() + pos2value.getValue().dataLength < position) {\n cacheStartPosition = position;\n value = Value.EMPTY;\n } else {\n cacheStartPosition = pos2value.getKey();\n value = pos2value.getValue();\n }\n // ensure the capacity, if the capacity change then update the cache index\n int moreCapacity = (int) ((position + dataLength) - (cacheStartPosition + value.blocks.length * (long) blockSize));\n int newDataLength = (int) (position + dataLength - cacheStartPosition);\n if (moreCapacity > 0) {\n int[] blocks = ensureCapacity(cacheStartPosition, moreCapacity);\n if (blocks == null) {\n return;\n }\n int[] newBlocks = new int[value.blocks.length + blocks.length];\n System.arraycopy(value.blocks, 0, newBlocks, 0, value.blocks.length);\n System.arraycopy(blocks, 0, newBlocks, value.blocks.length, blocks.length);\n value = new Value(newBlocks, newDataLength);\n } else {\n value = new Value(value.blocks, newDataLength);\n }\n cache.put(cacheStartPosition, value);\n lru.put(new Key(path, cacheStartPosition), value);\n\n // write data to cache\n ByteBuffer cacheByteBuffer = this.cacheByteBuffer.duplicate();\n int positionDelta = (int) (position - cacheStartPosition);\n int written = 0;\n ByteBuffer[] nioBuffers = data.nioBuffers();\n int[] blocks = value.blocks;\n for (ByteBuffer nioBuffer : nioBuffers) {\n ByteBuf buf = Unpooled.wrappedBuffer(nioBuffer);\n while (buf.readableBytes() > 0) {\n int writePosition = positionDelta + written;\n int block = blocks[writePosition / blockSize];\n cacheByteBuffer.position(block * blockSize + writePosition % blockSize);\n int length = Math.min(buf.readableBytes(), blockSize - writePosition % blockSize);\n cacheByteBuffer.put(buf.slice(buf.readerIndex(), length).nioBuffer());\n buf.skipBytes(length);\n written += length;\n }\n }\n } finally {\n writeLock.unlock();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testMergePut() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMergePut", + "reference": " @Test\n public void testMergePut() throws IOException {\n FileCache fileCache = new FileCache(\"/tmp/file_cache_test\", 10 * 1024, 1024);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 1, 500));\n buf.addComponent(true, genBuf((byte) 2, 500));\n buf.addComponent(true, genBuf((byte) 3, 500));\n fileCache.put(\"test\", 3333L, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(1500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 4, 500));\n buf.addComponent(true, genBuf((byte) 5, 500));\n buf.addComponent(true, genBuf((byte) 6, 500));\n fileCache.put(\"test\", 3333L + 1000, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(2500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 7, 500));\n fileCache.put(\"test\", 3333L + 1000 + 1500, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(3000, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n\n assertTrue(verify(fileCache.get(\"test\", 3333L, 500).get(), (byte) 1));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 500, 500).get(), (byte) 2));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1000, 500).get(), (byte) 4));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1500, 500).get(), (byte) 5));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2000, 500).get(), (byte) 6));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2500, 500).get(), (byte) 7));\n }", + "completions": [ + "```java\n @Test\n public void testMergePut() throws IOException {\n // Given\n String path = \"/test/path\";\n long position = 100;\n ByteBuf data = Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5});\n\n // When\n cache.put(path, position, data);\n\n // Then\n NavigableMap cacheMap = cache.path2cache.get(path);\n assertNotNull(cacheMap);\n Map.Entry entry = cacheMap.floorEntry(position);\n assertNotNull(entry);\n assertEquals(position, entry.getKey().longValue());\n assertEquals(5, entry.getValue().dataLength);\n assertEquals(1, entry.getValue().blocks[0]);\n assertEquals(2, entry.getValue().blocks[1]);\n assertEquals(3, entry.getValue().blocks[2]);\n assertEquals(4, entry.getValue().blocks[3]);\n assertEquals(5, entry.getValue().blocks[4]);\n }\n```\n" + ], + "line": 65, + "token": 508, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 7, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchById()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchById", + "reference": " @Test\n public void testSearchById() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult index =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(index.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchById() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 8, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = new ArrayList<>();\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 9, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_object", + "reference": " @Test\n public void should_clone_serializable_object() {\n Object original = new SerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_object() {\n // Given\n Serializable object = new Serializable() {\n private static final long serialVersionUID = 1L;\n // Implement necessary methods\n };\n\n // When\n Serializable clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertEquals(object, clonedObject);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 10, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() {\n\n\t\tvar text = \"Hello, how are you?\";\n\n\t\tEmbedding response = vertexAiPaLm2Api.embedText(text);\n\n\t\tassertThat(response).isNotNull();\n\t\tassertThat(response.value()).hasSize(768);\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(new EmbeddingResponse(expectedEmbedding));\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Optional maybeFetchErrorIntervalMs(short errorCode, int intervalMs) {\n if (errorCode == Errors.NONE.code())\n return Optional.empty();\n\n int pushIntervalMs;\n String reason;\n\n Errors error = Errors.forCode(errorCode);\n switch (error) {\n case INVALID_REQUEST:\n case INVALID_RECORD:\n case UNSUPPORTED_VERSION:\n pushIntervalMs = Integer.MAX_VALUE;\n reason = \"The broker response indicates the client sent an request that cannot be resolved\"\n + \" by re-trying, hence disable telemetry\";\n break;\n case UNKNOWN_SUBSCRIPTION_ID:\n case UNSUPPORTED_COMPRESSION_TYPE:\n pushIntervalMs = 0;\n reason = error.message();\n break;\n case TELEMETRY_TOO_LARGE:\n case THROTTLING_QUOTA_EXCEEDED:\n reason = error.message();\n pushIntervalMs = (intervalMs != -1) ? intervalMs : ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS;\n break;\n default:\n reason = \"Unwrapped error code\";\n log.error(\"Error code: {}. Unmapped error for telemetry, disable telemetry.\", errorCode);\n pushIntervalMs = Integer.MAX_VALUE;\n }\n\n log.debug(\"Error code: {}, reason: {}. Push interval update to {} ms.\", errorCode, reason, pushIntervalMs);\n return Optional.of(pushIntervalMs);\n }\n```\n###test function signature:\n```java\n @Test\n public void testMaybeFetchErrorIntervalMs()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMaybeFetchErrorIntervalMs", + "reference": " @Test\n public void testMaybeFetchErrorIntervalMs() {\n assertEquals(Optional.empty(), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.NONE.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_REQUEST.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_RECORD.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SUBSCRIPTION_ID.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNSUPPORTED_COMPRESSION_TYPE.code(), -1));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), 20000));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), 20000));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SERVER_ERROR.code(), -1));\n }", + "completions": [ + "```java\n @Test\n public void testMaybeFetchErrorIntervalMs() {\n // Test case 1: errorCode is Errors.NONE.code()\n short errorCode = Errors.NONE.code();\n int intervalMs = 1000;\n Optional result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertFalse(result.isPresent());\n\n // Test case 2: errorCode is Errors.INVALID_REQUEST.code()\n errorCode = Errors.INVALID_REQUEST.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 3: errorCode is Errors.UNKNOWN_SUBSCRIPTION_ID.code()\n errorCode = Errors.UNKNOWN_SUBSCRIPTION_ID.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 4: errorCode is Errors.TELEMETRY_TOO_LARGE.code()\n errorCode = Errors.TELEMETRY_TOO_LARGE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 5: errorCode is Errors.THROTTLING_QUOTA_EXCEEDED.code()\n errorCode = Errors.THROTTLING_QUOTA_EXCEEDED.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 6: errorCode is Errors.UNSUPPORTED_COMPRESSION_TYPE.code()\n errorCode = Errors.UNSUPPORTED_COMPRESSION_TYPE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 7: errorCode is Errors.INVALID_RECORD.code()\n errorCode = Errors.INVALID_RECORD.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 8: errorCode is Errors.UNSUPPORTED_VERSION.code()\n errorCode = Errors.UNSUPPORTED_VERSION.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 9: errorCode is an unmapped error\n errorCode = (short) 9999;\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n }\n```\n" + ], + "line": 43, + "token": 290, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 12, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_non_serializable_object", + "reference": " @Test\n public void should_clone_non_serializable_object() {\n Object original = new NonSerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_non_serializable_object() {\n // Given\n Map nonSerializableMap = new HashMap<>();\n nonSerializableMap.put(\"key1\", new Object());\n nonSerializableMap.put(\"key2\", new Object());\n\n // When\n Map clonedMap = objectCloner.clone(nonSerializableMap);\n\n // Then\n assertNotSame(nonSerializableMap, clonedMap);\n assertNotSame(nonSerializableMap.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(nonSerializableMap.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 13, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueOperationCompletesOperationWithCachedActionResult", + "reference": " @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ActionKey actionKey = DigestUtil.asActionKey(Digest.newBuilder().setHash(\"test\").build());\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"operation-with-cached-action-result\")\n .setActionDigest(actionKey.getDigest())\n .build();\n\n ActionResult actionResult = ActionResult.getDefaultInstance();\n\n when(mockBackplane.getActionResult(eq(actionKey))).thenReturn(actionResult);\n\n Poller poller = mock(Poller.class);\n\n instance.queue(executeEntry, poller, DEFAULT_TIMEOUT).get();\n\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(CACHE_CHECK));\n verify(mockBackplane, never()).putOperation(any(Operation.class), eq(QUEUED));\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(COMPLETED));\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n ActionResult actionResult = ActionResult.newBuilder().setExitCode(0).build();\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(\"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n when(cache.get(any(ActionKey.class))).thenReturn(actionResult);\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n future.get();\n\n verify(poller).pause();\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 3, + "token_diff": 1 + }, + { + "id": 14, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueActionFailsQueueEligibility", + "reference": " @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n Directory inputRoot = Directory.newBuilder().build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(false);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_INVALID)\n .setSubject(INVALID_PLATFORM)\n .setDescription(\n \"properties are not valid for queue eligibility: []. If you think your\"\n + \" queue should still accept these poperties without them being\"\n + \" specified in queue configuration, consider configuring the queue\"\n + \" with `allow_unmatched: True`\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(ACTION_DIGEST));\n when(executeEntry.getStdoutStreamName()).thenReturn(STDOUT_STREAM_NAME);\n when(executeEntry.getStderrStreamName()).thenReturn(STDERR_STREAM_NAME);\n when(executeEntry.getOperationName()).thenReturn(OPERATION_NAME);\n when(executeEntry.getRequestMetadata()).thenReturn(REQUEST_METADATA);\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n when(poller.pause()).thenThrow(new RuntimeException(\"Queue eligibility failed\"));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n\n assertThrows(RuntimeException.class, future::get);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 15, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new SerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new SerializableObject(\"name2\"),\n new SerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key\", new SerializableObject());\n SerializableObject object = new SerializableObject(map);\n\n // When\n SerializableObject clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertNotSame(object.getMap(), clonedObject.getMap());\n assertNotSame(object.getMap().get(\"key\"), clonedObject.getMap().get(\"key\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 16, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForSumAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForSumAgg", + "reference": " @Test\n public void testFullIndexSearchForSumAgg() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new SumAggBuilder(\"test\", TEST_SOURCE_LONG_PROPERTY, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalSum internalSum =\n (InternalSum) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n // 1, 3, 4, 5\n assertThat(internalSum.getValue()).isEqualTo(13);\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForSumAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new SumAggBuilder(\"testField\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertEquals(SumAggBuilder.class, result.getAggregation().getClass());\n assertEquals(\"testField\", ((SumAggBuilder) result.getAggregation()).getField());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 17, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void process(String base64AuthenticatorData, String signature, String clientDataJson, Fido2RegistrationData registration,\n Fido2AuthenticationData authenticationEntity) {\n log.debug(\"Registration: {}\", registration);\n\n AuthData authData = authenticatorDataParser.parseAssertionData(base64AuthenticatorData);\n commonVerifiers.verifyRpIdHash(authData, registration.getDomain());\n\n log.debug(\"User verification option: {}\", authenticationEntity.getUserVerificationOption());\n userVerificationVerifier.verifyUserVerificationOption(authenticationEntity.getUserVerificationOption(), authData);\n\n byte[] clientDataHash = DigestUtils.getSha256Digest().digest(base64Service.urlDecode(clientDataJson));\n\n try {\n int counter = authenticatorDataParser.parseCounter(authData.getCounters());\n commonVerifiers.verifyCounter(registration.getCounter(), counter);\n registration.setCounter(counter);\n\n JsonNode uncompressedECPointNode = dataMapperService.cborReadTree(base64Service.urlDecode(registration.getUncompressedECPoint()));\n PublicKey publicKey = coseService.createUncompressedPointFromCOSEPublicKey(uncompressedECPointNode);\n\n log.debug(\"Uncompressed ECpoint node: {}\", uncompressedECPointNode);\n log.debug(\"EC Public key hex: {}\", Hex.encodeHexString(publicKey.getEncoded()));\n log.debug(\"Registration algorithm: {}, default use: -7\", registration.getSignatureAlgorithm());\n authenticatorDataVerifier.verifyAssertionSignature(authData, clientDataHash, signature, publicKey, -7);\n\n } catch (Fido2CompromisedDevice ex) {\n log.error(\"Error compromised device: {}\", ex.getMessage());\n throw ex;\n } catch (Exception ex) {\n log.error(\"Error to check none assertion: {}\", ex.getMessage());\n throw new Fido2RuntimeException(\"Failed to check none assertion: {}\", ex.getMessage(), ex);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "process_ifCborReadTreeThrowException_fido2RuntimeException", + "reference": " @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n String base64AuthenticatorData = \"base64AuthenticatorData_test\";\n String signature = \"signature_test\";\n String clientDataJson = \"clientDataJson_test\";\n Fido2RegistrationData registration = mock(Fido2RegistrationData.class);\n Fido2AuthenticationData authenticationEntity = mock(Fido2AuthenticationData.class);\n\n when(authenticationEntity.getUserVerificationOption()).thenReturn(UserVerification.preferred);\n when(registration.getDomain()).thenReturn(\"domain_test\");\n when(registration.getCounter()).thenReturn(100);\n when(registration.getUncompressedECPoint()).thenReturn(\"uncompressedECPoint_test\");\n\n when(authenticatorDataParser.parseAssertionData(any())).thenReturn(mock(AuthData.class));\n when(base64Service.urlDecode(any(String.class))).thenReturn(\"decode_test\".getBytes());\n when(dataMapperService.cborReadTree(any(byte[].class))).thenThrow(new IOException(\"IOException_test\"));\n\n Fido2RuntimeException ex = assertThrows(Fido2RuntimeException.class, () -> noneAssertionFormatProcessor.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity));\n assertNotNull(ex);\n assertEquals(ex.getMessage(), \"IOException_test\");\n\n verify(log).debug(eq(\"Registration: {}\"), any(Fido2RegistrationData.class));\n verify(log).debug(eq(\"User verification option: {}\"), any(UserVerification.class));\n verify(commonVerifiers).verifyRpIdHash(any(AuthData.class), any(String.class));\n verify(authenticatorDataParser).parseCounter(any());\n verify(log).error(eq(\"Error to check none assertion: {}\"), any(String.class));\n\n verify(log, never()).error(eq(\"Error compromised device: {}\"), any(String.class));\n verifyNoInteractions(coseService, authenticatorDataVerifier);\n verifyNoMoreInteractions(log);\n }", + "completions": [ + "```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n // Arrange\n String base64AuthenticatorData = \"base64AuthenticatorData\";\n String signature = \"signature\";\n String clientDataJson = \"clientDataJson\";\n Fido2RegistrationData registration = new Fido2RegistrationData();\n Fido2AuthenticationData authenticationEntity = new Fido2AuthenticationData();\n\n when(authenticatorDataParser.parseAssertionData(base64AuthenticatorData)).thenReturn(new AuthData());\n when(authenticatorDataParser.parseCounter(any())).thenReturn(1);\n when(dataMapperService.cborReadTree(any())).thenThrow(new IOException());\n\n // Act & Assert\n Assertions.assertThrows(Fido2RuntimeException.class, () -> {\n fido2AuthenticationService.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity);\n });\n }\n```\n" + ], + "line": 38, + "token": 365, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 18, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void setUserData_invalidKey_throws()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "setUserData_invalidKey_throws", + "reference": "\t@Test\n\tpublic void setUserData_invalidKey_throws() {\n\t\t// setup\n\t\tTestParameters parameters = new TestParameters();\n\n\t\t// test\n\t\tfor (String key : new String[] { null, \"\" }) {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, \"test\");\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid key; key must be non-empty, non-null\"),\n\t\t\t\t\tex.getMessage());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@Test(expected = IllegalArgumentException.class)\n\tpublic void setUserData_invalidKey_throws() {\n\t\tTestClass testClass = new TestClass();\n\t\ttestClass.setUserData(\"\", \"value\");\n\t}\n```\n\nIn this test, we are creating an instance of the TestClass and calling the setUserData method with an empty key and a valid value. We expect this to throw an IllegalArgumentException, which is what we specify in the @Test annotation.\n" + ], + "line": 23, + "token": 170, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueDirectoryMissingErrorsOperation", + "reference": " @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ByteString foo = ByteString.copyFromUtf8(\"foo\");\n Digest subdirDigest = DIGEST_UTIL.compute(foo);\n Directory inputRoot =\n Directory.newBuilder()\n .addDirectories(\n DirectoryNode.newBuilder().setName(\"missing-subdir\").setDigest(subdirDigest))\n .build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(true);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_MISSING)\n .setSubject(\"blobs/\" + DigestUtil.toString(subdirDigest))\n .setDescription(\"The directory `/missing-subdir` was not found in the CAS.\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(10);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(DigestUtil.DEFAULT_INSTANCE_NAME + \"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n ActionKey actionKey = DigestUtil.asActionKey(executeEntry.getActionDigest());\n when(checkCacheFuture(actionKey, any(), any())).thenReturn(immediateFuture(true));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n assertThat(future.get()).isEqualTo(null);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 20, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime = Instant.now();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Start recovery with an offset range that is partially unavailable\n long startOffset = 50;\n long endOffset = 150;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry)).isEqualTo(50);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(51);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 100L,\n 200L,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 21, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public List selectHeaderAccept(String[] accepts) {\n if (accepts.length == 0) {\n return null;\n }\n for (String accept : accepts) {\n MediaType mediaType = MediaType.parseMediaType(accept);\n if (isJsonMime(mediaType) && !isProblemJsonMime(accept)) {\n return Collections.singletonList(mediaType);\n }\n }\n return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts));\n }\n```\n###test function signature:\n```java\n @Test\n public void testSelectHeaderAccept()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSelectHeaderAccept", + "reference": " @Test\n public void testSelectHeaderAccept() {\n String[] accepts = {\"application/json\", \"application/xml\"};\n assertEquals(\"application/json\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"APPLICATION/XML\", \"APPLICATION/JSON\"};\n assertEquals(\"APPLICATION/JSON\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"application/xml\", \"application/json; charset=UTF8\"};\n assertEquals(\"application/json; charset=UTF8\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"text/plain\", \"application/xml\"};\n assertEquals(\"text/plain,application/xml\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{};\n assertNull(apiClient.selectHeaderAccept(accepts));\n }", + "completions": [ + "```java\n @Test\n public void testSelectHeaderAccept() {\n // Given\n String[] accepts = {\"application/json\", \"application/problem+json\"};\n // When\n List result = selectHeaderAccept(accepts);\n // Then\n assertEquals(1, result.size());\n assertEquals(\"application/json\", result.get(0).toString());\n }\n```\n" + ], + "line": 23, + "token": 154, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollEventSentOnConsumerPoll", + "reference": " @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n SubscriptionState subscriptions = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);\n consumer = newConsumer(\n mock(FetchBuffer.class),\n new ConsumerInterceptors<>(Collections.emptyList()),\n mock(ConsumerRebalanceListenerInvoker.class),\n subscriptions,\n singletonList(new RoundRobinAssignor()),\n \"group-id\",\n \"client-id\");\n final TopicPartition tp = new TopicPartition(\"topic\", 0);\n final List> records = singletonList(\n new ConsumerRecord<>(\"topic\", 0, 2, \"key1\", \"value1\"));\n doAnswer(invocation -> Fetch.forPartition(tp, records, true))\n .when(fetchCollector)\n .collectFetch(Mockito.any(FetchBuffer.class));\n\n consumer.subscribe(singletonList(\"topic1\"));\n consumer.poll(Duration.ofMillis(100));\n verify(applicationEventHandler).add(any(PollEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n // Arrange\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n consumer.poll(Duration.ofMillis(100));\n\n // Act\n ConsumerRecords records = consumer.poll(Duration.ofMillis(100));\n\n // Assert\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 23, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetMultiplePartitionRecoveriesBehind", + "reference": " @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n final RecoveryTaskMetadata recoveryTask2 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"2\",\n \"2\",\n recoveryStartOffset * 3 + 1,\n recoveryStartOffset * 4,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask2);\n final RecoveryTaskMetadata recoveryTask21 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"21\", \"2\", recoveryStartOffset * 4 + 1, 50000, createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask21);\n await()\n .until(\n () ->\n recoveryTaskStore\n .listSync()\n .containsAll(\n List.of(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(5);\n assertThat(recoveryTasks)\n .contains(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n // Given\n String partitionId = \"partition1\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 24, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture commitStreamObject(apache.rocketmq.controller.v1.S3StreamObject streamObject,\n List compactedObjects) {\n LOGGER.info(\"commitStreamObject with streamObject: {}, compactedObjects: {}\", TextFormat.shortDebugString(streamObject),\n compactedObjects);\n\n CompletableFuture future = new CompletableFuture<>();\n try (SqlSession session = sessionFactory.openSession()) {\n if (streamObject.getObjectId() == S3Constants.NOOP_OBJECT_ID) {\n LOGGER.error(\"S3StreamObject[object-id={}] is null or objectId is unavailable\", streamObject.getObjectId());\n String msg = String.format(\"S3StreamObject[object-id=%d] is null or objectId is unavailable\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.NOT_FOUND_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n\n long committedTs = System.currentTimeMillis();\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n\n // commit object\n if (!commitObject(streamObject.getObjectId(), streamObject.getStreamId(), streamObject.getObjectSize(), session)) {\n String msg = String.format(\"S3StreamObject[object-id=%d] is not ready for commit\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.ILLEGAL_STATE_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n long dataTs = committedTs;\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n dataTs = compactedObjects.stream()\n .map(id -> {\n // mark destroy compacted object\n S3Object object = s3ObjectMapper.getById(id);\n object.setState(S3ObjectState.BOS_WILL_DELETE);\n object.setMarkedForDeletionTimestamp(new Date());\n s3ObjectMapper.markToDelete(object.getId(), new Date());\n\n // update dataTs to the min compacted object's dataTs\n com.automq.rocketmq.metadata.dao.S3StreamObject s3StreamObject =\n s3StreamObjectMapper.getByObjectId(id);\n return s3StreamObject.getBaseDataTimestamp().getTime();\n })\n .min(Long::compareTo).get();\n }\n\n List toCache = new ArrayList<>();\n\n // create a new S3StreamObject to replace committed ones\n if (streamObject.getObjectId() != S3Constants.NOOP_OBJECT_ID) {\n com.automq.rocketmq.metadata.dao.S3StreamObject newS3StreamObj =\n new com.automq.rocketmq.metadata.dao.S3StreamObject();\n newS3StreamObj.setStreamId(streamObject.getStreamId());\n newS3StreamObj.setObjectId(streamObject.getObjectId());\n newS3StreamObj.setObjectSize(streamObject.getObjectSize());\n newS3StreamObj.setStartOffset(streamObject.getStartOffset());\n newS3StreamObj.setEndOffset(streamObject.getEndOffset());\n newS3StreamObj.setBaseDataTimestamp(new Date(dataTs));\n newS3StreamObj.setCommittedTimestamp(new Date(committedTs));\n s3StreamObjectMapper.create(newS3StreamObj);\n toCache.add(newS3StreamObj);\n }\n\n // delete the compactedObjects of S3Stream\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n compactedObjects.forEach(id -> s3StreamObjectMapper.delete(null, null, id));\n }\n session.commit();\n\n // Update Cache\n s3StreamObjectCache.cache(streamObject.getStreamId(), toCache);\n s3StreamObjectCache.onCompact(streamObject.getStreamId(), compactedObjects);\n\n LOGGER.info(\"S3StreamObject[object-id={}] commit success, compacted objects: {}\",\n streamObject.getObjectId(), compactedObjects);\n future.complete(null);\n } catch (Exception e) {\n LOGGER.error(\"CommitStream failed\", e);\n ControllerException ex = new ControllerException(Code.INTERNAL_VALUE, \"CommitStream failed\" + e.getMessage());\n future.completeExceptionally(ex);\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCommitStreamObject() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCommitStreamObject", + "reference": " @Test\n public void testCommitStreamObject() throws IOException {\n long objectId, streamId = 1;\n\n try (S3MetadataService metadataService = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n objectId = metadataService.prepareS3Objects(3, 5).get();\n } catch (Exception e) {\n throw new RuntimeException(e);\n }\n\n S3StreamObject news3StreamObject = S3StreamObject.newBuilder()\n .setObjectId(objectId + 2)\n .setStreamId(streamId)\n .setObjectSize(111L)\n .build();\n\n try (SqlSession session = this.getSessionFactory().openSession()) {\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n buildS3StreamObjs(objectId, 2, 3, 100L).forEach(s3StreamObjectMapper::create);\n session.commit();\n }\n\n try (S3MetadataService service = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n List compactedObjects = new ArrayList<>();\n compactedObjects.add(objectId);\n compactedObjects.add(objectId + 1);\n service.commitStreamObject(news3StreamObject, compactedObjects);\n }\n\n try (SqlSession session = getSessionFactory().openSession()) {\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3Object s3Object = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object.getState());\n\n S3Object s3Object1 = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object1.getState());\n\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n for (long index = objectId; index < objectId + 2; index++) {\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(index);\n Assertions.assertNull(object);\n }\n\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(objectId + 2);\n Assertions.assertEquals(111L, object.getObjectSize());\n Assertions.assertEquals(streamId, object.getStreamId());\n Assertions.assertTrue(object.getBaseDataTimestamp().getTime() > 0);\n Assertions.assertTrue(object.getCommittedTimestamp().getTime() > 0);\n }\n }", + "completions": [ + "```java\n @Test\n public void testCommitStreamObject() throws IOException {\n // Given\n apache.rocketmq.controller.v1.S3StreamObject streamObject = new apache.rocketmq.controller.v1.S3StreamObject();\n streamObject.setObjectId(1L);\n streamObject.setStreamId(2L);\n streamObject.setObjectSize(3L);\n streamObject.setStartOffset(4L);\n streamObject.setEndOffset(5L);\n\n List compactedObjects = Arrays.asList(6L, 7L, 8L);\n\n // When\n CompletableFuture future = controller.commitStreamObject(streamObject, compactedObjects);\n\n // Then\n assertDoesNotThrow(() -> future.get());\n }\n```\n" + ], + "line": 85, + "token": 664, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 25, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyTooManyResults", + "reference": " @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(totalHits.value).thenReturn(2L);\n Mockito.when(hits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(response.getHits()).thenReturn(hits);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processService.getProcessByKey(processDefinitionKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Expected exception message to match\",\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n e.getMessage());\n }\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 26, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String doGet(String url) {\n try (CloseableHttpClient client = HttpClients.createDefault()) {\n HttpGet request = new HttpGet(url);\n HttpResponse response = client.execute(request);\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n return EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n }\n } catch (IOException e) {\n log.error(\"DoGet exception.\", e);\n\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n void testDoGet()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDoGet", + "reference": " @Test\n void testDoGet() {\n String result = HttpUtil.doGet(\"http://baidu.com\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testDoGet() {\n String url = \"http://example.com\";\n String expectedResponse = \"Hello, world!\";\n\n // Mock the HttpClient and HttpResponse\n CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);\n HttpResponse mockResponse = Mockito.mock(HttpResponse.class);\n StatusLine mockStatusLine = Mockito.mock(StatusLine.class);\n HttpEntity mockEntity = Mockito.mock(HttpEntity.class);\n\n // Set up the mock objects to return the expected values\n Mockito.when(mockClient.execute(Mockito.any(HttpGet.class))).thenReturn(mockResponse);\n Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);\n Mockito.when(mockStatusLine.getStatusCode()).thenReturn(HttpStatus.SC_OK);\n Mockito.when(mockResponse.getEntity()).thenReturn(mockEntity);\n Mockito.when(EntityUtils.toString(mockEntity, StandardCharsets.UTF_8)).thenReturn(expectedResponse);\n\n // Call the method under test\n String actualResponse = YourClassName.doGet(url);\n\n // Verify the result\n Assertions.assertEquals(expectedResponse, actualResponse);\n }\n```\nPlease replace `YourClassName` with the actual name of the class where the `doGet` method is located.\n" + ], + "line": 23, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n dpopService.validateDpopThumprint(null, \"any\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n // Given\n String existingThumprint = \"\";\n String requestThumprint = \"requestThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n try {\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), e.getResponse().getStatus());\n assertEquals(\"Thumprint does not match\", e.getMessage());\n }\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDataMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDataMapTest", + "reference": " @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"hera_log_doris_table_2_11\";\n for (int i = 0; i < 1; i++) {\n String jsonString = \"{\\\"linenumber\\\":1256,\\\"tailId\\\":90219,\\\"mqtag\\\":\\\"tags_4_13_90219\\\",\\\"filename\\\":\\\"/home/work/log/log-agent/server.log\\\",\\\"tail\\\":\\\"demo-client-agent\\\",\\\"mqtopic\\\":\\\"90219_hera-demo-client\\\",\\\"message\\\":\\\"2024-01-10 19:34:40,360|INFO ||NettyClientPublicExecutor_1|c.x.mone.log.agent.rpc.task.PingTask|83|ping res: log-agent-server:2022-12-05:0.0.2->2024-01-10 19:34:40 358->10.53.129.250:9899\\\",\\\"logstore\\\":\\\"测试doris日志\\\",\\\"logip\\\":\\\"10.53.129.176\\\",\\\"timestamp\\\":1704886481241}\";\n\n Map map = gson.fromJson(jsonString, Map.class);\n dorisStreamLoad.sendData(database, tableName, columnListNew, map);\n }\n }", + "completions": [ + "```java\n @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 29, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyNoResults", + "reference": " @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final NotFoundException exception =\n Assertions.assertThrows(\n NotFoundException.class, () -> processService.getProcessByKey(processDefinitionKey));\n\n // Then\n Assertions.assertEquals(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 30, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String getDiagramByKey(Long processDefinitionKey) {\n final IdsQueryBuilder q = idsQuery().addIds(processDefinitionKey.toString());\n\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n\n if (response.getHits().getTotalHits().value == 1) {\n final Map result = response.getHits().getHits()[0].getSourceAsMap();\n return (String) result.get(BPMN_XML);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with id '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with id '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\n \"Exception occurred, while obtaining the process diagram: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetDiagramByKeyTooManyResults", + "reference": " @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getDiagramByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 1L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHits.Relation.EQUAL_TO));\n\n // When\n try {\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n processDiagramService.getDiagramByKey(processDefinitionKey);\n } catch (NotFoundException e) {\n // Then\n Assert.assertEquals(\n \"Could not find unique process with id '\" + processDefinitionKey + \"'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 37, + "token": 309, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 31, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n void shutdown() {\n LOG.info(\"Running shutdown hook.\");\n try {\n serviceManager.stopAsync().awaitStopped(30, TimeUnit.SECONDS);\n } catch (Exception e) {\n // stopping timed out\n LOG.error(\"ServiceManager shutdown timed out\", e);\n }\n try {\n curatorFramework.unwrap().close();\n } catch (Exception e) {\n LOG.error(\"Error while closing curatorFramework \", e);\n }\n LOG.info(\"Shutting down LogManager\");\n LogManager.shutdown();\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexerShutdownTwice() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexerShutdownTwice", + "reference": " @Test\n public void testIndexerShutdownTwice() throws Exception {\n startKafkaServer();\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // Create a live partition for this partiton\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTimeMs = 1;\n final long endTimeMs = 100;\n final long maxOffset = 30;\n SnapshotMetadata livePartition0 =\n new SnapshotMetadata(\n name + \"live0\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"0\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition0);\n\n SnapshotMetadata livePartition1 =\n new SnapshotMetadata(\n name + \"live1\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"1\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition1);\n\n final SnapshotMetadata partition0 =\n new SnapshotMetadata(name, path, startTimeMs, endTimeMs, maxOffset, \"0\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition0);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .containsOnly(livePartition1, livePartition0, partition0);\n\n // Empty consumer offset since there is no prior consumer.\n astraIndexer =\n new AstraIndexer(\n chunkManagerUtil.chunkManager,\n curatorFramework,\n makeIndexerConfig(50),\n getKafkaConfig(),\n metricsRegistry);\n astraIndexer.startAsync();\n astraIndexer.awaitRunning(DEFAULT_START_STOP_DURATION);\n await().until(() -> kafkaServer.getConnectedConsumerGroups() == 1);\n\n // Produce more messages since the recovery task is created for head.\n produceMessagesToKafka(kafkaServer.getBroker(), startTime);\n\n consumeMessagesAndSearchMessagesTest(100, 1);\n\n // Live snapshot is deleted, recovery task is created.\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(livePartition1, partition0);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isEqualTo(4);\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).size()).isEqualTo(1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(searchMetadataStore).size()).isEqualTo(1);\n RecoveryTaskMetadata recoveryTask1 =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).get(0);\n assertThat(recoveryTask1.startOffset).isEqualTo(31);\n assertThat(recoveryTask1.endOffset).isEqualTo(99);\n assertThat(recoveryTask1.partitionId).isEqualTo(\"0\");\n\n // Shutting down is idempotent. So, doing it twice shouldn't throw an error.\n astraIndexer.shutDown();\n astraIndexer.shutDown();\n astraIndexer = null;\n }", + "completions": [ + "```java\n @Test\n public void testIndexerShutdownTwice() throws Exception {\n // Given\n Indexer indexer = new Indexer();\n indexer.serviceManager = Mockito.mock(ServiceManager.class);\n indexer.curatorFramework = Mockito.mock(CuratorFramework.class);\n Mockito.when(indexer.serviceManager.stopAsync()).thenReturn(Mockito.mock(ServiceManager.class));\n Mockito.when(indexer.serviceManager.awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class))).thenThrow(new TimeoutException());\n Mockito.when(indexer.curatorFramework.unwrap()).thenReturn(Mockito.mock(Closeable.class));\n Mockito.doThrow(new IOException()).when(indexer.curatorFramework.unwrap()).close();\n\n // When\n indexer.shutdown();\n\n // Then\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).stopAsync();\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class));\n Mockito.verify(indexer.curatorFramework.unwrap(), Mockito.times(1)).close();\n }\n```\n" + ], + "line": 27, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(addressList,\"test\",\"test\",\"test\", SystemEnum.AGENT, MsgTypeEnum.ALARM);\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = null;\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 33, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyWithException", + "reference": " @Test\n public void testGetProcessByKeyWithException() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(OperateRuntimeException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(totalHits.value).thenReturn(2L);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final Executable executable = () -> processService.getProcessByKey(processDefinitionKey);\n\n // Then\n final NotFoundException exception =\n Assertions.assertThrows(NotFoundException.class, executable);\n Assertions.assertEquals(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 34, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testConcurrentSearches() throws InterruptedException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testConcurrentSearches", + "reference": " @Test\n public void testConcurrentSearches() throws InterruptedException {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n AtomicInteger searchFailures = new AtomicInteger(0);\n AtomicInteger statsFailures = new AtomicInteger(0);\n AtomicInteger searchExceptions = new AtomicInteger(0);\n AtomicInteger successfulRuns = new AtomicInteger(0);\n\n Runnable searchRun =\n () -> {\n for (int i = 0; i < 100; i++) {\n try {\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n 0L,\n MAX_TIME,\n 100,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n if (babies.hits.size() != 2) {\n searchFailures.addAndGet(1);\n } else {\n successfulRuns.addAndGet(1);\n }\n } catch (Exception e) {\n searchExceptions.addAndGet(1);\n }\n }\n };\n\n Thread t1 = new Thread(searchRun);\n Thread t2 = new Thread(searchRun);\n t1.start();\n t2.start();\n t1.join();\n t2.join();\n assertThat(searchExceptions.get()).isEqualTo(0);\n assertThat(statsFailures.get()).isEqualTo(0);\n assertThat(searchFailures.get()).isEqualTo(0);\n assertThat(successfulRuns.get()).isEqualTo(200);\n }", + "completions": [ + "```java\n @Test\n public void testConcurrentSearches() throws InterruptedException {\n // Create a fixed thread pool with 10 threads\n ExecutorService executorService = Executors.newFixedThreadPool(10);\n\n // Create a list to hold the futures\n List>> futures = new ArrayList<>();\n\n // Submit 100 search tasks to the executor service\n for (int i = 0; i < 100; i++) {\n // Create a search task\n Callable> searchTask = () -> {\n // Generate a random dataset\n String dataset = UUID.randomUUID().toString();\n\n // Generate a random query string\n String queryStr = UUID.randomUUID().toString();\n\n // Generate a random start time\n long startTimeMsEpoch = System.currentTimeMillis() - ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random end time\n long endTimeMsEpoch = startTimeMsEpoch + ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random number of hits requested\n int howMany = ThreadLocalRandom.current().nextInt(100);\n\n // Generate a random aggregation builder\n AggBuilder aggBuilder = new AggBuilder();\n\n // Perform the search\n return logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n };\n\n // Submit the search task to the executor service and add the future to the list\n futures.add(executorService.submit(searchTask));\n }\n\n // Wait for all search tasks to complete\n for (Future> future : futures) {\n future.get();\n }\n\n // Shutdown the executor service\n executorService.shutdown();\n }\n```\n" + ], + "line": 85, + "token": 608, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 35, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String sendGetData(String path, Map headMap) {\n log.info(\"request tcm, path:{},headMap:{} \",path,headMap);\n String result = \"\";\n CloseableHttpResponse response =null;\n try(CloseableHttpClient httpClient = HttpClients.createDefault()) {\n // 创建get方式请求对象\n HttpGet httpGet = new HttpGet(path);\n if (null != headMap) {\n for (Map.Entry entry : headMap.entrySet()) {\n httpGet.addHeader(entry.getKey(), entry.getValue());\n }\n }\n // 通过请求对象获取响应对象\n response = httpClient.execute(httpGet);\n\n // 获取结果实体\n // 判断网络连接状态码是否正常(0--200都数正常)\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n result = EntityUtils.toString(response.getEntity(), UTF_8);\n } else {\n log.error(\"get请求传输 异常.url:{}, headMap:{}\", path, headMap);\n }\n } catch (ClientProtocolException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n } catch (IOException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n }\n finally {\n try {\n if (null!=response){\n response.close();\n }\n } catch (IOException e) {\n log.error(\"关闭response 异常 \", e);\n }\n }\n log.debug(result);\n return result;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendGetData()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendGetData", + "reference": " @Test\n void testSendGetData() {\n final Map headMap = new HashMap<>();\n final String result = HttpUtils.sendGetData(\"path\", headMap);\n assertThat(result).isEqualTo(\"\");\n }", + "completions": [ + "```java\n @Test\n void testSendGetData() {\n String path = \"http://localhost:8080/api/test\";\n Map headMap = new HashMap<>();\n headMap.put(\"Content-Type\", \"application/json\");\n headMap.put(\"Authorization\", \"Bearer your_token\");\n\n String result = sendGetData(path, headMap);\n\n // Add assertions to verify the result\n // For example, if the result is expected to be a JSON string, you can use a JSON library to parse the result and assert the values\n // Here is an example using the JSON library Jackson\n ObjectMapper objectMapper = new ObjectMapper();\n JsonNode jsonNode = objectMapper.readTree(result);\n assertEquals(\"expectedValue\", jsonNode.get(\"key\").asText());\n }\n```\nPlease replace \"your_token\" with the actual token you want to use for the test. Also, replace \"expectedValue\" and \"key\" with the actual expected value and key from the JSON response.\n" + ], + "line": 51, + "token": 367, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 36, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ByteString retainKeyPrefix(ByteString tenantNS, List topicFilterLevels) {\n ByteString prefix = ByteString.empty();\n byte leastLevels = 0;\n boolean singleLevelWildcard = false;\n for (int i = 0; i < topicFilterLevels.size(); i++) {\n String tfl = topicFilterLevels.get(i);\n if (\"+\".equals(tfl)) {\n leastLevels++;\n singleLevelWildcard = true;\n continue;\n }\n if (\"#\".equals(tfl)) {\n break;\n }\n leastLevels++;\n if (!singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(tfl));\n }\n if (i + 1 < topicFilterLevels.size()) {\n if (!topicFilterLevels.get(i + 1).equals(\"#\") && !singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(NUL));\n }\n }\n }\n return tenantNS.concat(unsafeWrap(new byte[] {leastLevels})).concat(prefix);\n }\n```\n###test function signature:\n```java\n @Test\n public void testRetainKeyPrefix()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRetainKeyPrefix", + "reference": " @Test\n public void testRetainKeyPrefix() {\n String tenantId = \"tenantA\";\n ByteString tenantNS = tenantNS(tenantId);\n assertEquals(retainKeyPrefix(tenantNS, parse(\"#\", false)), tenantNS.concat(levelByte(0)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/#\", false)),\n tenantNS.concat(levelByte(1).concat(copyFromUtf8(\"a\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/+\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(\"a\").concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/b\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)).concat(copyFromUtf8(\"b\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/#\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+\", false)),\n tenantNS.concat(\n levelByte(3).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+/+\", false)),\n tenantNS.concat(\n levelByte(4).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/\", false)),\n tenantNS.concat(levelByte(4).concat(copyFromUtf8(NUL))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/+/\", false)),\n tenantNS.concat(levelByte(5).concat(copyFromUtf8(NUL))));\n }", + "completions": [ + "```java\n @Test\n public void testRetainKeyPrefix() {\n ByteString tenantNS = ByteString.copyFromUtf8(\"tenant1\");\n List topicFilterLevels = Arrays.asList(\"level1\", \"level2\", \"+\", \"#\");\n ByteString expectedPrefix = ByteString.copyFromUtf8(\"tenant1\\u0001level1\\u0000level2\");\n ByteString actualPrefix = retainKeyPrefix(tenantNS, topicFilterLevels);\n assertEquals(expectedPrefix, actualPrefix);\n }\n```\n" + ], + "line": 37, + "token": 255, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 37, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n @Transactional(rollbackFor = Exception.class)\n public void updateFileConfigMaster(Long id) {\n // 校验存在\n validateFileConfigExists(id);\n // 更新其它为非 master\n fileConfigMapper.updateBatch(new FileConfigDO().setMaster(false));\n // 更新\n fileConfigMapper.updateById(new FileConfigDO().setId(id).setMaster(true));\n // 发送刷新配置的消息\n TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {\n\n @Override\n public void afterCommit() {\n fileConfigProducer.sendFileConfigRefreshMessage();\n }\n\n });\n }\n```\n###test function signature:\n```java\n @Test\n public void testUpdateFileConfigMaster_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testUpdateFileConfigMaster_success", + "reference": " @Test\n public void testUpdateFileConfigMaster_success() {\n // mock 数据\n FileConfigDO dbFileConfig = randomFileConfigDO().setMaster(false);\n fileConfigMapper.insert(dbFileConfig);// @Sql: 先插入出一条存在的数据\n FileConfigDO masterFileConfig = randomFileConfigDO().setMaster(true);\n fileConfigMapper.insert(masterFileConfig);// @Sql: 先插入出一条存在的数据\n\n // 调用\n fileConfigService.updateFileConfigMaster(dbFileConfig.getId());\n // 断言数据\n assertTrue(fileConfigMapper.selectById(dbFileConfig.getId()).getMaster());\n assertFalse(fileConfigMapper.selectById(masterFileConfig.getId()).getMaster());\n // verify 调用\n verify(fileConfigProducer).sendFileConfigRefreshMessage();\n }", + "completions": [ + "```java\n @Test\n public void testUpdateFileConfigMaster_success() {\n // 准备数据\n FileConfigDO fileConfigDO = new FileConfigDO();\n fileConfigDO.setId(1L);\n fileConfigDO.setMaster(false);\n fileConfigMapper.insert(fileConfigDO);\n\n // 执行方法\n fileConfigService.updateFileConfigMaster(1L);\n\n // 验证结果\n FileConfigDO updatedFileConfigDO = fileConfigMapper.selectById(1L);\n Assert.assertTrue(updatedFileConfigDO.getMaster());\n\n // 清理数据\n fileConfigMapper.deleteById(1L);\n }\n```\n" + ], + "line": 28, + "token": 164, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\t@Override\n\tpublic void add(List documents) {\n\n\t\tAssert.notNull(documents, \"The document list should not be null.\");\n\t\tif (CollectionUtils.isEmpty(documents)) {\n\t\t\treturn; // nothing to do;\n\t\t}\n\n\t\tfinal var searchDocuments = documents.stream().map(document -> {\n\t\t\tfinal var embeddings = this.embeddingClient.embed(document);\n\t\t\tSearchDocument searchDocument = new SearchDocument();\n\t\t\tsearchDocument.put(ID_FIELD_NAME, document.getId());\n\t\t\tsearchDocument.put(EMBEDDING_FIELD_NAME, embeddings);\n\t\t\tsearchDocument.put(CONTENT_FIELD_NAME, document.getContent());\n\t\t\tsearchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());\n\n\t\t\t// Add the filterable metadata fields as top level fields, allowing filler\n\t\t\t// expressions on them.\n\t\t\tfor (MetadataField mf : this.filterMetadataFields) {\n\t\t\t\tif (document.getMetadata().containsKey(mf.name())) {\n\t\t\t\t\tsearchDocument.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name()));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn searchDocument;\n\t\t}).toList();\n\n\t\tIndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);\n\n\t\tfor (IndexingResult indexingResult : result.getResults()) {\n\t\t\tAssert.isTrue(indexingResult.isSucceeded(),\n\t\t\t\t\tString.format(\"Document with key %s upload is not successfully\", indexingResult.getKey()));\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void addAndSearchTest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "addAndSearchTest", + "reference": "\t@Test\n\tpublic void addAndSearchTest() {\n\n\t\tcontextRunner.run(context -> {\n\n\t\t\tVectorStore vectorStore = context.getBean(VectorStore.class);\n\n\t\t\tvectorStore.add(documents);\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\t\t\t}, hasSize(1));\n\n\t\t\tList results = vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\n\t\t\tassertThat(results).hasSize(1);\n\t\t\tDocument resultDoc = results.get(0);\n\t\t\tassertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());\n\t\t\tassertThat(resultDoc.getContent()).contains(\"The Great Depression (1929–1939) was an economic shock\");\n\t\t\tassertThat(resultDoc.getMetadata()).hasSize(2);\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"meta2\");\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"distance\");\n\n\t\t\t// Remove all documents from the store\n\t\t\tvectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Hello\").withTopK(1));\n\t\t\t}, hasSize(0));\n\t\t});\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void addAndSearchTest() {\n\t\t// Given\n\t\tList documents = new ArrayList<>();\n\t\tDocument document1 = new Document();\n\t\tdocument1.setId(\"1\");\n\t\tdocument1.setContent(\"This is a test document\");\n\t\tdocument1.setMetadata(new HashMap<>());\n\t\tdocument1.getMetadata().put(\"author\", \"John Doe\");\n\t\tdocument1.getMetadata().put(\"date\", \"2022-01-01\");\n\t\tdocuments.add(document1);\n\n\t\tDocument document2 = new Document();\n\t\tdocument2.setId(\"2\");\n\t\tdocument2.setContent(\"This is another test document\");\n\t\tdocument2.setMetadata(new HashMap<>());\n\t\tdocument2.getMetadata().put(\"author\", \"Jane Doe\");\n\t\tdocument2.getMetadata().put(\"date\", \"2022-01-02\");\n\t\tdocuments.add(document2);\n\n\t\t// When\n\t\tsearchService.add(documents);\n\n\t\t// Then\n\t\tSearchRequest searchRequest = new SearchRequest();\n\t\tsearchRequest.setQuery(\"test\");\n\t\tsearchRequest.setFilter(new HashMap<>());\n\t\tsearchRequest.getFilter().put(\"author\", \"John Doe\");\n\t\tSearchResponse searchResponse = searchService.search(searchRequest);\n\n\t\tAssert.assertEquals(1, searchResponse.getDocuments().size());\n\t\tAssert.assertEquals(\"1\", searchResponse.getDocuments().get(0).getId());\n\t}\n```\n" + ], + "line": 39, + "token": 312, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 39, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {\n long startTime = System.currentTimeMillis();\n\n // 校验参数\n try {\n validateParams(metrics);\n } catch (Exception e) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(e.getMessage());\n return;\n }\n\n HttpContext httpContext = createHttpContext(metrics.getNginx());\n HttpUriRequest request = createHttpRequest(metrics.getNginx());\n try (CloseableHttpResponse response = CommonHttpClient.getHttpClient().execute(request, httpContext)){\n // 发起http请求,获取响应数据\n int statusCode = response.getStatusLine().getStatusCode();\n if (statusCode != SUCCESS_CODE) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(\"StatusCode \" + statusCode);\n return;\n }\n String resp = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n\n Long responseTime = System.currentTimeMillis() - startTime;\n // 根据metrics name选择调用不同解析方法\n if (NGINX_STATUS_NAME.equals(metrics.getName()) || AVAILABLE.equals(metrics.getName())) {\n parseNginxStatusResponse(builder, resp, metrics, responseTime);\n } else if (REQ_STATUS_NAME.equals(metrics.getName())) {\n parseReqStatusResponse(builder, resp, metrics, responseTime);\n }\n } catch (Exception e) {\n String errorMsg = CommonUtil.getMessageFromThrowable(e);\n log.info(errorMsg);\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(errorMsg);\n } finally {\n if (request != null) {\n request.abort();\n }\n }\n\n }\n```\n###test function signature:\n```java\n @Test\n public void testNginxCollect()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNginxCollect", + "reference": " @Test\n public void testNginxCollect() {\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n long monitorId = 999;\n String app = \"testNginx\";\n\n Metrics metrics = new Metrics();\n metrics.setNginx(nginxProtocol);\n nginxCollect.collect(builder, monitorId, app, metrics);\n }", + "completions": [ + "```java\n @Test\n public void testNginxCollect() {\n // 创建一个Metrics对象\n Metrics metrics = new Metrics();\n metrics.setName(\"nginx_status\");\n metrics.setNginx(new Nginx(\"http://localhost:8080/nginx_status\"));\n\n // 创建一个CollectRep.MetricsData.Builder对象\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n\n // 创建一个NginxCollect对象\n NginxCollect nginxCollect = new NginxCollect();\n\n // 调用collect方法\n nginxCollect.collect(builder, 1, \"test\", metrics);\n\n // 验证结果\n assertEquals(CollectRep.Code.SUCCESS, builder.getCode());\n assertNotNull(builder.getData());\n assertTrue(builder.getData().getMetrics().containsKey(\"active_connections\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"accepts\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"handled\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"requests\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"reading\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"writing\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"waiting\"));\n }\n```\n" + ], + "line": 51, + "token": 335, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 40, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Transactional(rollbackOn = Exception.class)\n public void submit(@NonNull PartitionPlanConfig partitionPlanConfig)\n throws SchedulerException, ClassNotFoundException {\n Long databaseId = partitionPlanConfig.getDatabaseId();\n Validate.notNull(databaseId, \"DatabaseId can not be null\");\n // disable all related partition plan task\n Database database = this.databaseService.detail(databaseId);\n disablePartitionPlan(database.getId());\n PartitionPlanEntity partitionPlanEntity = modelToEntity(partitionPlanConfig);\n partitionPlanEntity = this.partitionPlanRepository.save(partitionPlanEntity);\n if (!partitionPlanConfig.isEnabled()\n || CollectionUtils.isEmpty(partitionPlanConfig.getPartitionTableConfigs())) {\n log.info(\"Partition plan is disabled or table config is empty, do nothing and return\");\n return;\n }\n Validate.isTrue(partitionPlanConfig.getCreationTrigger() != null, \"Creation trigger can not be null\");\n if (partitionPlanConfig.getDroppingTrigger() == null) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(partitionPlanConfig.getPartitionTableConfigs(),\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n return;\n }\n Map> strategy2TblCfgs =\n partitionPlanConfig.getPartitionTableConfigs().stream().flatMap(tableConfig -> {\n Map> strategy2Cfgs =\n tableConfig.getPartitionKeyConfigs().stream()\n .collect(Collectors.groupingBy(PartitionPlanKeyConfig::getStrategy));\n return strategy2Cfgs.values().stream().map(cfgs -> {\n PartitionPlanTableConfig cfg = new PartitionPlanTableConfig();\n cfg.setPartitionKeyConfigs(cfgs);\n cfg.setTableName(tableConfig.getTableName());\n cfg.setEnabled(tableConfig.isEnabled());\n cfg.setPartitionNameInvoker(tableConfig.getPartitionNameInvoker());\n cfg.setPartitionNameInvokerParameters(tableConfig.getPartitionNameInvokerParameters());\n return cfg;\n });\n }).collect(Collectors.groupingBy(cfg -> cfg.getPartitionKeyConfigs().get(0).getStrategy()));\n List createConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.CREATE);\n List dropConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.DROP);\n if (CollectionUtils.isNotEmpty(createConfigs)) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(createConfigs,\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n if (CollectionUtils.isNotEmpty(dropConfigs)) {\n ScheduleEntity dropScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getDroppingTrigger());\n createPartitionPlanTables(dropConfigs,\n partitionPlanEntity.getId(), dropScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "submit_bothCreateAndDropTrigger_submitSucceed", + "reference": " @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n PartitionPlanTableConfig tableConfig = new PartitionPlanTableConfig();\n tableConfig.setTableName(MYSQL_REAL_RANGE_TABLE_NAME);\n tableConfig.setPartitionNameInvoker(\"CUSTOM_PARTITION_NAME_GENERATOR\");\n SqlExprBasedGeneratorConfig config = new SqlExprBasedGeneratorConfig();\n config.setGenerateExpr(\"concat('p', date_format(from_unixtime(unix_timestamp(\"\n + \"STR_TO_DATE(20240125, '%Y%m%d')) + \"\n + PartitionPlanVariableKey.INTERVAL.getVariable() + \"), '%Y%m%d'))\");\n config.setIntervalGenerateExpr(\"86400\");\n tableConfig.setPartitionNameInvokerParameters(getSqlExprBasedNameGeneratorParameters(config));\n PartitionPlanKeyConfig c3Create = getMysqlc3CreateConfig();\n PartitionPlanKeyConfig datekeyCreate = getMysqldatekeyCreateConfig();\n PartitionPlanKeyConfig dropConfig = getDropConfig();\n tableConfig.setPartitionKeyConfigs(Arrays.asList(c3Create, datekeyCreate, dropConfig));\n\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setPartitionTableConfigs(Collections.singletonList(tableConfig));\n partitionPlanConfig.setFlowInstanceId(1L);\n partitionPlanConfig.setTimeoutMillis(180000L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setDatabaseId(1L);\n\n long t1 = System.currentTimeMillis();\n long t2 = System.currentTimeMillis() + 1000;\n TriggerConfig createTrigger = TestRandom.nextObject(TriggerConfig.class);\n createTrigger.setStartAt(new Date(t1));\n partitionPlanConfig.setCreationTrigger(createTrigger);\n\n TriggerConfig dropTrigger = TestRandom.nextObject(TriggerConfig.class);\n dropTrigger.setStartAt(new Date(t2));\n partitionPlanConfig.setDroppingTrigger(dropTrigger);\n Database database = TestRandom.nextObject(Database.class);\n database.setId(1L);\n Mockito.when(this.databaseService.detail(1L)).thenReturn(database);\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t1;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t2;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.doNothing().when(this.scheduleService).enable(Mockito.isA(ScheduleEntity.class));\n this.partitionPlanScheduleService.submit(partitionPlanConfig);\n List actuals = this.partitionPlanTableRepository.findAll();\n Set scheduleIds = actuals.stream()\n .map(PartitionPlanTableEntity::getScheduleId).collect(Collectors.toSet());\n Assert.assertEquals(scheduleIds.size(), 2);\n }", + "completions": [ + "```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n // Given\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setDatabaseId(1L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanConfig.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanConfig.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanConfig.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanConfig.setTaskId(\"taskId\");\n partitionPlanConfig.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanConfig.setTimeoutMillis(1000L);\n\n Database database = new Database();\n database.setId(1L);\n database.setName(\"database\");\n database.setType(DatabaseType.MYSQL);\n database.setHost(\"localhost\");\n database.setPort(3306);\n database.setUsername(\"root\");\n database.setPassword(\"password\");\n database.setDatabaseName(\"test\");\n database.setProperties(\"{}\");\n database.setStatus(DatabaseStatus.ONLINE);\n database.setCreateTime(new Date());\n database.setUpdateTime(new Date());\n when(databaseService.detail(1L)).thenReturn(database);\n\n PartitionPlanEntity partitionPlanEntity = new PartitionPlanEntity();\n partitionPlanEntity.setId(1L);\n partitionPlanEntity.setDatabaseId(1L);\n partitionPlanEntity.setEnabled(true);\n partitionPlanEntity.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanEntity.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanEntity.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanEntity.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanEntity.setTaskId(\"taskId\");\n partitionPlanEntity.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanEntity.setTimeoutMillis(1000L);\n when(partitionPlanRepository.save(any(PartitionPlanEntity.class))).thenReturn(partitionPlanEntity);\n\n ScheduleEntity createScheduleEntity = new ScheduleEntity();\n createScheduleEntity.setId(1L);\n createScheduleEntity.setDatabaseId(1L);\n createScheduleEntity.setTriggerConfig(new TriggerConfig(TriggerType" + ], + "line": 70, + "token": 530, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 41, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAllQueryWithFullTextSearchDisabled", + "reference": " @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n Instant time = Instant.now();\n Trace.KeyValue customField =\n Trace.KeyValue.newBuilder()\n .setVStr(\"value\")\n .setKey(\"customField\")\n .setFieldType(Schema.SchemaFieldType.KEYWORD)\n .build();\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time, List.of(customField)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n SearchResult termQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"customField:value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(termQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermStrQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermStrQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermNumericQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermNumericQuery.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"*\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNull(result.getInternalAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 42, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollExecutedCommitAsyncCallbacks", + "reference": " @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks() {\n consumer = newConsumer();\n MockCommitCallback callback = new MockCommitCallback();\n completeCommitAsyncApplicationEventSuccessfully();\n doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n completeFetchedCommittedOffsetApplicationEventSuccessfully(mkMap());\n\n consumer.assign(Collections.singleton(new TopicPartition(\"foo\", 0)));\n assertDoesNotThrow(() -> consumer.commitAsync(new HashMap<>(), callback));\n assertMockCommitCallbackInvoked(() -> consumer.poll(Duration.ZERO),\n callback,\n null);\n }", + "completions": [ + "```java\n @Test\n public void testPoll() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertNotNull(records);\n assertEquals(0, records.count());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 43, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static URI resolve(URI base, URI ref) {\n String path = ref.getRawPath();\n if (path != null && !path.isEmpty()) {\n path = base.getRawPath();\n if (path == null || !path.endsWith(\"/\")) {\n try {\n base = new URI(base.getScheme(), base.getAuthority(), base.getPath() + '/', null, null);\n } catch (URISyntaxException e) {\n throw new IllegalStateException(e);\n }\n }\n }\n return URIUtils.resolve(base, ref);\n }\n```\n###test function signature:\n```java\n @Test\n void testResolve_BasePathTrailingSlash()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testResolve_BasePathTrailingSlash", + "reference": " @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://host/sub/dir/\");\n assertEquals(\"http://host/sub/dir/file.jar\", resolve(base, \"file.jar\"));\n assertEquals(\"http://host/sub/dir/dir/file.jar\", resolve(base, \"dir/file.jar\"));\n assertEquals(\"http://host/sub/dir/?arg=val\", resolve(base, \"?arg=val\"));\n assertEquals(\"http://host/sub/dir/file?arg=val\", resolve(base, \"file?arg=val\"));\n assertEquals(\"http://host/sub/dir/dir/file?arg=val\", resolve(base, \"dir/file?arg=val\"));\n }", + "completions": [ + "```java\n @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://example.com/path/\");\n URI ref = URI.create(\"http://example.com/path/file.txt\");\n URI expected = URI.create(\"http://example.com/path/file.txt\");\n URI actual = resolve(base, ref);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 25, + "token": 187, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_collection_of_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_collection_of_non_serializable_object", + "reference": " @Test\n public void should_clone_collection_of_non_serializable_object() {\n List original = new ArrayList<>();\n original.add(new NonSerializableObject(\"value\"));\n List cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_collection_of_non_serializable_object() {\n // Given\n List nonSerializableList = new ArrayList<>();\n nonSerializableList.add(new NonSerializableObject(\"test\"));\n nonSerializableList.add(new NonSerializableObject(\"test2\"));\n\n // When\n List clonedList = objectCloner.clone(nonSerializableList);\n\n // Then\n assertNotSame(nonSerializableList, clonedList);\n assertEquals(nonSerializableList.size(), clonedList.size());\n for (int i = 0; i < nonSerializableList.size(); i++) {\n assertNotSame(nonSerializableList.get(i), clonedList.get(i));\n assertEquals(nonSerializableList.get(i), clonedList.get(i));\n }\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 45, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssign()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssign", + "reference": " @Test\n public void testAssign() {\n consumer = newConsumer();\n final TopicPartition tp = new TopicPartition(\"foo\", 3);\n consumer.assign(singleton(tp));\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().contains(tp));\n verify(applicationEventHandler).add(any(AssignmentChangeEvent.class));\n verify(applicationEventHandler).add(any(NewTopicsMetadataUpdateRequestEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testAssign() {\n // Create a mock Consumer instance\n Consumer mockConsumer = mock(Consumer.class);\n\n // Create a collection of TopicPartition instances\n Collection partitions = new ArrayList<>();\n partitions.add(new TopicPartition(\"testTopic\", 0));\n partitions.add(new TopicPartition(\"testTopic\", 1));\n\n // Call the assign method\n mockConsumer.assign(partitions);\n\n // Verify that the assign method was called with the correct parameters\n verify(mockConsumer).assign(partitions);\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 46, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String relativePath(Path absolutePath, Path basePath) {\n Preconditions.checkArgument(absolutePath.isAbsolute(), \"absolutePath must be an absolute path\");\n Preconditions.checkArgument(basePath.isAbsolute(), \"basePath must be an absolute path\");\n\n List absolutePathComponents = Lists.newArrayList(toPathComponents(absolutePath)); // make a copy\n List basePathComponents = toPathComponents(basePath);\n boolean hasCommonPrefix = basePathComponents.isEmpty(); // when basePath is \"/\", we always have a common prefix\n\n for (String base : basePathComponents) {\n if (absolutePathComponents.get(0).equals(base)) {\n absolutePathComponents.remove(0);\n hasCommonPrefix = true;\n } else {\n break;\n }\n }\n if (hasCommonPrefix) {\n return PATH_JOINER.join(absolutePathComponents);\n } else {\n return absolutePath.toString();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testRelativePath() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRelativePath", + "reference": " @Test\n public void testRelativePath() throws Exception {\n assertEquals(\"a\", PathUtils.relativePath(Path.of(\"/a\"), Path.of(\"/\")));\n assertEquals(\"b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/a\")));\n assertEquals(\"b/c.json\", PathUtils.relativePath(Path.of(\"/a/b/c.json\"), Path.of(\"/a\")));\n assertEquals(\"c/d/e\", PathUtils.relativePath(Path.of(\"/a/b/c/d/e\"), Path.of(\"/a/b\")));\n assertEquals(\"/a/b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/c/d\"))); // no common prefix\n }", + "completions": [ + "```java\n @Test\n public void testRelativePath() throws Exception {\n Path absolutePath = Paths.get(\"/a/b/c/d\");\n Path basePath = Paths.get(\"/a/b\");\n String expected = \"c/d\";\n String actual = relativePath(absolutePath, basePath);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 31, + "token": 234, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 47, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Boolean isJson(String value) {\n JsonElement jsonElement;\n try {\n jsonElement = JsonParser.parseString(value);\n } catch (Exception e) {\n return false;\n }\n if (jsonElement == null || jsonElement.isJsonNull()) {\n return false;\n }\n if (!jsonElement.isJsonObject() && !jsonElement.isJsonArray()) {\n return false;\n }\n return true;\n }\n```\n###test function signature:\n```java\n @Test\n void testIsJson()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJson", + "reference": " @Test\n void testIsJson() {\n assertThat(JsonUtil.isJson(\"value\")).isFalse();\n }", + "completions": [ + "```java\n @Test\n void testIsJson() {\n // Test case 1: Valid JSON object\n String jsonObject = \"{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}\";\n assertTrue(isJson(jsonObject));\n\n // Test case 2: Valid JSON array\n String jsonArray = \"[{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}, {\\\"name\\\":\\\"Jane\\\", \\\"age\\\":25, \\\"city\\\":\\\"Chicago\\\"}]\";\n assertTrue(isJson(jsonArray));\n\n // Test case 3: Invalid JSON (not an object or array)\n String invalidJson = \"\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"\";\n assertFalse(isJson(invalidJson));\n\n // Test case 4: Empty string\n String emptyString = \"\";\n assertFalse(isJson(emptyString));\n\n // Test case 5: Null value\n String nullValue = null;\n assertFalse(isJson(nullValue));\n }\n```\n" + ], + "line": 26, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchOnlyHistogram()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchOnlyHistogram", + "reference": " @Test\n public void testSearchOnlyHistogram() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 0,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(0);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(2);\n\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(1).getDocCount()).isEqualTo(1);\n\n assertThat(\n Long.parseLong(histogram.getBuckets().get(0).getKeyAsString()) >= time.toEpochMilli())\n .isTrue();\n assertThat(\n Long.parseLong(histogram.getBuckets().get(1).getKeyAsString())\n <= time.plusSeconds(10).toEpochMilli())\n .isTrue();\n }", + "completions": [ + "```java\n @Test\n public void testSearchOnlyHistogram() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 0;\n AggBuilder aggBuilder = new AggBuilder();\n aggBuilder.addHistogramAgg(\"testField\", \"testInterval\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation().getHistogramAggs().size() > 0);\n assertEquals(\"testField\", result.getAggregation().getHistogramAggs().get(0).getField());\n assertEquals(\"testInterval\", result.getAggregation().getHistogramAggs().get(0).getInterval());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 49, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent", + "reference": " @Test\n public void testReadComponent() {\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n EntityMetadata entityMetadata = mock(EntityMetadata.class);\n VariableSource variableSource = mock(VariableSource.class);\n when(contentPermissionChecker.isPermitted(anyString(), any(), eq(BreadActions.BROWSE), any())).thenReturn(true);\n when(component.getEntityMetadata()).thenReturn(entityMetadata);\n when(entityMetadata.getId()).thenReturn(new DetachedEntityId(\"someid\"));\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(Arrays.asList(asset));\n when(assetVariableResolver.fromAsset(asset)).thenReturn(variableSource);\n ComponentXO componentXO = underTest.readComponent(\"someid\", \"testRepositoryName\");\n\n assertThat(componentXO, is(notNullValue()));\n assertThat(componentXO.getId(), is(\"someid\"));\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n Supplier txSupplier = () -> storageTx;\n\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(txSupplier);\n when(storageTx.findComponent(componentId)).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(ImmutableList.of(asset));\n\n // When\n ComponentXO result = readComponent(repository, componentId);\n\n // Then\n verify(repository).facet(StorageFacet.class);\n verify(storageFacet).txSupplier();\n verify(storageTx).begin();\n verify(storageTx).findComponent(componentId);\n verify(storageTx).browseAssets(component);\n verify(storageTx).close();\n // Add assertions for the result\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 50, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ListResult listEntities(String tableName, String searchIndexName, List matchFilters,\n List queryFilters, List multiMatchFilter, String nextToken, List sorters, Class clazz) {\n SearchQuery searchQuery = createSearchQuery(matchFilters, queryFilters, multiMatchFilter);\n SearchRequest searchRequest = new SearchRequest(tableName, searchIndexName, searchQuery);\n if (!StringUtils.isEmpty(nextToken)) {\n byte[] tokenBytes = EncryptionUtil.decode(nextToken);\n searchRequest.getSearchQuery().setToken(tokenBytes);\n } else {\n if (sorters != null &&!sorters.isEmpty()) {\n searchQuery.setSort(new Sort(sorters));\n }\n }\n SearchRequest.ColumnsToGet columnsToGet = new SearchRequest.ColumnsToGet();\n columnsToGet.setColumns(ReflectionUtil.getPropertyNames(clazz));\n searchRequest.setColumnsToGet(columnsToGet);\n log.info(\"searchRequest:{}\", JsonUtil.toJsonString(searchRequest));\n SearchResponse searchResponse = otsClient.search(searchRequest);\n if (searchResponse == null || searchResponse.getRows() == null) {\n return ListResult.genSuccessListResult(null, 0);\n }\n byte[] nextTokenBytes = searchResponse.getNextToken();\n nextToken = nextTokenBytes == null || nextTokenBytes.length == 0 ? null : EncryptionUtil.encode(nextTokenBytes);\n List result = searchResponse.getRows().stream()\n .map(row -> OtsUtil.convertRowToDTO(row, clazz))\n .collect(Collectors.toList());\n return ListResult.genSuccessListResult(result, searchResponse.getTotalCount(), nextToken);\n }\n```\n###test function signature:\n```java\n @Test\n void testListEntities()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListEntities", + "reference": " @Test\n void testListEntities() {\n final List matchFilters = createMatchFilters();\n final List queryFilters = createQueryFilters();\n final ListResult expectedResult = new ListResult<>();\n FieldSort fieldSort = new FieldSort(OrderOtsConstant.GMT_CREATE_LONG);\n fieldSort.setOrder(SortOrder.DESC);\n expectedResult.setData(Arrays.asList());\n expectedResult.setCount(0L);\n String nextToken = \"CAESFQoTChEKDWdtdENyZWF0ZUxvbmcQARgBIlQKCQBI8UqGigEAAApHA0IAAAAxUzM1MzQzMTM0NjQzMjYzMzAzMzYyMzE2MTMzMzkzOTM1MzEzNjM2MzM2NDM2MzAzMDMwNjYzNTM1MzA2NjY0MzM=\";\n expectedResult.setNextToken(nextToken);\n final SearchResponse searchResponse = new SearchResponse(new Response(\"requestId\"));\n when(mockOtsClient.search(any(SearchRequest.class))).thenReturn(searchResponse);\n ListResult result = baseOtsHelper.listEntities(\"order\", \"order_index\", matchFilters, queryFilters, null, nextToken, Arrays.asList(fieldSort), OrderDTO.class);\n assertThat(result.getCount()).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n void testListEntities() {\n // Given\n String tableName = \"testTable\";\n String searchIndexName = \"testIndex\";\n List matchFilters = Arrays.asList(new OtsFilter(\"field1\", OtsFilter.CompareOperator.EQUAL, \"value1\"),\n new OtsFilter(\"field2\", OtsFilter.CompareOperator.EQUAL, \"value2\"));\n List queryFilters = Arrays.asList(new OtsFilter(\"field3\", OtsFilter.CompareOperator.EQUAL, \"value3\"),\n new OtsFilter(\"field4\", OtsFilter.CompareOperator.EQUAL, \"value4\"));\n List multiMatchFilter = Arrays.asList(new OtsFilter(\"field5\", OtsFilter.CompareOperator.EQUAL, \"value5\"),\n new OtsFilter(\"field6\", OtsFilter.CompareOperator.EQUAL, \"value6\"));\n String nextToken = \"testNextToken\";\n List sorters = Arrays.asList(new Sort.Sorter(\"field7\", Sort.SortOrder.ASC),\n new Sort.Sorter(\"field8\", Sort.SortOrder.DESC));\n Class clazz = TestDTO.class;\n\n // When\n ListResult result = listEntities(tableName, searchIndexName, matchFilters, queryFilters, multiMatchFilter, nextToken, sorters, clazz);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalCount());\n assertNull(result.getNextToken());\n assertTrue(result.getData().isEmpty());\n }\n```\n" + ], + "line": 38, + "token": 346, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 51, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String[] listFiles(URI fileUri, boolean recursive) throws IOException {\n try {\n ImmutableList.Builder builder = ImmutableList.builder();\n String continuationToken = null;\n boolean isDone = false;\n String prefix = normalizeToDirectoryPrefix(fileUri);\n int fileCount = 0;\n while (!isDone) {\n ListObjectsV2Request.Builder listObjectsV2RequestBuilder =\n ListObjectsV2Request.builder().bucket(fileUri.getHost());\n if (!prefix.equals(DELIMITER)) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.prefix(prefix);\n }\n if (!recursive) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.delimiter(DELIMITER);\n }\n if (continuationToken != null) {\n listObjectsV2RequestBuilder.continuationToken(continuationToken);\n }\n ListObjectsV2Request listObjectsV2Request = listObjectsV2RequestBuilder.build();\n LOG.debug(\"Trying to send ListObjectsV2Request {}\", listObjectsV2Request);\n ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);\n LOG.debug(\"Getting ListObjectsV2Response: {}\", listObjectsV2Response);\n List filesReturned = listObjectsV2Response.contents();\n fileCount += filesReturned.size();\n filesReturned.stream()\n .forEach(\n object -> {\n // Only add files and not directories\n if (!object.key().equals(fileUri.getPath())\n && !object.key().endsWith(DELIMITER)) {\n String fileKey = object.key();\n if (fileKey.startsWith(DELIMITER)) {\n fileKey = fileKey.substring(1);\n }\n builder.add(S3_SCHEME + fileUri.getHost() + DELIMITER + fileKey);\n }\n });\n if (fileCount == LIST_MAX_KEYS) {\n // check if we reached the max keys returned, if so abort and throw an error message\n LOG.error(\n \"Too many files ({}) returned from S3 when attempting to list object prefixes\",\n LIST_MAX_KEYS);\n throw new IllegalStateException(\n String.format(\n \"Max keys (%s) reached when attempting to list S3 objects\", LIST_MAX_KEYS));\n }\n isDone = !listObjectsV2Response.isTruncated();\n continuationToken = listObjectsV2Response.nextContinuationToken();\n }\n String[] listedFiles = builder.build().toArray(new String[0]);\n LOG.debug(\n \"Listed {} files from URI: {}, is recursive: {}\", listedFiles.length, fileUri, recursive);\n return listedFiles;\n } catch (Throwable t) {\n throw new IOException(t);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListFilesInFolderRecursive", + "reference": " @Test\n public void testListFilesInFolderRecursive() throws Exception {\n String folder = \"list-files-rec\";\n String[] nestedFolders = new String[] {\"list-files-child-1\", \"list-files-child-2\"};\n String[] originalFiles = new String[] {\"a-list-3.txt\", \"b-list-3.txt\", \"c-list-3.txt\"};\n\n List expectedResultList = new ArrayList<>();\n for (String childFolder : nestedFolders) {\n String folderName = folder + DELIMITER + childFolder;\n for (String fileName : originalFiles) {\n createEmptyFile(folderName, fileName);\n expectedResultList.add(\n String.format(FILE_FORMAT, SCHEME, bucket, folderName + DELIMITER + fileName));\n }\n }\n String[] actualFiles =\n s3BlobFs.listFiles(URI.create(String.format(FILE_FORMAT, SCHEME, bucket, folder)), true);\n\n actualFiles =\n Arrays.stream(actualFiles).filter(x -> x.contains(\"list-3\")).toArray(String[]::new);\n assertEquals(actualFiles.length, expectedResultList.size());\n assertTrue(Arrays.equals(expectedResultList.toArray(), actualFiles));\n }", + "completions": [ + "```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception {\n // Given\n URI fileUri = new URI(\"s3://bucket/folder/\");\n boolean recursive = true;\n String[] expectedFiles = {\n \"s3://bucket/folder/file1.txt\",\n \"s3://bucket/folder/file2.txt\",\n \"s3://bucket/folder/subfolder/file3.txt\",\n \"s3://bucket/folder/subfolder/file4.txt\"\n };\n\n // When\n String[] actualFiles = s3FileSystem.listFiles(fileUri, recursive);\n\n // Then\n assertArrayEquals(expectedFiles, actualFiles);\n }\n```\n" + ], + "line": 70, + "token": 490, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 52, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_nameDuplicate", + "reference": " @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // mock 数据\n RoleDO roleDO = randomPojo(RoleDO.class, o -> o.setName(\"role_name\"));\n roleMapper.insert(roleDO);\n // 准备参数\n String name = \"role_name\";\n\n // 调用,并断言异常\n assertServiceException(() -> roleService.validateRoleDuplicate(name, randomString(), null),\n ROLE_NAME_DUPLICATE, name);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // Given\n String name = \"duplicateName\";\n String code = \"uniqueCode\";\n Long id = 1L;\n RoleDO existingRole = new RoleDO();\n existingRole.setId(2L);\n existingRole.setName(name);\n existingRole.setCode(code);\n when(roleMapper.selectByName(name)).thenReturn(existingRole);\n\n // When\n try {\n roleService.validateRoleDuplicate(name, code, id);\n fail(\"Expected exception not thrown\");\n } catch (ServiceException e) {\n // Then\n assertEquals(ROLE_NAME_DUPLICATE, e.getCode());\n assertEquals(name, e.getArgs()[0]);\n }\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 53, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public String sign(Map map) {\n try {\n Date date = new Date(System.currentTimeMillis() + jwtProperties.getExpireTimeMills());\n Algorithm algorithm = Algorithm.HMAC256(jwtProperties.getTokenSecret());\n Map header = new HashMap<>(2);\n header.put(\"typ\", \"jwt\");\n JWTCreator.Builder builder = JWT.create()\n .withHeader(header)\n .withIssuedAt(new Date())\n .withExpiresAt(date);\n map.entrySet().forEach(entry -> {\n if (entry.getValue() instanceof Integer) {\n builder.withClaim(entry.getKey(), (Integer) entry.getValue());\n } else if (entry.getValue() instanceof Long) {\n builder.withClaim(entry.getKey(), (Long) entry.getValue());\n } else if (entry.getValue() instanceof Boolean) {\n builder.withClaim(entry.getKey(), (Boolean) entry.getValue());\n } else if (entry.getValue() instanceof String) {\n builder.withClaim(entry.getKey(), String.valueOf(entry.getValue()));\n } else if (entry.getValue() instanceof Double) {\n builder.withClaim(entry.getKey(), (Double) entry.getValue());\n } else if (entry.getValue() instanceof Date) {\n builder.withClaim(entry.getKey(), (Date) entry.getValue());\n }\n });\n return builder.sign(algorithm);\n } catch (Exception e) {\n return null;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sign_generateJwtToken_succeed()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sign_generateJwtToken_succeed", + "reference": " @Test\n public void sign_generateJwtToken_succeed() {\n Map map = new HashMap<>();\n map.put(\"userId\", 123);\n map.put(\"username\", \"testUser\");\n String token = jwtService.sign(map);\n Assert.assertNotNull(token);\n }", + "completions": [ + "```java\n @Test\n public void sign_generateJwtToken_succeed() {\n // Given\n Map map = new HashMap<>();\n map.put(\"userId\", 1);\n map.put(\"username\", \"testUser\");\n map.put(\"isAdmin\", true);\n map.put(\"expireTime\", System.currentTimeMillis() + 3600000);\n\n // When\n String jwtToken = jwtService.sign(map);\n\n // Then\n assertNotNull(jwtToken);\n assertTrue(jwtToken.length() > 0);\n }\n```\n" + ], + "line": 41, + "token": 349, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 54, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_non_serializable_value_with_null_value", + "reference": " @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map original = new LinkedHashMap<>();\n original.put(\"null\", null);\n original.put(\"key\", new NonSerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map map = new HashMap<>();\n map.put(\"key\", null);\n\n Map clonedMap = (Map) clone(map);\n\n assertNotSame(map, clonedMap);\n assertEquals(map, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 55, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void init(final DiscoveryConfig config) {\n if (this.namingService != null) {\n LOGGER.info(\"Nacos naming service already registered\");\n return;\n }\n Properties properties = config.getProps();\n Properties nacosProperties = new Properties();\n this.groupName = properties.getProperty(\"groupName\", \"SHENYU_GROUP\");\n String serverAddr = config.getServerList();\n nacosProperties.put(PropertyKeyConst.SERVER_ADDR, serverAddr);\n nacosProperties.put(PropertyKeyConst.NAMESPACE, properties.getProperty(NAMESPACE, \"\"));\n nacosProperties.put(PropertyKeyConst.USERNAME, properties.getProperty(PropertyKeyConst.USERNAME, \"\"));\n nacosProperties.put(PropertyKeyConst.PASSWORD, properties.getProperty(PropertyKeyConst.PASSWORD, \"\"));\n nacosProperties.put(PropertyKeyConst.ACCESS_KEY, properties.getProperty(PropertyKeyConst.ACCESS_KEY, \"\"));\n nacosProperties.put(PropertyKeyConst.SECRET_KEY, properties.getProperty(PropertyKeyConst.SECRET_KEY, \"\"));\n try {\n this.namingService = NamingFactory.createNamingService(nacosProperties);\n LOGGER.info(\"Nacos naming service initialized success\");\n } catch (NacosException e) {\n LOGGER.error(\"Error initializing Nacos naming service\", e);\n throw new ShenyuException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testInit", + "reference": " @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n // Set the discovery config\n setField(nacosDiscoveryServiceUnderTest.getClass(), \"namingService\", null);\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n config.setServerList(\"127.0.0.1:8848\");\n properties.setProperty(\"groupName\", \"SHENYU_GROUP\");\n config.setProps(properties);\n\n try (MockedStatic mockedNamingFactory = mockStatic(NamingFactory.class)) {\n // Mock the successful creation of NamingService\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenReturn(namingService);\n nacosDiscoveryServiceUnderTest.init(config);\n mockedNamingFactory.verify(() -> NamingFactory.createNamingService(any(Properties.class)));\n assertEquals(namingService, getField(nacosDiscoveryServiceUnderTest, \"namingService\"));\n // Mock the situation where NamingService fails to be created and throws an exception\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenThrow(new NacosException());\n assertDoesNotThrow(() -> nacosDiscoveryServiceUnderTest.init(config));\n }\n }", + "completions": [ + "```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n properties.put(\"groupName\", \"SHENYU_GROUP\");\n properties.put(\"serverAddr\", \"localhost:8848\");\n properties.put(\"namespace\", \"\");\n properties.put(\"username\", \"\");\n properties.put(\"password\", \"\");\n properties.put(\"accessKey\", \"\");\n properties.put(\"secretKey\", \"\");\n config.setProps(properties);\n config.setServerList(\"localhost:8848\");\n\n NacosDiscoveryService nacosDiscoveryService = new NacosDiscoveryService();\n nacosDiscoveryService.init(config);\n\n Field namingServiceField = NacosDiscoveryService.class.getDeclaredField(\"namingService\");\n namingServiceField.setAccessible(true);\n NamingService namingService = (NamingService) namingServiceField.get(nacosDiscoveryService);\n\n assertNotNull(namingService);\n }\n```\n" + ], + "line": 35, + "token": 273, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 56, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyRecoveryNotBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(850, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 2) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(750, 0, indexerConfig));\n\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n assertThat(recoveryTaskCreator.determineStartingOffset(1201, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTaskCreator.determineStartingOffset(1200, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(1150, 0, indexerConfig));\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore))\n .contains(recoveryTask1, recoveryTask11);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 57, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public boolean isJsonMime(String mediaType) {\n // \"* / *\" is default to JSON\n if (\"*/*\".equals(mediaType)) {\n return true;\n }\n\n try {\n return isJsonMime(MediaType.parseMediaType(mediaType));\n } catch (InvalidMediaTypeException e) {\n }\n return false;\n }\n```\n###test function signature:\n```java\n @Test\n public void testIsJsonMime()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJsonMime", + "reference": " @Test\n public void testIsJsonMime() {\n assertFalse(apiClient.isJsonMime(null));\n assertFalse(apiClient.isJsonMime(\"\"));\n assertFalse(apiClient.isJsonMime(\"text/plain\"));\n assertFalse(apiClient.isJsonMime(\"application/xml\"));\n assertFalse(apiClient.isJsonMime(\"application/jsonp\"));\n assertFalse(apiClient.isJsonMime(\"example/json\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+jsonx\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+xjson\"));\n\n assertTrue(apiClient.isJsonMime(\"application/json\"));\n assertTrue(apiClient.isJsonMime(\"application/json; charset=UTF8\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/JSON\"));\n\n assertTrue(apiClient.isJsonMime(\"application/problem+json\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/PROBLEM+JSON\"));\n assertTrue(apiClient.isJsonMime(\"application/json\\t\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+bar+json\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json;x;y\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json\\t;\"));\n assertTrue(apiClient.isJsonMime(\"Example/fOO+JSON\"));\n }", + "completions": [ + "```java\n @Test\n public void testIsJsonMime() {\n // Test case 1: \"* / *\"\n assertTrue(isJsonMime(\"*/*\"));\n\n // Test case 2: \"application/json\"\n assertTrue(isJsonMime(\"application/json\"));\n\n // Test case 3: \"application/json; charset=UTF-8\"\n assertTrue(isJsonMime(\"application/json; charset=UTF-8\"));\n\n // Test case 4: \"application/xml\"\n assertFalse(isJsonMime(\"application/xml\"));\n\n // Test case 5: \"text/plain\"\n assertFalse(isJsonMime(\"text/plain\"));\n\n // Test case 6: \"invalid/media/type\"\n assertFalse(isJsonMime(\"invalid/media/type\"));\n\n // Test case 7: null\n assertFalse(isJsonMime(null));\n }\n```\n" + ], + "line": 22, + "token": 139, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n dpopService.validateDpopThumprint(\"test\", \"test\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n // Given\n String existingThumprint = \"existingThumprint\";\n String requestThumprint = \"existingThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n // Call the method to test\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String getExpandUrl(String url, Object paramObj) {\n UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url);\n if (paramObj instanceof Map) {\n for (Map.Entry param : ((Map) paramObj).entrySet()) {\n if (null != param.getValue()) {\n builder.queryParam(param.getKey(), param.getValue());\n }\n }\n } else {\n for (Field param : paramObj.getClass().getDeclaredFields()) {\n param.setAccessible(true);\n\n try {\n Object value = param.get(paramObj);\n if (null != value) {\n if (value instanceof List) {\n builder.queryParam(param.getName(), (List) value);\n } else {\n builder.queryParam(param.getName(), value);\n }\n }\n } catch (IllegalAccessException e) {\n log.warn(\"Field: {} value retrieval failed, exception: {}\", param.getName(), e);\n }\n }\n }\n\n return builder.build().encode().toUri().toString();\n }\n```\n###test function signature:\n```java\n @Test\n void testGetExpandUrl()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetExpandUrl", + "reference": " @Test\n void testGetExpandUrl() {\n String result = HttpUtil.getExpandUrl(\"http://baidu.com\", \"paramObj\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testGetExpandUrl() {\n // Test with a Map\n Map map = new HashMap<>();\n map.put(\"param1\", \"value1\");\n map.put(\"param2\", \"value2\");\n String url = \"http://example.com\";\n String expected = \"http://example.com?param1=value1¶m2=value2\";\n String actual = getExpandUrl(url, map);\n assertEquals(expected, actual);\n\n // Test with a custom object\n TestObject obj = new TestObject(\"value1\", \"value2\");\n expected = \"http://example.com?param1=value1¶m2=value2\";\n actual = getExpandUrl(url, obj);\n assertEquals(expected, actual);\n }\n\n public static class TestObject {\n private String param1;\n private String param2;\n\n public TestObject(String param1, String param2) {\n this.param1 = param1;\n this.param2 = param2;\n }\n\n public String getParam1() {\n return param1;\n }\n\n public String getParam2() {\n return param2;\n }\n }\n```\n" + ], + "line": 38, + "token": 277, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 60, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"asdf\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n String invalidExpirationTime = \"invalid\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(invalidExpirationTime);\n assertEquals(CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 61, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Get\n @Path(\"/:indexName/_mapping\")\n public HttpResponse mapping(\n @Param(\"indexName\") Optional indexName,\n @Param(\"startTimeEpochMs\") Optional startTimeEpochMs,\n @Param(\"endTimeEpochMs\") Optional endTimeEpochMs)\n throws IOException {\n // Use a tree map so the results are naturally sorted\n Map> propertiesMap = new TreeMap<>();\n\n // we default the schema search to the last hour if params are not provided\n AstraSearch.SchemaResult schemaResult =\n searcher.getSchema(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(\n startTimeEpochMs.orElse(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build());\n\n Map schema = SearchResultUtils.fromSchemaResultProto(schemaResult);\n schema.forEach((key, value) -> propertiesMap.put(key, Map.of(\"type\", value.getName())));\n\n // todo - remove this after we add support for a \"date\" type\n // override the timestamp as a date field for proper autocomplete\n propertiesMap.put(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, Map.of(\"type\", \"date\"));\n\n return HttpResponse.of(\n HttpStatus.OK,\n MediaType.JSON,\n JsonUtil.writeAsString(\n ImmutableMap.of(\n indexName.orElseThrow(),\n ImmutableMap.of(\"mappings\", ImmutableMap.of(\"properties\", propertiesMap)))));\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexMapping() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexMapping", + "reference": " @Test\n public void testIndexMapping() throws IOException {\n AstraQueryServiceBase searcher = mock(AstraQueryServiceBase.class);\n ElasticsearchApiService serviceUnderTest = new ElasticsearchApiService(searcher);\n\n Instant start = Instant.now();\n Instant end = start.minusSeconds(60);\n\n when(searcher.getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build())))\n .thenReturn(AstraSearch.SchemaResult.newBuilder().build());\n\n HttpResponse response =\n serviceUnderTest.mapping(\n Optional.of(\"foo\"), Optional.of(start.toEpochMilli()), Optional.of(end.toEpochMilli()));\n verify(searcher)\n .getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build()));\n\n // handle response\n AggregatedHttpResponse aggregatedRes = response.aggregate().join();\n String body = aggregatedRes.content(StandardCharsets.UTF_8);\n JsonNode jsonNode = new ObjectMapper().readTree(body);\n\n assertThat(aggregatedRes.status().code()).isEqualTo(200);\n\n assertThat(jsonNode.findValue(\"foo\")).isNotNull();\n assertThat(\n jsonNode.findValue(\"foo\").findValue(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName))\n .isNotNull();\n\n when(searcher.getSchema(any()))\n .thenAnswer(\n invocationOnMock -> {\n AstraSearch.SchemaRequest request =\n ((AstraSearch.SchemaRequest) invocationOnMock.getArguments()[0]);\n assertThat(request.getDataset()).isEqualTo(\"bar\");\n assertThat(request.getStartTimeEpochMs())\n .isCloseTo(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli(),\n Offset.offset(1000L));\n assertThat(request.getEndTimeEpochMs())\n .isCloseTo(Instant.now().toEpochMilli(), Offset.offset(1000L));\n return AstraSearch.SchemaResult.newBuilder().build();\n });\n serviceUnderTest.mapping(Optional.of(\"bar\"), Optional.empty(), Optional.empty());\n }", + "completions": [ + "```java\n @Test\n public void testIndexMapping() throws IOException {\n Optional indexName = Optional.of(\"testIndex\");\n Optional startTimeEpochMs = Optional.of(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli());\n Optional endTimeEpochMs = Optional.of(Instant.now().toEpochMilli());\n\n AstraSearch.SchemaResult schemaResult =\n AstraSearch.SchemaResult.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(startTimeEpochMs.orElse(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build();\n\n when(searcher.getSchema(any(AstraSearch.SchemaRequest.class))).thenReturn(schemaResult);\n\n HttpResponse response = mapping(indexName, startTimeEpochMs, endTimeEpochMs);\n\n assertEquals(HttpStatus.OK, response.getStatus());\n assertEquals(MediaType.JSON, response.getContentType());\n assertNotNull(response.getContent());\n }\n```\n" + ], + "line": 43, + "token": 354, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 62, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "@SuppressWarnings", + "reference": " @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n consumer = newConsumer();\n Exception e = assertThrows(UnsupportedOperationException.class, () -> consumer.poll(0L));\n assertEquals(\"Consumer.poll(long) is not supported when \\\"group.protocol\\\" is \\\"consumer\\\". \" +\n \"This method is deprecated and will be removed in the next major release.\", e.getMessage());\n }", + "completions": [ + "```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 46, + "token": 373, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 63, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testNullSearchString()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNullSearchString", + "reference": " @Test\n public void testNullSearchString() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n assertThatExceptionOfType(IllegalArgumentException.class)\n .isThrownBy(\n () ->\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME + \"miss\",\n null,\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\")));\n }", + "completions": [ + "```java\n @Test\n public void testNullSearchString() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = null;\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n try {\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n fail(\"Expected an IllegalArgumentException to be thrown\");\n } catch (IllegalArgumentException e) {\n // Expected\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 64, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture sendMessageBack(ProxyContext ctx, ReceiptHandle handle, String messageId,\n ConsumerSendMsgBackRequestHeader requestHeader, long timeoutMillis) {\n // Build the response.\n final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null, null);\n\n Integer delayLevel = requestHeader.getDelayLevel();\n if (Objects.isNull(delayLevel)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument delay level is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n Long offset = requestHeader.getOffset();\n if (Objects.isNull(offset)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument offset is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n VirtualQueue virtualQueue = new VirtualQueue(requestHeader.getBname());\n\n CompletableFuture topicFuture = topicOf(requestHeader.getOriginTopic());\n CompletableFuture groupFuture = consumerGroupOf(requestHeader.getGroup());\n return topicFuture.thenCombine(groupFuture, (topic, group) -> {\n if (topic.getTopicId() != virtualQueue.topicId()) {\n LOGGER.error(\"Topic id in request header {} does not match topic id in message queue {}, maybe the topic is recreated.\",\n topic.getTopicId(), virtualQueue.topicId());\n throw new ProxyException(apache.rocketmq.v2.Code.TOPIC_NOT_FOUND, \"Topic resource does not exist.\");\n }\n return Pair.of(topic, group);\n }).thenCompose(pair -> {\n Topic topic = pair.getLeft();\n ConsumerGroup group = pair.getRight();\n\n return store.pull(StoreContext.EMPTY, group.getGroupId(), topic.getTopicId(), virtualQueue.physicalQueueId(),\n Filter.DEFAULT_FILTER, requestHeader.getOffset(), 1, false)\n .thenApply(pullResult -> {\n if (pullResult.status() == com.automq.rocketmq.store.model.message.PullResult.Status.FOUND) {\n return pullResult.messageList().get(0);\n }\n throw new ProxyException(apache.rocketmq.v2.Code.MESSAGE_NOT_FOUND, \"Message not found from server.\");\n }).thenCompose(messageExt -> {\n if (messageExt.deliveryAttempts() > group.getMaxDeliveryAttempt()) {\n return deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n }\n\n // Message consume retry strategy\n // <0: no retry,put into DLQ directly\n // =0: broker control retry frequency\n // >0: client control retry frequency\n return switch (Integer.compare(delayLevel, 0)) {\n case -1 ->\n deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n case 0 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n // Keep the same logic as apache RocketMQ.\n int serverDelayLevel = messageExt.deliveryAttempts() + 1;\n messageExt.setDeliveryAttempts(serverDelayLevel);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(serverDelayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put messageExt to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n case 1 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n messageExt.setDeliveryAttempts(messageExt.deliveryAttempts() + 1);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(delayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put message to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n default -> throw new IllegalStateException(\"Never reach here\");\n };\n });\n }).whenComplete((nil, throwable) -> {\n if (throwable != null) {\n response.setCode(ResponseCode.SYSTEM_ERROR);\n response.setRemark(throwable.getMessage());\n }\n }).thenApply(nil -> response);\n }\n```\n###test function signature:\n```java\n @Test\n void sendMessageBack()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendMessageBack", + "reference": " @Test\n void sendMessageBack() {\n String topicName = \"topic\";\n VirtualQueue virtualQueue = new VirtualQueue(2, 0);\n ConsumerSendMsgBackRequestHeader header = new ConsumerSendMsgBackRequestHeader();\n header.setBname(virtualQueue.brokerName());\n header.setGroup(\"group\");\n header.setOriginTopic(topicName);\n header.setOffset(0L);\n header.setDelayLevel(0);\n\n CompletionException exception = assertThrowsExactly(CompletionException.class, () -> messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join());\n assertInstanceOf(ProxyException.class, ExceptionUtils.getRealException(exception));\n ProxyException realException = (ProxyException) ExceptionUtils.getRealException(exception);\n assertEquals(realException.getErrorCode(), Code.MESSAGE_NOT_FOUND);\n assertEquals(realException.getMessage(), \"Message not found from server.\");\n\n long topicId = metadataService.topicOf(topicName).join().getTopicId();\n messageStore.put(StoreContext.EMPTY, FlatMessageUtil.convertTo(ProxyContextExt.create(), topicId, 0, \"\", new Message(topicName, \"\", new byte[] {})));\n\n // Broker controlled delay level.\n header.setDelayLevel(0);\n RemotingCommand response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Client controlled delay level.\n header.setDelayLevel(16);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Forward message into dead letter topic.\n header.setDelayLevel(-1);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.times(1)).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n }", + "completions": [ + "```java\n @Test\n void sendMessageBack() {\n // Given\n ProxyContext ctx = mock(ProxyContext.class);\n ReceiptHandle handle = mock(ReceiptHandle.class);\n String messageId = \"testMessageId\";\n ConsumerSendMsgBackRequestHeader requestHeader = mock(ConsumerSendMsgBackRequestHeader.class);\n long timeoutMillis = 1000L;\n\n when(requestHeader.getDelayLevel()).thenReturn(1);\n when(requestHeader.getOffset()).thenReturn(1L);\n when(requestHeader.getOriginTopic()).thenReturn(\"testOriginTopic\");\n when(requestHeader.getGroup()).thenReturn(\"testGroup\");\n when(requestHeader.getBname()).thenReturn(\"testBname\");\n\n Topic topic = mock(Topic.class);\n when(topic.getTopicId()).thenReturn(1);\n\n ConsumerGroup group = mock(ConsumerGroup.class);\n when(group.getGroupId()).thenReturn(\"testGroupId\");\n when(group.getMaxDeliveryAttempt()).thenReturn(3);\n\n FlatMessage message = mock(FlatMessage.class);\n when(message.deliveryAttempts()).thenReturn(1);\n when(message.originalOffset()).thenReturn(1L);\n\n MessageExt messageExt = mock(MessageExt.class);\n when(messageExt.message()).thenReturn(message);\n when(messageExt.deliveryAttempts()).thenReturn(1);\n\n PullResult pullResult = mock(PullResult.class);\n when(pullResult.status()).thenReturn(PullResult.Status.FOUND);\n when(pullResult.messageList()).thenReturn(List.of(messageExt));\n\n Store store = mock(Store.class);\n when(store.pull(any(), any(), any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(pullResult));\n when(store.put(any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n DeadLetterService deadLetterService = mock(DeadLetterService.class);\n when(deadLetterService.send(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n ProxyContextExt ctxExt = mock(ProxyContextExt.class);\n\n TopicService topicService = mock(TopicService.class);\n when(topicService.topicOf(any())).thenReturn(CompletableFuture.completedFuture(topic));\n when(topicService.consumerGroupOf(any())).thenReturn(CompletableFuture.completedFuture(group));\n\n ConsumerService consumerService = new ConsumerService(store, deadLetterService, topicService);\n\n // When\n CompletableFuture future = consumerService.sendMessageBack(ctx, handle, messageId, requestHeader, timeoutMillis);\n\n // Then\n RemotingCommand response = future.join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n }\n```\n" + ], + "line": 99, + "token": 793, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 65, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInRequestThreads()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInRequestThreads", + "reference": " @Test\n public void testGetTracerInRequestThreads() {\n ApolloAuditTracer mockTracer = new ApolloAuditTracer(Mockito.mock(ApolloAuditScopeManager.class), supplier);\n RequestAttributes mockRequestAttributes = Mockito.mock(RequestAttributes.class);\n RequestContextHolder.setRequestAttributes(mockRequestAttributes);\n Mockito.when(mockRequestAttributes.getAttribute(Mockito.eq(ApolloAuditConstants.TRACER), Mockito.eq(RequestAttributes.SCOPE_REQUEST)))\n .thenReturn(mockTracer);\n ApolloAuditTracer get = traceContext.tracer();\n assertNotNull(get);\n Mockito.verify(traceContext, Mockito.times(0))\n .setTracer(Mockito.any(ApolloAuditTracer.class));\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInRequestThreads() {\n // Given\n RequestAttributes requestAttributes = Mockito.mock(RequestAttributes.class);\n ApolloAuditTracer tracer = Mockito.mock(ApolloAuditTracer.class);\n Mockito.when(requestAttributes.getAttribute(ApolloAuditConstants.TRACER, RequestAttributes.SCOPE_REQUEST)).thenReturn(tracer);\n Mockito.when(RequestContextHolder.getRequestAttributes()).thenReturn(requestAttributes);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n Assert.assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Map generate(String modelName) {\n Map kv = new HashMap<>();\n if (modelName != null) {\n final Schema schema = this.globalContext.getSchemaMap().get(modelName);\n if (schema != null) {\n String example = Json.pretty(this.resolveModelToExample(modelName, schema));\n\n if (example != null) {\n kv.put(EXAMPLE, example);\n return Map.copyOf(kv);\n }\n }\n }\n return Collections.emptyMap();\n }\n```\n###test function signature:\n```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "shouldGenerateOneOfWhenOneOfInRoot", + "reference": " @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n OpenAPIModelGenerator generator = setupPayloadGenerator();\n\n Map example = generator.generate(\"PetType\");\n String exampleJson = example.get(\"example\");\n\n Assertions.assertThat(exampleJson)\n .contains(\"PetTypeONE_OF#/components/schemas/Husky\")\n .contains(\"PetTypeONE_OF#/components/schemas/Labrador\");\n }", + "completions": [ + "```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n // Given\n String modelName = \"ModelName\";\n Schema schema = new Schema();\n schema.setOneOf(Arrays.asList(new Schema(), new Schema()));\n this.globalContext.getSchemaMap().put(modelName, schema);\n\n // When\n Map result = this.generate(modelName);\n\n // Then\n assertNotNull(result);\n assertFalse(result.isEmpty());\n assertTrue(result.containsKey(EXAMPLE));\n assertNotNull(result.get(EXAMPLE));\n }\n```\n" + ], + "line": 25, + "token": 181, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInAnotherThreadButSameRequest", + "reference": " @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n ApolloAuditTracer mockTracer = Mockito.mock(ApolloAuditTracer.class);\n {\n Mockito.when(traceContext.tracer()).thenReturn(mockTracer);\n }\n CountDownLatch latch = new CountDownLatch(1);\n Executors.newSingleThreadExecutor().submit(() -> {\n ApolloAuditTracer tracer = traceContext.tracer();\n\n assertEquals(mockTracer, tracer);\n\n latch.countDown();\n });\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n // Given\n RequestAttributes requestAttributes = new ServletRequestAttributes(new MockHttpServletRequest());\n RequestContextHolder.setRequestAttributes(requestAttributes);\n ApolloAuditTracer tracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n requestAttributes.setAttribute(ApolloAuditConstants.TRACER, tracer, RequestAttributes.SCOPE_REQUEST);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchMultipleItemsAndIndices()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchMultipleItemsAndIndices", + "reference": " @Test\n public void testSearchMultipleItemsAndIndices() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(1);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchMultipleItemsAndIndices() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(10, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 69, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n static Map>> getCandidateJobs(String tableNameWithType,\n Map> allJobMetadata)\n throws Exception {\n long nowMs = System.currentTimeMillis();\n Map>> candidates = new HashMap<>();\n // If the job started most recently has already completed, then skip retry for the table.\n Pair latestStartedJob = null;\n Pair latestCompletedJob = null;\n // The processing order of job metadata from the given Map is not deterministic. Track the completed original\n // jobs so that we can simply skip the retry jobs belonging to the completed original jobs.\n Map completedOriginalJobs = new HashMap<>();\n Set cancelledOriginalJobs = new HashSet<>();\n for (Map.Entry> entry : allJobMetadata.entrySet()) {\n String jobId = entry.getKey();\n Map jobMetadata = entry.getValue();\n long statsUpdatedAt = Long.parseLong(jobMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));\n String jobStatsInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);\n if (StringUtils.isEmpty(jobStatsInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job progress stats\", jobId);\n continue;\n }\n String jobCtxInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n if (StringUtils.isEmpty(jobCtxInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job context\", jobId);\n continue;\n }\n TableRebalanceProgressStats jobStats = JsonUtils.stringToObject(jobStatsInStr, TableRebalanceProgressStats.class);\n TableRebalanceContext jobCtx = JsonUtils.stringToObject(jobCtxInStr, TableRebalanceContext.class);\n long jobStartTimeMs = jobStats.getStartTimeMs();\n if (latestStartedJob == null || latestStartedJob.getRight() < jobStartTimeMs) {\n latestStartedJob = Pair.of(jobId, jobStartTimeMs);\n }\n String originalJobId = jobCtx.getOriginalJobId();\n RebalanceResult.Status jobStatus = jobStats.getStatus();\n if (jobStatus == RebalanceResult.Status.DONE || jobStatus == RebalanceResult.Status.NO_OP) {\n LOGGER.info(\"Skip rebalance job: {} as it has completed with status: {}\", jobId, jobStatus);\n completedOriginalJobs.put(originalJobId, jobId);\n if (latestCompletedJob == null || latestCompletedJob.getRight() < jobStartTimeMs) {\n latestCompletedJob = Pair.of(jobId, jobStartTimeMs);\n }\n continue;\n }\n if (jobStatus == RebalanceResult.Status.FAILED || jobStatus == RebalanceResult.Status.ABORTED) {\n LOGGER.info(\"Found rebalance job: {} for original job: {} has been stopped with status: {}\", jobId,\n originalJobId, jobStatus);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n continue;\n }\n if (jobStatus == RebalanceResult.Status.CANCELLED) {\n LOGGER.info(\"Found cancelled rebalance job: {} for original job: {}\", jobId, originalJobId);\n cancelledOriginalJobs.add(originalJobId);\n continue;\n }\n // Check if an IN_PROGRESS job is still actively running.\n long heartbeatTimeoutMs = jobCtx.getConfig().getHeartbeatTimeoutInMs();\n if (nowMs - statsUpdatedAt < heartbeatTimeoutMs) {\n LOGGER.info(\"Rebalance job: {} is actively running with status updated at: {} within timeout: {}. Skip \"\n + \"retry for table: {}\", jobId, statsUpdatedAt, heartbeatTimeoutMs, tableNameWithType);\n return Collections.emptyMap();\n }\n // The job is considered failed, but it's possible it is still running, then we might end up with more than one\n // rebalance jobs running in parallel for a table. The rebalance algorithm is idempotent, so this should be fine\n // for the correctness.\n LOGGER.info(\"Found stuck rebalance job: {} for original job: {}\", jobId, originalJobId);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n }\n if (latestCompletedJob != null && latestCompletedJob.getLeft().equals(latestStartedJob.getLeft())) {\n LOGGER.info(\"Rebalance job: {} started most recently has already done. Skip retry for table: {}\",\n latestCompletedJob.getLeft(), tableNameWithType);\n return Collections.emptyMap();\n }\n for (String jobId : cancelledOriginalJobs) {\n LOGGER.info(\"Skip original job: {} as it's cancelled\", jobId);\n candidates.remove(jobId);\n }\n for (Map.Entry entry : completedOriginalJobs.entrySet()) {\n LOGGER.info(\"Skip original job: {} as it's completed by attempt: {}\", entry.getKey(), entry.getValue());\n candidates.remove(entry.getKey());\n }\n return candidates;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetCandidateJobs()\n throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetCandidateJobs", + "reference": " @Test\n public void testGetCandidateJobs()\n throws Exception {\n String tableName = \"table01\";\n Map> allJobMetadata = new HashMap<>();\n\n // Original job run as job1, and all its retry jobs failed too.\n RebalanceConfig jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n TableRebalanceProgressStats stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(1000);\n TableRebalanceContext jobCtx = TableRebalanceContext.forInitialAttempt(\"job1\", jobCfg);\n Map jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job1\", stats, jobCtx);\n allJobMetadata.put(\"job1\", jobMetadata);\n // 3 failed retry runs for job1\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 2, 1100, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 3, 1200, RebalanceResult.Status.ABORTED);\n allJobMetadata.put(\"job1_3\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 4, 1300, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_4\", jobMetadata);\n\n // Original job run as job2, and its retry job job2_1 completed.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(2000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job2\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job2\", stats, jobCtx);\n allJobMetadata.put(\"job2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job2\", 2, 2100, RebalanceResult.Status.DONE);\n allJobMetadata.put(\"job2_2\", jobMetadata);\n\n // Original job run as job3, and failed to send out heartbeat in time.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.IN_PROGRESS);\n stats.setStartTimeMs(3000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job3\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job3\", stats, jobCtx);\n jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"3000\");\n allJobMetadata.put(\"job3\", jobMetadata);\n\n // Original job run as job4, which didn't have retryJobCfg as from old version of the code.\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(4000);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job4\", stats, null);\n jobMetadata.remove(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n allJobMetadata.put(\"job4\", jobMetadata);\n\n // Only need to retry job1 and job3, as job2 is completed and job4 is from old version of code.\n Map>> jobs =\n RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 2);\n assertTrue(jobs.containsKey(\"job1\"));\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job1\").size(), 4); // four runs including job1,job1_1,job1_2,job1_3\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Abort job1 and cancel its retries, then only job3 is retry candidate.\n jobMetadata = allJobMetadata.get(\"job1_4\");\n cancelRebalanceJob(jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 1);\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Add latest job5 that's already done, thus no need to retry for table.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.DONE);\n stats.setStartTimeMs(5000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job5\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job5\", stats, jobCtx);\n allJobMetadata.put(\"job5\", jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 0);\n }", + "completions": [ + "```java\n @Test\n public void testGetCandidateJobs()\n throws Exception {\n // Given\n String tableNameWithType = \"table1_type1\";\n Map> allJobMetadata = new HashMap<>();\n Map jobMetadata1 = new HashMap<>();\n jobMetadata1.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"FAILED\\\"}\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job1\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job1\", jobMetadata1);\n Map jobMetadata2 = new HashMap<>();\n jobMetadata2.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"CANCELLED\\\"}\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job2\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job2\", jobMetadata2);\n\n // When\n Map>> result = getCandidateJobs(tableNameWithType, allJobMetadata);\n\n // Then\n assertEquals(1, result.size());\n assertTrue(result.containsKey(\"job2\"));\n Set> jobSet = result.get(\"job2\");\n assertEquals(1, jobSet.size());\n Pair jobPair = jobSet.iterator().next();\n assertEquals(\"job2\", jobPair.getLeft().getOriginalJobId());\n assertEquals(1640995200000L, jobPair.getRight().longValue());\n }\n```\n" + ], + "line": 94, + "token": 952, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 70, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchError()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchError", + "reference": " @Test\n public void distClientMatchError() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.ERROR);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.ERROR);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.NO_SUB));\n for (IMQTTTransientSession session : sessions) {\n verify(session, never()).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchError() {\n // Create a mock for the ITenantMeter interface\n ITenantMeter mockTenantMeter = Mockito.mock(ITenantMeter.class);\n // Create a mock for the ILocalDistService interface\n ILocalDistService mockLocalDistService = Mockito.mock(ILocalDistService.class);\n // Create a mock for the sessionMap\n Map mockSessionMap = new HashMap<>();\n // Create a mock for the routeMap\n Map> mockRouteMap = new HashMap<>();\n // Create a mock for the resourceThrottler\n IResourceThrottler mockResourceThrottler = Mockito.mock(IResourceThrottler.class);\n\n // Set up the mock objects\n Mockito.when(mockTenantMeter.get(Mockito.anyString())).thenReturn(mockTenantMeter);\n Mockito.when(mockLocalDistService.isGlobal(Mockito.anyString())).thenReturn(false);\n Mockito.when(mockSessionMap.get(Mockito.anyString())).thenReturn(null);\n Mockito.when(mockRouteMap.get(Mockito.any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(null));\n Mockito.when(mockResourceThrottler.hasResource(Mockito.anyString(), Mockito.anyLong())).thenReturn(false);\n\n // Create a new instance of the class under test\n DistService distService = new DistService(mockTenantMeter, mockLocalDistService, mockSessionMap, mockRouteMap, mockResourceThrottler);\n\n // Create a test request\n DeliveryRequest request = DeliveryRequest.newBuilder().build();\n\n // Call the method under test\n CompletableFuture replyFuture = distService.dist(request);\n\n // Verify the results\n DeliveryReply reply = replyFuture.join();\n Assert.assertEquals(0, reply.getResultMap().size());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 71, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchOK()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchOK", + "reference": " @Test\n public void distClientMatchOK() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n when(session.publish(any(), anyList())).thenReturn(true);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.OK);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.OK);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.OK));\n for (IMQTTTransientSession session : sessions) {\n verify(session, times(1)).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchOK() {\n // Given\n DeliveryRequest request = DeliveryRequest.newBuilder()\n .putPackage(\"tenant1\", DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .setMessagePack(TopicMessagePack.newBuilder()\n .setTopic(\"topic1\")\n .setPayload(\"payload1\")\n .build())\n .addMatchInfo(MatchInfo.newBuilder()\n .setReceiverId(\"receiver1\")\n .setTopicFilter(\"topic1\")\n .build())\n .build())\n .build())\n .build();\n sessionMap.put(\"receiver1\", mock(IMQTTTransientSession.class));\n when(sessionMap.get(\"receiver1\").publish(any(MatchInfo.class), anyList())).thenReturn(true);\n when(routeMap.get(any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(LocalRoutes.newBuilder()\n .setLocalizedReceiverId(\"receiver1\")\n .putRoute(\"receiver1\", mock(IMQTTTransientSession.class))\n .build()));\n when(ITenantMeter.get(\"tenant1\")).thenReturn(mock(ITenantMeter.class));\n when(resourceThrottler.hasResource(\"tenant1\", TotalTransientFanOutBytesPerSeconds)).thenReturn(true);\n\n // When\n CompletableFuture replyFuture = dist(request);\n\n // Then\n DeliveryReply reply = replyFuture.join();\n DeliveryResults results = reply.getResultMap().get(\"tenant1\");\n assertEquals(1, results.getResultCount());\n DeliveryResult result = results.getResult(0);\n assertEquals(DeliveryResult.Code.OK, result.getCode());\n assertEquals(\"receiver1\", result.getMatchInfo().getReceiverId());\n assertEquals(\"topic1\", result.getMatchInfo().getTopicFilter());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 72, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Range calcInterquartileRangeApproximation(List sortedValues) {\n Range iqr = null;\n\n if (sortedValues.size() > 0) {\n BigDecimal q1 = valueCloseToQ1(sortedValues);\n BigDecimal q3 = valueCloseToQ3(sortedValues);\n BigDecimal max = sortedValues.get(sortedValues.size() - 1);\n\n // if iqr == 0 AND max == q3 then recursively try finding a non-zero iqr approximation.\n if (q1.compareTo(q3) == 0 && max.compareTo(q3) == 0) {\n // filter out max and try again\n iqr = this.calcInterquartileRangeApproximation(\n sortedValues.stream().filter(d -> d.compareTo(max) < 0).collect(Collectors.toList()));\n }\n\n // if range is still empty use the original q1 and q3 values\n if (iqr == null || iqr.isEmpty()) {\n iqr = Range.closedOpen(q1, q3);\n }\n }\n\n return iqr;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCalcInterquartileRangeApproximation()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCalcInterquartileRangeApproximation", + "reference": " @Test\n public void testCalcInterquartileRangeApproximation() {\n List values = decList(1, 2, 3, 4, 5, 6, 7, 8);\n assertEquals(new BigDecimal(\"3\"), dataBinHelper.calcInterquartileRangeApproximation(values).lowerEndpoint());\n assertEquals(new BigDecimal(\"7\"), dataBinHelper.calcInterquartileRangeApproximation(values).upperEndpoint());\n }", + "completions": [ + "```java\n @Test\n public void testCalcInterquartileRangeApproximation() {\n List sortedValues = Arrays.asList(\n new BigDecimal(\"1.0\"),\n new BigDecimal(\"2.0\"),\n new BigDecimal(\"3.0\"),\n new BigDecimal(\"4.0\"),\n new BigDecimal(\"5.0\"),\n new BigDecimal(\"6.0\"),\n new BigDecimal(\"7.0\"),\n new BigDecimal(\"8.0\"),\n new BigDecimal(\"9.0\"),\n new BigDecimal(\"10.0\")\n );\n\n Range expectedIqr = Range.closedOpen(new BigDecimal(\"2.5\"), new BigDecimal(\"7.5\"));\n Range actualIqr = calcInterquartileRangeApproximation(sortedValues);\n\n assertEquals(expectedIqr, actualIqr);\n }\n```\n" + ], + "line": 30, + "token": 250, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 73, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_non_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new NonSerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new NonSerializableObject(\"name2\"),\n new NonSerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObjectWithNonSerializableNestedObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serialable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key1\", \"value1\");\n map.put(\"key2\", new NonSerializableObject(\"nonSerializableValue\"));\n\n // When\n Map clonedMap = objectCloner.clone(map);\n\n // Then\n assertNotSame(map, clonedMap);\n assertEquals(map.size(), clonedMap.size());\n assertEquals(map.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(map.get(\"key2\"), clonedMap.get(\"key2\"));\n assertEquals(map.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 74, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SignedJWT token2Jwt(String token) {\n if (token == null) {\n log.error(\"Token is null.\");\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n try {\n return SignedJWT.parse(token);\n } catch (ParseException e) {\n log.error(\"Token to JWT Error,message = {}\", e.getMessage());\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testToken2Jwt()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testToken2Jwt", + "reference": " @Test\n public void testToken2Jwt() {\n Assertions.assertThrows(BizException.class, () -> tokenParseHelper.token2Jwt(null));\n }", + "completions": [ + "```java\n @Test\n public void testToken2Jwt() {\n // Given\n String token = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c\";\n\n // When\n SignedJWT jwt = token2Jwt(token);\n\n // Then\n assertNotNull(jwt);\n assertEquals(token, jwt.serialize());\n }\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testLongPollWaitIsLimited()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testLongPollWaitIsLimited", + "reference": " @Test\n public void testLongPollWaitIsLimited() {\n consumer = newConsumer();\n String topicName = \"topic1\";\n consumer.subscribe(singletonList(topicName));\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertTrue(consumer.assignment().isEmpty());\n\n final int partition = 3;\n final TopicPartition tp = new TopicPartition(topicName, partition);\n final List> records = asList(\n new ConsumerRecord<>(topicName, partition, 2, \"key1\", \"value1\"),\n new ConsumerRecord<>(topicName, partition, 3, \"key2\", \"value2\")\n );\n\n // On the first iteration, return no data; on the second, return two records\n doAnswer(invocation -> {\n // Mock the subscription being assigned as the first fetch is collected\n consumer.subscriptions().assignFromSubscribed(Collections.singleton(tp));\n return Fetch.empty();\n }).doAnswer(invocation -> {\n return Fetch.forPartition(tp, records, true);\n }).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n\n // And then poll for up to 10000ms, which should return 2 records without timing out\n ConsumerRecords returnedRecords = consumer.poll(Duration.ofMillis(10000));\n assertEquals(2, returnedRecords.count());\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertEquals(singleton(tp), consumer.assignment());\n }", + "completions": [ + "```java\n @Test\n public void testLongPollWaitIsLimited() {\n // Arrange\n Duration timeout = Duration.ofMillis(1000);\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // Act\n long startTime = System.currentTimeMillis();\n ConsumerRecords records = consumer.poll(timeout);\n long endTime = System.currentTimeMillis();\n\n // Assert\n long pollDuration = endTime - startTime;\n assertTrue(pollDuration >= timeout.toMillis());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 76, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForMinAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForMinAgg", + "reference": " @Test\n public void testFullIndexSearchForMinAgg() {\n Instant time = Instant.now();\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new MinAggBuilder(\n \"test\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalMin internalMin =\n (InternalMin) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n assertThat(Double.valueOf(internalMin.getValue()).longValue()).isEqualTo(time.toEpochMilli());\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForMinAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new MinAggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation() instanceof MinAgg);\n assertTrue(result.getElapsedTimeMicros() > 0);\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 77, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_bsaChecksumFetched", + "reference": " @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n setupMocks();\n when(connection.getResponseCode()).thenReturn(SC_OK);\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream(\"bsa-checksum\\ndata\".getBytes(UTF_8)));\n LazyBlockList download = fetcher.tryFetch(BlockListType.BLOCK);\n assertThat(download.getName()).isEqualTo(BlockListType.BLOCK);\n assertThat(download.checksum()).isEqualTo(\"bsa-checksum\");\n verify(connection, times(1)).setRequestMethod(\"GET\");\n verify(connection, times(1)).setRequestProperty(\"Authorization\", \"Bearer authToken\");\n }", + "completions": [ + "```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n // Given\n URL dataUrl = new URL(\"https://example.com/bsaChecksum\");\n HttpsURLConnection connection = mock(HttpsURLConnection.class);\n when(connection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);\n when(connection.getErrorStream()).thenReturn(new ByteArrayInputStream(\"\".getBytes()));\n when(urlConnectionService.createConnection(dataUrl)).thenReturn(connection);\n\n // When\n LazyBlockList lazyBlockList = bsaService.tryFetch(BlockListType.BSA_CHECKSUM);\n\n // Then\n assertEquals(BlockListType.BSA_CHECKSUM, lazyBlockList.getBlockListType());\n assertEquals(connection, lazyBlockList.getConnection());\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 78, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDisabledFullTextSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDisabledFullTextSearch", + "reference": " @Test\n public void testDisabledFullTextSearch() {\n Instant time = Instant.ofEpochSecond(1593365471);\n Trace.KeyValue field1Tag =\n Trace.KeyValue.newBuilder()\n .setVInt32(1234)\n .setKey(\"field1\")\n .setFieldType(Schema.SchemaFieldType.INTEGER)\n .build();\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(2, \"apple baby\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(3, \"baby car 1234\", time.plusSeconds(4)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Without the _all field as default.\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(3);\n\n // empty string\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"app*\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Returns baby or car, 2 messages.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby car\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Test numbers\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"apple 1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"123\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n }", + "completions": [ + "```java\n @Test\n public void testDisabledFullTextSearch() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getResults().size());\n assertNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 79, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyWithException", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(\n OperateRuntimeException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHitsRelation.EQUAL_TO));\n Mockito.when(searchHits.getHits()).thenReturn(new SearchHit[0]);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processInstanceListViewService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException to be thrown\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find unique process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 80, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_SecurityException_notRetriable", + "reference": " @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n when(connectionService.createConnection(any(URL.class)))\n .thenThrow(new GeneralSecurityException());\n assertThat(\n assertThrows(BsaException.class, () -> fetcher.tryFetch(BlockListType.BLOCK))\n .isRetriable())\n .isFalse();\n }", + "completions": [ + "```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n // Arrange\n String blockListTypeName = \"TEST\";\n BlockListType blockListType = new BlockListType(blockListTypeName);\n URL dataUrl = new URL(\"http://example.com\");\n when(blockListUrls.get(blockListTypeName)).thenReturn(dataUrl.toString());\n when(urlConnectionService.createConnection(dataUrl)).thenThrow(new GeneralSecurityException());\n\n // Act and Assert\n assertThrows(BsaException.class, () -> {\n try {\n blockListService.tryFetch(blockListType);\n } catch (BsaException e) {\n assertFalse(e.isRetriable());\n throw e;\n }\n });\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 81, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssignOnEmptyTopicPartition()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssignOnEmptyTopicPartition", + "reference": " @Test\n public void testAssignOnEmptyTopicPartition() {\n consumer = newConsumer();\n completeUnsubscribeApplicationEventSuccessfully();\n\n consumer.assign(Collections.emptyList());\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().isEmpty());\n }", + "completions": [ + "```java\n @Test\n public void testAssignOnEmptyTopicPartition() {\n // Given\n Collection partitions = new ArrayList<>();\n ConsumerImpl consumer = new ConsumerImpl();\n\n // When\n consumer.assign(partitions);\n\n // Then\n // Add assertions to verify the expected behavior\n // For example, you can check if the consumer is unsubscribed and if the fetch buffer is empty\n // assertFalse(consumer.isSubscribed());\n // assertTrue(consumer.getFetchBuffer().isEmpty());\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 82, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ImmutableList applyLabelDiff(\n ImmutableList labels,\n IdnChecker idnChecker,\n DownloadSchedule schedule,\n DateTime now) {\n ImmutableList.Builder nonBlockedDomains = new ImmutableList.Builder<>();\n ImmutableMap> labelsByType =\n ImmutableMap.copyOf(\n labels.stream().collect(groupingBy(BlockLabel::labelType, toImmutableList())));\n\n tm().transact(\n () -> {\n for (Map.Entry> entry :\n labelsByType.entrySet()) {\n switch (entry.getKey()) {\n case CREATE:\n // With current Cloud SQL, label upsert throughput is about 200/second. If\n // better performance is needed, consider bulk insert in native SQL.\n tm().putAll(\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(\n label ->\n new BsaLabel(label.label(), schedule.jobCreationTime()))\n .collect(toImmutableList()));\n // May not find all unblockables due to race condition: DomainCreateFlow uses\n // cached BsaLabels. Eventually will be consistent.\n nonBlockedDomains.addAll(\n tallyUnblockableDomainsForNewLabels(entry.getValue(), idnChecker, now));\n break;\n case DELETE:\n ImmutableSet deletedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n // Delete labels in DB. Also cascade-delete BsaUnblockableDomain.\n int nDeleted = Queries.deleteBsaLabelByLabels(deletedLabels);\n if (nDeleted != deletedLabels.size()) {\n logger.atSevere().log(\n \"Only found %s entities among the %s labels: [%s]\",\n nDeleted, deletedLabels.size(), deletedLabels);\n }\n break;\n case NEW_ORDER_ASSOCIATION:\n ImmutableSet affectedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n ImmutableSet labelsInDb =\n Queries.queryBsaLabelByLabels(affectedLabels)\n .map(BsaLabel::getLabel)\n .collect(toImmutableSet());\n verify(\n labelsInDb.size() == affectedLabels.size(),\n \"Missing labels in DB: %s\",\n LazyArgs.lazy(() -> difference(affectedLabels, labelsInDb)));\n\n // Reuse registered and reserved names that are already computed.\n Queries.queryBsaUnblockableDomainByLabels(affectedLabels)\n .map(BsaUnblockableDomain::toUnblockableDomain)\n .forEach(nonBlockedDomains::add);\n\n for (BlockLabel label : entry.getValue()) {\n getInvalidTldsForLabel(label, idnChecker)\n .map(tld -> UnblockableDomain.of(label.label(), tld, Reason.INVALID))\n .forEach(nonBlockedDomains::add);\n }\n break;\n }\n }\n },\n TRANSACTION_REPEATABLE_READ);\n logger.atInfo().log(\"Processed %s of labels.\", labels.size());\n return nonBlockedDomains.build();\n }\n```\n###test function signature:\n```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "applyLabelDiffs_newAssociationOfLabelToOrder", + "reference": " @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n tm().transact(\n () -> {\n tm().insert(new BsaLabel(\"label\", fakeClock.nowUtc()));\n tm().insert(new BsaUnblockableDomain(\"label\", \"app\", Reason.REGISTERED));\n });\n when(idnChecker.getSupportingTlds(any())).thenReturn(ImmutableSet.of(app));\n when(idnChecker.getForbiddingTlds(any()))\n .thenReturn(Sets.difference(ImmutableSet.of(dev), ImmutableSet.of()).immutableCopy());\n\n ImmutableList unblockableDomains =\n applyLabelDiff(\n ImmutableList.of(\n BlockLabel.of(\"label\", LabelType.NEW_ORDER_ASSOCIATION, ImmutableSet.of())),\n idnChecker,\n schedule,\n fakeClock.nowUtc());\n assertThat(unblockableDomains)\n .containsExactly(\n UnblockableDomain.of(\"label.app\", UnblockableDomain.Reason.REGISTERED),\n UnblockableDomain.of(\"label.dev\", UnblockableDomain.Reason.INVALID));\n assertThat(tm().transact(() -> tm().loadByKeyIfPresent(BsaLabel.vKey(\"label\")))).isPresent();\n assertThat(\n tm().transact(() -> tm().loadByKeyIfPresent(BsaUnblockableDomain.vKey(\"label\", \"app\"))))\n .isPresent();\n }", + "completions": [ + "```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n // Given\n ImmutableList labels = ImmutableList.of(BlockLabel.of(\"label1\", LabelType.NEW_ORDER_ASSOCIATION));\n IdnChecker idnChecker = mock(IdnChecker.class);\n DownloadSchedule schedule = mock(DownloadSchedule.class);\n DateTime now = DateTime.now();\n\n // When\n ImmutableList result = applyLabelDiff(labels, idnChecker, schedule, now);\n\n // Then\n // Add assertions here to verify the result\n }\n```\n" + ], + "line": 85, + "token": 602, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 83, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture createTopic(CreateTopicRequest request) {\n CompletableFuture future = new CompletableFuture<>();\n for (; ; ) {\n if (metadataStore.isLeader()) {\n try (SqlSession session = metadataStore.openSession()) {\n if (!metadataStore.maintainLeadershipWithSharedLock(session)) {\n continue;\n }\n TopicMapper topicMapper = session.getMapper(TopicMapper.class);\n if (null != topicMapper.get(null, request.getTopic())) {\n ControllerException e = new ControllerException(Code.DUPLICATED_VALUE,\n String.format(\"Topic %s was taken\", request.getTopic()));\n future.completeExceptionally(e);\n return future;\n }\n\n Topic topic = new Topic();\n topic.setName(request.getTopic());\n topic.setQueueNum(request.getCount());\n topic.setStatus(TopicStatus.TOPIC_STATUS_ACTIVE);\n topic.setAcceptMessageTypes(JsonFormat.printer().print(request.getAcceptTypes()));\n topic.setRetentionHours(request.getRetentionHours());\n topicMapper.create(topic);\n long topicId = topic.getId();\n List assignments = createQueues(IntStream.range(0, request.getCount()),\n topicId, session);\n // Commit transaction\n session.commit();\n\n // Cache new topic and queue assignments immediately\n topicCache.apply(List.of(topic));\n assignmentCache.apply(assignments);\n future.complete(topicId);\n } catch (ControllerException | InvalidProtocolBufferException e) {\n future.completeExceptionally(e);\n }\n return future;\n } else {\n Optional leaderAddress = metadataStore.electionService().leaderAddress();\n if (leaderAddress.isEmpty()) {\n return CompletableFuture.failedFuture(new ControllerException(Code.NO_LEADER_VALUE, \"No leader is elected yet\"));\n }\n metadataStore.controllerClient().createTopic(leaderAddress.get(), request).whenComplete((res, e) -> {\n if (null != e) {\n future.completeExceptionally(e);\n } else {\n future.complete(res);\n }\n });\n }\n break;\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCreateTopic()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCreateTopic", + "reference": " @Test\n public void testCreateTopic() {\n MetadataStore metadataStore = Mockito.mock(MetadataStore.class);\n ControllerClient controllerClient = Mockito.mock(ControllerClient.class);\n Mockito.when(controllerClient.createTopic(ArgumentMatchers.anyString(), ArgumentMatchers.any())).thenReturn(\n CompletableFuture.failedFuture(new CompletionException(new RuntimeException()))\n );\n Mockito.when(metadataStore.controllerClient()).thenReturn(controllerClient);\n Mockito.when(metadataStore.isLeader()).thenReturn(false);\n\n ElectionService electionService = Mockito.mock(ElectionService.class);\n Mockito.when(metadataStore.electionService()).thenReturn(electionService);\n\n Mockito.when(electionService.leaderAddress()).thenReturn(Optional.of(\"localhost:1234\"));\n\n TopicManager topicManager = new TopicManager(metadataStore);\n topicManager.createTopic(CreateTopicRequest.newBuilder().build());\n }", + "completions": [ + "```java\n @Test\n public void testCreateTopic() {\n // Given\n CreateTopicRequest request = new CreateTopicRequest();\n request.setTopic(\"testTopic\");\n request.setCount(3);\n request.setRetentionHours(24);\n request.setAcceptTypes(List.of(MessageType.MESSAGE_TYPE_JSON));\n\n // When\n CompletableFuture future = controller.createTopic(request);\n\n // Then\n assertDoesNotThrow(() -> {\n Long topicId = future.get();\n assertNotNull(topicId);\n assertTrue(topicId > 0);\n });\n }\n```\n" + ], + "line": 63, + "token": 424, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 84, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull(){\n try(MockedStatic sessionMockedStatic = Mockito.mockStatic(Session.class)){\n Session mockSession = mock(Session.class);\n Transport transport = mock(Transport.class);\n sessionMockedStatic.when(()->Session.getDefaultInstance(any(Properties.class))).thenReturn(mockSession);\n when(blacklistService.inBlacklist(anyString())).thenReturn(false);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n when(settingsService.getByCategoryAndKey(anyString(),anyString())).thenReturn(\"test\").thenReturn(5678).thenReturn(\"test\").thenReturn(\"test\").thenReturn(\"123456\");\n when(mockSession.getTransport(anyString())).thenReturn(transport);\n doNothing().when(transport).connect(anyString(),anyInt(),anyString(),anyString());\n doAnswer(invocationOnMock -> {\n InternetAddress[] internetAddressList = invocationOnMock.getArgument(1);\n Assertions.assertEquals(addressList.get(0),internetAddressList[0].getAddress());\n return null;\n }).when(transport).sendMessage(any(MimeMessage.class),any());\n mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n } catch (MessagingException e) {\n throw new RuntimeException(e);\n }\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = Arrays.asList(\"test1@example.com\", \"test2@example.com\");\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://example.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 85, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead", + "reference": " @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 1,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n AstraConfigs.IndexerConfig headLocationAndNoRecoveryConfig =\n AstraConfigs.IndexerConfig.newBuilder()\n .setCreateRecoveryTasksOnStart(false)\n .setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST)\n .build();\n\n // When there is no data and ReadFromLocationOnStart is set to LATEST, return the current head\n assertThat(\n recoveryTaskCreator.determineStartingOffset(1000, 0, headLocationAndNoRecoveryConfig))\n .isEqualTo(1000);\n\n // Data exists for not for this partition.\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTime = 1;\n final long endTime = 100;\n final long maxOffset = 100;\n\n final SnapshotMetadata partition1 =\n new SnapshotMetadata(name, path, startTime, endTime, maxOffset, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition1);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition1));\n\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n\n final SnapshotMetadata partition11 =\n new SnapshotMetadata(\n name + \"1\", path, endTime + 1, endTime * 2, maxOffset * 2, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition11);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition11));\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(partition1, partition11);\n assertThat(recoveryTaskCreator.determineStartingOffset(0, 0, indexerConfig)).isNegative();\n\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n \"2\",\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(false);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n\n // When\n long result = determineStartingOffset(currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(currentEndOffsetForPartition, result);\n }\n```\n" + ], + "line": 128, + "token": 935, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 86, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime =\n LocalDateTime.of(2020, 10, 1, 10, 10, 0).atZone(ZoneOffset.UTC).toInstant();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n long startOffset = 1;\n long endOffset = msgsToProduce - 1;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry))\n .isEqualTo(endOffset - startOffset + 1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(0);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(0);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isEmpty();\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 0,\n 100,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 87, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"3g\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n String expirationTime = \"10x\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(expirationTime);\n assertEquals(CacheControlHeader.DEFAULT_BASE_VALUE + CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 88, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_empty_collection()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_empty_collection", + "reference": " @Test\n public void should_clone_empty_collection() {\n List original = new ArrayList<>();\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_empty_collection() {\n // Given\n Collection collection = new ArrayList<>();\n\n // When\n Collection clonedCollection = (Collection) clone(collection);\n\n // Then\n assertNotSame(collection, clonedCollection);\n assertEquals(collection, clonedCollection);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 89, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testTimeBoundSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testTimeBoundSearch", + "reference": " @Test\n public void testTimeBoundSearch() {\n Instant time = Instant.now();\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(1, time));\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(2, time.plusSeconds(100)));\n strictLogStore.logStore.commit();\n strictLogStore.logStore.refresh();\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(2);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(0);\n assertThat(getTimerCount(REFRESHES_TIMER, strictLogStore.metricsRegistry)).isEqualTo(1);\n\n // Start inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Extended range still only picking one element.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(90).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Both ranges are inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.toEpochMilli(),\n time.plusSeconds(100).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n // Extended range to pick up both events\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(1000).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n }", + "completions": [ + "```java\n @Test\n public void testTimeBoundSearch() {\n // Given\n String dataset = \"test_dataset\";\n String queryStr = \"test_query\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n for (LogMessage logMessage : result.getResults()) {\n assertTrue(logMessage.getTimestampMs() >= startTimeMsEpoch);\n assertTrue(logMessage.getTimestampMs() <= endTimeMsEpoch);\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 90, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "\t@ParameterizedTest\n\t@MethodSource", + "reference": "\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\t// setup\n\t\tString key = \"key\";\n\t\tTestParameters parameters = new TestParameters();\n\t\tObject testValue = theParams.getTestValue();\n\n\t\t// test\n\t\tif (theParams.isExpectedToWork()) {\n\t\t\tparameters.setUserData(key, testValue);\n\t\t\tassertFalse(parameters.getUserData().isEmpty());\n\t\t\tassertEquals(testValue, parameters.getUserData().get(key));\n\t\t} else {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, testValue);\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tString dataType = testValue.getClass().getName();\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid data type provided \" + dataType),\n\t\t\t\t\tex.getMessage());\n\t\t\t\tassertTrue(parameters.getUserData().isEmpty());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\tUser user = new User();\n\t\tuser.setUserData(theParams.getKey(), theParams.getValue());\n\t\tassertEquals(theParams.getValue(), user.getUserData().get(theParams.getKey()));\n\t}\n\n\tstatic Stream parameters() {\n\t\treturn Stream.of(\n\t\t\t\tnew TestParam(\"stringKey\", \"stringValue\"),\n\t\t\t\tnew TestParam(\"numberKey\", 123),\n\t\t\t\tnew TestParam(\"booleanKey\", true),\n\t\t\t\tnew TestParam(\"nullKey\", null)\n\t\t);\n\t}\n\n\tstatic class TestParam {\n\t\tprivate final String key;\n\t\tprivate final Object value;\n\n\t\tTestParam(String key, Object value) {\n\t\t\tthis.key = key;\n\t\t\tthis.value = value;\n\t\t}\n\n\t\tString getKey() {\n\t\t\treturn key;\n\t\t}\n\n\t\tObject getValue() {\n\t\t\treturn value;\n\t\t}\n\t}\n```\n" + ], + "line": 24, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\n\t\tString text = \"Hello, how are you?\";\n\n\t\tEmbedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));\n\n\t\tserver\n\t\t\t.expect(requestToUriTemplate(\"/models/{generative}:embedText?key={apiKey}\",\n\t\t\t\t\tVertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))\n\t\t\t.andExpect(method(HttpMethod.POST))\n\t\t\t.andExpect(content().json(objectMapper.writeValueAsString(Map.of(\"text\", text))))\n\t\t\t.andRespond(withSuccess(objectMapper.writeValueAsString(Map.of(\"embedding\", expectedEmbedding)),\n\t\t\t\t\tMediaType.APPLICATION_JSON));\n\n\t\tEmbedding embedding = client.embedText(text);\n\n\t\tassertThat(embedding).isEqualTo(expectedEmbedding);\n\n\t\tserver.verify();\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\tEmbeddingResponse response = new EmbeddingResponse(expectedEmbedding);\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(response);\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Mono syncAclWithAclCsv(KafkaCluster cluster, String csv) {\n return adminClientService.get(cluster)\n .flatMap(ac -> ac.listAcls(ResourcePatternFilter.ANY).flatMap(existingAclList -> {\n var existingSet = Set.copyOf(existingAclList);\n var newAcls = Set.copyOf(AclCsv.parseCsv(csv));\n var toDelete = Sets.difference(existingSet, newAcls);\n var toAdd = Sets.difference(newAcls, existingSet);\n logAclSyncPlan(cluster, toAdd, toDelete);\n if (toAdd.isEmpty() && toDelete.isEmpty()) {\n return Mono.empty();\n }\n log.info(\"Starting new ACLs creation\");\n return ac.createAcls(toAdd)\n .doOnSuccess(v -> {\n log.info(\"{} new ACLs created\", toAdd.size());\n log.info(\"Starting ACLs deletion\");\n })\n .then(ac.deleteAcls(toDelete)\n .doOnSuccess(v -> log.info(\"{} ACLs deleted\", toDelete.size())));\n }));\n }\n```\n###test function signature:\n```java\n @Test\n void testSyncAclWithAclCsv()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSyncAclWithAclCsv", + "reference": " @Test\n void testSyncAclWithAclCsv() {\n var existingBinding1 = new AclBinding(\n new ResourcePattern(ResourceType.TOPIC, \"*\", PatternType.LITERAL),\n new AccessControlEntry(\"User:test1\", \"*\", AclOperation.READ, AclPermissionType.ALLOW));\n\n var existingBinding2 = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"group1\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test2\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n var newBindingToBeAdded = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"groupNew\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test3\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n when(adminClientMock.listAcls(ResourcePatternFilter.ANY))\n .thenReturn(Mono.just(List.of(existingBinding1, existingBinding2)));\n\n ArgumentCaptor> createdCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.createAcls(createdCaptor.capture()))\n .thenReturn(Mono.empty());\n\n ArgumentCaptor> deletedCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.deleteAcls(deletedCaptor.capture()))\n .thenReturn(Mono.empty());\n\n aclsService.syncAclWithAclCsv(\n CLUSTER,\n \"Principal,ResourceType, PatternType, ResourceName,Operation,PermissionType,Host\\n\"\n + \"User:test1,TOPIC,LITERAL,*,READ,ALLOW,*\\n\"\n + \"User:test3,GROUP,PREFIXED,groupNew,DESCRIBE,DENY,localhost\"\n ).block();\n\n Collection createdBindings = createdCaptor.getValue();\n assertThat(createdBindings)\n .hasSize(1)\n .contains(newBindingToBeAdded);\n\n Collection deletedBindings = deletedCaptor.getValue();\n assertThat(deletedBindings)\n .hasSize(1)\n .contains(existingBinding2);\n }", + "completions": [ + "```java\n @Test\n void testSyncAclWithAclCsv() {\n // Given\n KafkaCluster cluster = new KafkaCluster(\"testCluster\", \"localhost:9092\");\n String csv = \"principal,host,operation,permissionType\\nUser:Alice,*,Read,Allow\\nUser:Bob,*,Write,Allow\";\n AdminClient adminClient = Mockito.mock(AdminClient.class);\n when(adminClientService.get(cluster)).thenReturn(Mono.just(adminClient));\n when(adminClient.listAcls(ResourcePatternFilter.ANY)).thenReturn(Flux.just(\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Alice\", \"*\", AclOperation.READ, AclPermissionType.ALLOW)),\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Bob\", \"*\", AclOperation.WRITE, AclPermissionType.ALLOW))\n ).collectList());\n\n // When\n Mono result = service.syncAclWithAclCsv(cluster, csv);\n\n // Then\n StepVerifier.create(result)\n .verifyComplete();\n\n verify(adminClient).createAcls(any());\n verify(adminClient).deleteAcls(any());\n }\n```\n" + ], + "line": 32, + "token": 259, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 93, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyNoResults", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final SearchResponse mockResponse = mock(SearchResponse.class);\n final SearchHits mockHits = mock(SearchHits.class);\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n when(mockHits.getHits()).thenReturn(new SearchHit[0]);\n when(tenantAwareClient.search(any(SearchRequest.class))).thenReturn(mockResponse);\n\n // When\n try {\n processInstanceService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 94, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyMultipleRecoveryBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(3);\n assertThat(recoveryTasks).contains(recoveryTask1, recoveryTask11);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo(recoveryStartOffset * 3 + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(100);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 95, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n private String saveAndChecksumBlockList(String jobName, LazyBlockList blockList) {\n BlobId blobId = getBlobId(jobName, getBlockListFileName(blockList.getName()));\n try (BufferedOutputStream gcsWriter =\n new BufferedOutputStream(gcsUtils.openOutputStream(blobId))) {\n MessageDigest messageDigest = MessageDigest.getInstance(checksumAlgorithm);\n blockList.consumeAll(\n (byteArray, length) -> {\n try {\n gcsWriter.write(byteArray, 0, length);\n } catch (IOException e) {\n throw new RuntimeException(e);\n }\n messageDigest.update(byteArray, 0, length);\n });\n return base16().lowerCase().encode(messageDigest.digest());\n } catch (IOException | NoSuchAlgorithmException e) {\n throw new RuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "saveAndChecksumBlockList_success", + "reference": " @Test\n void saveAndChecksumBlockList_success() throws Exception {\n String payload = \"somedata\\n\";\n String payloadChecksum = \"0737c8e591c68b93feccde50829aca86a80137547d8cfbe96bab6b20f8580c63\";\n\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream((\"bsa-checksum\\n\" + payload).getBytes(UTF_8)));\n lazyBlockList = new LazyBlockList(BlockListType.BLOCK, connection);\n\n ImmutableMap checksums =\n gcsClient.saveAndChecksumBlockList(\"some-name\", ImmutableList.of(lazyBlockList));\n assertThat(gcsUtils.existsAndNotEmpty(BlobId.of(\"my-bucket\", \"some-name/BLOCK.csv\"))).isTrue();\n assertThat(checksums).containsExactly(BlockListType.BLOCK, payloadChecksum);\n assertThat(gcsClient.readBlockList(\"some-name\", BlockListType.BLOCK))\n .containsExactly(\"somedata\");\n }", + "completions": [ + "```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception {\n // Given\n String jobName = \"testJob\";\n LazyBlockList blockList = mock(LazyBlockList.class);\n BlobId blobId = mock(BlobId.class);\n BufferedOutputStream gcsWriter = mock(BufferedOutputStream.class);\n MessageDigest messageDigest = mock(MessageDigest.class);\n when(gcsUtils.openOutputStream(blobId)).thenReturn(gcsWriter);\n when(checksumAlgorithm).thenReturn(\"SHA-256\");\n when(messageDigest.digest()).thenReturn(new byte[] {1, 2, 3, 4, 5});\n when(base16().lowerCase().encode(messageDigest.digest())).thenReturn(\"0102030405\");\n\n // When\n String checksum = saveAndChecksumBlockList(jobName, blockList);\n\n // Then\n assertEquals(\"0102030405\", checksum);\n verify(gcsWriter).write(any(), anyInt(), anyInt());\n verify(messageDigest).update(any(), anyInt(), anyInt());\n }\n```\n" + ], + "line": 30, + "token": 211, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 96, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ListenableFuture> findMissingBlobs(\n Iterable blobDigests, RequestMetadata requestMetadata) {\n // Some requests have been blocked, and we should tell the client we refuse to perform a lookup.\n try {\n if (inDenyList(requestMetadata)) {\n return immediateFailedFuture(\n Status.UNAVAILABLE\n .withDescription(\"The action associated with this request is forbidden\")\n .asException());\n }\n } catch (IOException e) {\n return immediateFailedFuture(Status.fromThrowable(e).asException());\n }\n\n // Empty blobs are an exceptional case. Filter them out.\n // If the user only requested empty blobs we can immediately tell them we already have it.\n Iterable nonEmptyDigests =\n Iterables.filter(blobDigests, (digest) -> digest.getSizeBytes() != 0);\n if (Iterables.isEmpty(nonEmptyDigests)) {\n return immediateFuture(ImmutableList.of());\n }\n\n if (configs.getServer().isFindMissingBlobsViaBackplane()) {\n return findMissingBlobsViaBackplane(nonEmptyDigests, requestMetadata);\n }\n\n return findMissingBlobsQueryingEachWorker(nonEmptyDigests, requestMetadata);\n }\n```\n###test function signature:\n```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "findMissingBlobsTest_ViaBackPlane", + "reference": " @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n Set activeWorkers = ImmutableSet.of(\"worker1\", \"worker2\", \"worker3\");\n Set expiredWorkers = ImmutableSet.of(\"workerX\", \"workerY\", \"workerZ\");\n Set imposterWorkers = ImmutableSet.of(\"imposter1\", \"imposter2\", \"imposter3\");\n\n Set availableDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFound1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build());\n\n Set missingDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"missing1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build());\n\n Set digestAvailableOnImposters =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFoundOnImposter1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter3\").setSizeBytes(1).build());\n\n Set emptyDigests =\n new HashSet<>(\n Arrays.asList(\n Digest.newBuilder().setHash(\"empty1\").build(),\n Digest.newBuilder().setHash(\"empty2\").build()));\n\n Iterable allDigests =\n Iterables.concat(\n availableDigests,\n missingDigests,\n emptyDigests,\n digestAvailableOnImposters,\n Arrays.asList(\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build()));\n\n Map> digestAndWorkersMap = new HashMap<>();\n\n for (Digest digest : availableDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(activeWorkers));\n }\n for (Digest digest : missingDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(expiredWorkers));\n }\n for (Digest digest : digestAvailableOnImposters) {\n digestAndWorkersMap.put(digest, getRandomSubset(imposterWorkers));\n }\n\n BuildfarmConfigs buildfarmConfigs = instance.getBuildFarmConfigs();\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(true);\n Set activeAndImposterWorkers =\n Sets.newHashSet(Iterables.concat(activeWorkers, imposterWorkers));\n\n when(mockBackplane.getStorageWorkers()).thenReturn(activeAndImposterWorkers);\n when(mockBackplane.getBlobDigestsWorkers(any(Iterable.class))).thenReturn(digestAndWorkersMap);\n when(mockInstanceLoader.load(anyString())).thenReturn(mockWorkerInstance);\n when(mockWorkerInstance.findMissingBlobs(anyIterable(), any(RequestMetadata.class)))\n .thenReturn(Futures.immediateFuture(new ArrayList<>()));\n\n long serverStartTime = 1686951033L; // june 15th, 2023\n Map workersStartTime = new HashMap<>();\n for (String worker : activeAndImposterWorkers) {\n workersStartTime.put(worker, serverStartTime);\n }\n when(mockBackplane.getWorkersStartTimeInEpochSecs(activeAndImposterWorkers))\n .thenReturn(workersStartTime);\n long oneDay = 86400L;\n for (Digest digest : availableDigests) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime + oneDay);\n }\n for (Digest digest : digestAvailableOnImposters) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime - oneDay);\n }\n\n Iterable actualMissingDigests =\n instance.findMissingBlobs(allDigests, RequestMetadata.getDefaultInstance()).get();\n Iterable expectedMissingDigests =\n Iterables.concat(missingDigests, digestAvailableOnImposters);\n\n assertThat(actualMissingDigests).containsExactlyElementsIn(expectedMissingDigests);\n verify(mockWorkerInstance, atMost(3))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n verify(mockWorkerInstance, atLeast(1))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n\n for (Digest digest : actualMissingDigests) {\n assertThat(digest).isNotIn(availableDigests);\n assertThat(digest).isNotIn(emptyDigests);\n assertThat(digest).isIn(expectedMissingDigests);\n }\n\n // reset BuildfarmConfigs\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(false);\n }", + "completions": [ + "```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n // Given\n Iterable blobDigests = Arrays.asList(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n RequestMetadata requestMetadata = RequestMetadata.newBuilder().setClientId(\"client1\").build();\n\n // When\n ListenableFuture> future =\n underTest.findMissingBlobs(blobDigests, requestMetadata);\n\n // Then\n Iterable missingBlobs = future.get();\n assertThat(missingBlobs).containsExactly(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n }\n```\n" + ], + "line": 37, + "token": 276, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 97, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent_notFound()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent_notFound", + "reference": " @Test\n public void testReadComponent_notFound() {\n Component component = mock(Component.class);\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(null);\n try {\n underTest.readComponent(\"someid\", \"testRepositoryName\");\n fail(\"Exception should have been thrown\");\n }\n catch (WebApplicationException e) {\n assertThat(e.getResponse(), is(notNullValue()));\n assertThat(e.getResponse().getStatus(), is(404));\n }\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent_notFound() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(() -> storageTx);\n when(storageTx.findComponent(componentId)).thenReturn(null);\n\n // When\n try {\n readComponent(repository, componentId);\n fail(\"Expected WebApplicationException\");\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Status.NOT_FOUND.getStatusCode(), e.getResponse().getStatus());\n }\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture auth(MQTT3AuthData authData) {\n try {\n HttpRequest request = HttpRequest.newBuilder()\n .uri(URI.create(webhookURI + \"/auth\"))\n .header(\"Content-Type\", \"application/json\")\n .POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))\n .timeout(Duration.ofSeconds(5))\n .build();\n return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())\n .thenApply(response -> {\n if (response.statusCode() == 200) {\n try {\n MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();\n JsonFormat.parser()\n .ignoringUnknownFields()\n .merge(response.body(), resultBuilder);\n return resultBuilder.build();\n } catch (Throwable e) {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build();\n }\n } else {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(\"Authenticate failed\")\n .build())\n .build();\n }\n })\n .exceptionally(e -> {\n System.out.println(\"Failed to call webhook: \" + e.getMessage());\n return null;\n });\n } catch (Throwable e) {\n return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAuth()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAuth", + "reference": " @Test\n public void testAuth() {\n MQTT3AuthData authedUser = MQTT3AuthData.newBuilder()\n .setUsername(\"authUser\")\n .build();\n\n MQTT3AuthData unauthUser = MQTT3AuthData.newBuilder()\n .setUsername(\"unauthUser\")\n .build();\n testServer.addAuthedUser(authedUser.getUsername());\n WebHookBasedAuthProvider provider = new WebHookBasedAuthProvider(testServer.getURI());\n MQTT3AuthResult authResult = provider.auth(authedUser).join();\n assertTrue(authResult.hasOk());\n\n authResult = provider.auth(unauthUser).join();\n assertFalse(authResult.hasOk());\n }", + "completions": [ + "```java\n @Test\n public void testAuth() {\n // Given\n String webhookURI = \"http://localhost:8080\";\n HttpClient httpClient = HttpClient.newHttpClient();\n MQTT3AuthData authData = MQTT3AuthData.newBuilder()\n .setClientId(\"client1\")\n .setUsername(\"user1\")\n .setPassword(\"password1\")\n .build();\n MQTT3AuthResult expectedResult = MQTT3AuthResult.newBuilder()\n .setSuccess(Success.newBuilder()\n .setIsSuperuser(false)\n .build())\n .build();\n\n // When\n CompletableFuture result = new MQTT3AuthService(webhookURI, httpClient).auth(authData);\n\n // Then\n assertNotNull(result);\n assertEquals(expectedResult, result.join());\n }\n```\n" + ], + "line": 59, + "token": 335, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 99, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_success", + "reference": " @Test\n public void testValidateRoleDuplicate_success() {\n // 调用,不会抛异常\n roleService.validateRoleDuplicate(randomString(), randomString(), null);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_success() {\n // Given\n String name = \"TestRole\";\n String code = \"TEST_ROLE\";\n Long id = 1L;\n\n // When\n // Call the method to be tested\n validateRoleDuplicate(name, code, id);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 1, + "token_diff": 1 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/5/QS/line_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Java/5/QS/line_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..92cc54fee726d0e28ea02c156383ca7e8f45657f --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/5/QS/line_counts_QS.csv @@ -0,0 +1,8 @@ +Models,subset_0(22~29),subset_1(30~39),subset_2(41~47),subset_3(47~85),subset_4(85~128) +StarCoder2-15b,33.66,29.03,36.27,26.50,19.25 +CodeLlama-7b,35.78,29.99,36.10,27.81,22.25 +CodeLlama-13b,30.86,27.90,34.93,26.51,21.12 +CodeLlama-34b,33.25,27.24,35.53,27.23,22.56 +DeepSeek-Coder-1.3b,28.50,29.12,34.01,30.23,21.31 +DeepSeek-Coder-6.7b,20.00,21.47,19.97,22.38,22.03 +DeepSeek-Coder-33b,32.10,33.62,39.30,29.49,20.99 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/5/QS/token_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Java/5/QS/token_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..730c8998a7806894fec2c1019ad9b74fde257819 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/5/QS/token_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~209),subset_1(211~300),subset_2(300~365),subset_3(366~544),subset_4(602~952) +StarCoder2-15b,33.66,27.58,36.39,24.99,22.10 +CodeLlama-7b,35.78,31.14,34.89,28.17,21.95 +CodeLlama-13b,30.86,28.94,33.44,26.35,21.72 +CodeLlama-34b,33.25,26.90,37.47,25.52,22.66 +DeepSeek-Coder-1.3b,28.50,29.56,34.82,28.26,22.04 +DeepSeek-Coder-6.7b,20.00,21.50,22.60,18.84,22.92 +DeepSeek-Coder-33b,32.10,33.27,38.51,28.64,22.98 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/5/QS/tongji.py b/dataset/Test Generation/ComplexCodeEval-Java/5/QS/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/5/QS/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/6/EI/EI.json b/dataset/Test Generation/ComplexCodeEval-Java/6/EI/EI.json new file mode 100644 index 0000000000000000000000000000000000000000..94df7d0d42d5850d3e1df6d8ed604a914f362279 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/6/EI/EI.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTask", + "reference": " @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(TEST_S3_BUCKET);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 100L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockKafkaConsumer.consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(true);\n Mockito.when(mockKafkaConsumer.close()).thenReturn(true);\n Mockito.when(mockChunkManager.stopAsync()).thenReturn(true);\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenReturn(true);\n Mockito.when(AstraKafkaConsumer.makeKafkaConfig(Mockito.any(), Mockito.anyInt())).thenReturn(new Properties());\n Mockito.when(RecoveryChunkManager.fromConfig(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(mockChunkManager);\n Mockito.when(adminClient.listConsumerGroupOffsets(Mockito.anyString())).thenReturn(new ConsumerGroupOffsets());\n Mockito.when(AstraConfig.getRecoveryConfig()).thenReturn(new RecoveryConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig()).thenReturn(new KafkaConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic()).thenReturn(\"testTopic\");\n Mockito.when(AstraConfig.getIndexerConfig()).thenReturn(new IndexerConfig());\n Mockito.when(AstraConfig.getS3Config()).thenReturn(new S3Config());\n Mockito.when(meterRegistry.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class));\n Mockito.when(Timer.start(Mockito.any())).thenReturn(Mockito.mock(Timer.Sample.class));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n Assert.assertTrue(result);\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).prepConsumerForConsumption(Mockito.anyLong());\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong());\n Mockito.verify(mockChunkManager, Mockito.times(1)).waitForRollOvers();\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).close();\n Mockito.verify(mockChunkManager, Mockito.times(1)).stopAsync();\n Mockito.verify(mockChunkManager, Mockito.times(1)).awaitTerminated(Mockito.anyLong());\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 4, + "token_diff": 2 + }, + { + "id": 1, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskFailure", + "reference": " @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n String fakeS3Bucket = \"fakeBucket\";\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(fakeS3Bucket);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isFalse();\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 10L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(false);\n Mockito.when(mockKafkaConsumer.close()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.stopAsync()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertFalse(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 4, + "token_diff": 2 + }, + { + "id": 2, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public int migration(String srcTable, List uuids, String targetTable) {\n if (Objects.isNull(srcTable) || Objects.isNull(targetTable) || CollectionUtils.isEmpty(uuids)) {\n throw new SwValidationException(\n ValidSubject.EVALUATION, \"srcTable, targetTable and uuids must not be null\");\n }\n var filter = TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build();\n return store.migration(DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(filter)\n .build()\n );\n }\n```\n###test function signature:\n```java\n @Test\n public void testMigration()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMigration", + "reference": " @Test\n public void testMigration() {\n var srcTable = \"p/1/space/1/eval/summary\";\n var targetTable = \"p/1/eval/summary\";\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(null, List.of(\"uuid1\"), targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, null, targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, List.of(\"uuid1\"), null));\n\n evaluationRepo.migration(srcTable, List.of(\"uuid1\", \"uuid2\"), targetTable);\n\n verify(dataStore, times(1)).migration(\n DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(\n TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(List.of(\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid1\")))\n .build(),\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid2\")))\n .build()))\n .build())\n .build()\n );\n }", + "completions": [ + "```java\n @Test\n public void testMigration() {\n // Given\n String srcTable = \"srcTable\";\n List uuids = Arrays.asList(\"uuid1\", \"uuid2\", \"uuid3\");\n String targetTable = \"targetTable\";\n DataStoreMigrationRequest expectedRequest = DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build()\n )\n .build();\n when(store.migration(any(DataStoreMigrationRequest.class))).thenReturn(1);\n\n // When\n int result = migrationService.migration(srcTable, uuids, targetTable);\n\n // Then\n assertEquals(1, result);\n verify(store).migration(expectedRequest);\n }\n```\n" + ], + "line": 37, + "token": 221, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDatasMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDatasMapTest", + "reference": " @Test\n public void sendDatasMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"doris_test_db\";\n List> dataList = new ArrayList<>();\n for (int i = 0; i < 100; i++) {\n Map data = new HashMap<>();\n data.put(\"id\", i);\n data.put(\"name\", \"张三\");\n data.put(\"message\", \"股东风波的法规的规定翻跟斗广泛的给yryrtytr\");\n System.out.println(gson.toJson(data));\n dataList.add(data);\n\n }\n dorisStreamLoad.sendData(database, tableName, columnList, dataList);\n }", + "completions": [ + "```java\n @Test\n public void sendDataTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 4, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_serializable_key_and_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_serializable_key_and_value", + "reference": " @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map original = new HashMap<>();\n original.put(\"key\", new SerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map originalMap = new HashMap<>();\n originalMap.put(\"one\", 1);\n originalMap.put(\"two\", 2);\n originalMap.put(\"three\", 3);\n\n Map clonedMap = clone(originalMap);\n\n assertNotSame(originalMap, clonedMap);\n assertEquals(originalMap, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 5, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void encode(String schemaName, long nextTransactionOffset) {\n if (schemaName == null) {\n schemaName = UNKNOWN;\n }\n this.schemaName = schemaName;\n this.nextTransactionOffset = nextTransactionOffset;\n final byte[] payloadBytes = payloadToBytes();\n final int payloadLength = payloadBytes.length;\n\n // set logEventHeader\n int eventSize = eventHeaderLengthVersionGt1 + payloadLength;\n setLogEventHeader(new LogEventHeader(drc_filter_log_event.getType(), 0, eventSize, nextTransactionOffset)\n );\n\n // set payload\n final ByteBuf payloadByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(payloadLength);\n payloadByteBuf.writeBytes(payloadBytes);\n payloadByteBuf.skipBytes(payloadLength);\n setPayloadBuf(payloadByteBuf);\n }\n```\n###test function signature:\n```java\n @Test\n public void encode()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "encode", + "reference": " @Test\n public void encode() {\n FilterLogEvent filterLogEvent = new FilterLogEvent();\n filterLogEvent.encode(\"drc1\", 101);\n Assert.assertEquals(\"drc1\", filterLogEvent.getSchemaName());\n Assert.assertEquals(101, filterLogEvent.getNextTransactionOffset());\n\n ByteBuf headerBuf = filterLogEvent.getLogEventHeader().getHeaderBuf();\n headerBuf.readerIndex(0);\n ByteBuf payloadBuf = filterLogEvent.getPayloadBuf();\n payloadBuf.readerIndex(0);\n CompositeByteBuf compositeByteBuf = PooledByteBufAllocator.DEFAULT.compositeDirectBuffer();\n compositeByteBuf.addComponents(true, headerBuf, payloadBuf);\n\n FilterLogEvent newFilterLogEvent = new FilterLogEvent();\n newFilterLogEvent.read(compositeByteBuf);\n compositeByteBuf.release(compositeByteBuf.refCnt());\n Assert.assertEquals(\"drc1\", newFilterLogEvent.getSchemaName());\n Assert.assertEquals(101, newFilterLogEvent.getNextTransactionOffset());\n }", + "completions": [ + "```java\n @Test\n public void encode() {\n // Given\n String schemaName = \"testSchema\";\n long nextTransactionOffset = 1234567890L;\n drc_filter_log_event event = new drc_filter_log_event();\n\n // When\n event.encode(schemaName, nextTransactionOffset);\n\n // Then\n assertEquals(schemaName, event.getSchemaName());\n assertEquals(nextTransactionOffset, event.getNextTransactionOffset());\n assertEquals(drc_filter_log_event.getType(), event.getLogEventHeader().getType());\n assertEquals(0, event.getLogEventHeader().getServerId());\n assertEquals(eventHeaderLengthVersionGt1 + event.getPayloadBuf().readableBytes(), event.getLogEventHeader().getEventLength());\n assertEquals(nextTransactionOffset, event.getLogEventHeader().getNextTransactionOffset());\n assertEquals(event.getPayloadBuf().readableBytes(), event.getPayloadBuf().readableBytes());\n }\n```\n" + ], + "line": 29, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void put(String path, long position, ByteBuf data) {\n writeLock.lock();\n try {\n int dataLength = data.readableBytes();\n NavigableMap cache = path2cache.computeIfAbsent(path, k -> new TreeMap<>());\n Map.Entry pos2value = cache.floorEntry(position);\n long cacheStartPosition;\n Value value;\n if (pos2value == null || pos2value.getKey() + pos2value.getValue().dataLength < position) {\n cacheStartPosition = position;\n value = Value.EMPTY;\n } else {\n cacheStartPosition = pos2value.getKey();\n value = pos2value.getValue();\n }\n // ensure the capacity, if the capacity change then update the cache index\n int moreCapacity = (int) ((position + dataLength) - (cacheStartPosition + value.blocks.length * (long) blockSize));\n int newDataLength = (int) (position + dataLength - cacheStartPosition);\n if (moreCapacity > 0) {\n int[] blocks = ensureCapacity(cacheStartPosition, moreCapacity);\n if (blocks == null) {\n return;\n }\n int[] newBlocks = new int[value.blocks.length + blocks.length];\n System.arraycopy(value.blocks, 0, newBlocks, 0, value.blocks.length);\n System.arraycopy(blocks, 0, newBlocks, value.blocks.length, blocks.length);\n value = new Value(newBlocks, newDataLength);\n } else {\n value = new Value(value.blocks, newDataLength);\n }\n cache.put(cacheStartPosition, value);\n lru.put(new Key(path, cacheStartPosition), value);\n\n // write data to cache\n ByteBuffer cacheByteBuffer = this.cacheByteBuffer.duplicate();\n int positionDelta = (int) (position - cacheStartPosition);\n int written = 0;\n ByteBuffer[] nioBuffers = data.nioBuffers();\n int[] blocks = value.blocks;\n for (ByteBuffer nioBuffer : nioBuffers) {\n ByteBuf buf = Unpooled.wrappedBuffer(nioBuffer);\n while (buf.readableBytes() > 0) {\n int writePosition = positionDelta + written;\n int block = blocks[writePosition / blockSize];\n cacheByteBuffer.position(block * blockSize + writePosition % blockSize);\n int length = Math.min(buf.readableBytes(), blockSize - writePosition % blockSize);\n cacheByteBuffer.put(buf.slice(buf.readerIndex(), length).nioBuffer());\n buf.skipBytes(length);\n written += length;\n }\n }\n } finally {\n writeLock.unlock();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testMergePut() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMergePut", + "reference": " @Test\n public void testMergePut() throws IOException {\n FileCache fileCache = new FileCache(\"/tmp/file_cache_test\", 10 * 1024, 1024);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 1, 500));\n buf.addComponent(true, genBuf((byte) 2, 500));\n buf.addComponent(true, genBuf((byte) 3, 500));\n fileCache.put(\"test\", 3333L, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(1500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 4, 500));\n buf.addComponent(true, genBuf((byte) 5, 500));\n buf.addComponent(true, genBuf((byte) 6, 500));\n fileCache.put(\"test\", 3333L + 1000, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(2500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 7, 500));\n fileCache.put(\"test\", 3333L + 1000 + 1500, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(3000, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n\n assertTrue(verify(fileCache.get(\"test\", 3333L, 500).get(), (byte) 1));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 500, 500).get(), (byte) 2));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1000, 500).get(), (byte) 4));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1500, 500).get(), (byte) 5));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2000, 500).get(), (byte) 6));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2500, 500).get(), (byte) 7));\n }", + "completions": [ + "```java\n @Test\n public void testMergePut() throws IOException {\n // Given\n String path = \"/test/path\";\n long position = 100;\n ByteBuf data = Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5});\n\n // When\n cache.put(path, position, data);\n\n // Then\n NavigableMap cacheMap = cache.path2cache.get(path);\n assertNotNull(cacheMap);\n Map.Entry entry = cacheMap.floorEntry(position);\n assertNotNull(entry);\n assertEquals(position, entry.getKey().longValue());\n assertEquals(5, entry.getValue().dataLength);\n assertEquals(1, entry.getValue().blocks[0]);\n assertEquals(2, entry.getValue().blocks[1]);\n assertEquals(3, entry.getValue().blocks[2]);\n assertEquals(4, entry.getValue().blocks[3]);\n assertEquals(5, entry.getValue().blocks[4]);\n }\n```\n" + ], + "line": 65, + "token": 508, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 7, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchById()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchById", + "reference": " @Test\n public void testSearchById() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult index =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(index.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchById() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 8, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = new ArrayList<>();\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 9, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_object", + "reference": " @Test\n public void should_clone_serializable_object() {\n Object original = new SerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_object() {\n // Given\n Serializable object = new Serializable() {\n private static final long serialVersionUID = 1L;\n // Implement necessary methods\n };\n\n // When\n Serializable clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertEquals(object, clonedObject);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 10, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() {\n\n\t\tvar text = \"Hello, how are you?\";\n\n\t\tEmbedding response = vertexAiPaLm2Api.embedText(text);\n\n\t\tassertThat(response).isNotNull();\n\t\tassertThat(response.value()).hasSize(768);\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(new EmbeddingResponse(expectedEmbedding));\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Optional maybeFetchErrorIntervalMs(short errorCode, int intervalMs) {\n if (errorCode == Errors.NONE.code())\n return Optional.empty();\n\n int pushIntervalMs;\n String reason;\n\n Errors error = Errors.forCode(errorCode);\n switch (error) {\n case INVALID_REQUEST:\n case INVALID_RECORD:\n case UNSUPPORTED_VERSION:\n pushIntervalMs = Integer.MAX_VALUE;\n reason = \"The broker response indicates the client sent an request that cannot be resolved\"\n + \" by re-trying, hence disable telemetry\";\n break;\n case UNKNOWN_SUBSCRIPTION_ID:\n case UNSUPPORTED_COMPRESSION_TYPE:\n pushIntervalMs = 0;\n reason = error.message();\n break;\n case TELEMETRY_TOO_LARGE:\n case THROTTLING_QUOTA_EXCEEDED:\n reason = error.message();\n pushIntervalMs = (intervalMs != -1) ? intervalMs : ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS;\n break;\n default:\n reason = \"Unwrapped error code\";\n log.error(\"Error code: {}. Unmapped error for telemetry, disable telemetry.\", errorCode);\n pushIntervalMs = Integer.MAX_VALUE;\n }\n\n log.debug(\"Error code: {}, reason: {}. Push interval update to {} ms.\", errorCode, reason, pushIntervalMs);\n return Optional.of(pushIntervalMs);\n }\n```\n###test function signature:\n```java\n @Test\n public void testMaybeFetchErrorIntervalMs()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMaybeFetchErrorIntervalMs", + "reference": " @Test\n public void testMaybeFetchErrorIntervalMs() {\n assertEquals(Optional.empty(), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.NONE.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_REQUEST.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_RECORD.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SUBSCRIPTION_ID.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNSUPPORTED_COMPRESSION_TYPE.code(), -1));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), 20000));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), 20000));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SERVER_ERROR.code(), -1));\n }", + "completions": [ + "```java\n @Test\n public void testMaybeFetchErrorIntervalMs() {\n // Test case 1: errorCode is Errors.NONE.code()\n short errorCode = Errors.NONE.code();\n int intervalMs = 1000;\n Optional result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertFalse(result.isPresent());\n\n // Test case 2: errorCode is Errors.INVALID_REQUEST.code()\n errorCode = Errors.INVALID_REQUEST.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 3: errorCode is Errors.UNKNOWN_SUBSCRIPTION_ID.code()\n errorCode = Errors.UNKNOWN_SUBSCRIPTION_ID.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 4: errorCode is Errors.TELEMETRY_TOO_LARGE.code()\n errorCode = Errors.TELEMETRY_TOO_LARGE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 5: errorCode is Errors.THROTTLING_QUOTA_EXCEEDED.code()\n errorCode = Errors.THROTTLING_QUOTA_EXCEEDED.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 6: errorCode is Errors.UNSUPPORTED_COMPRESSION_TYPE.code()\n errorCode = Errors.UNSUPPORTED_COMPRESSION_TYPE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 7: errorCode is Errors.INVALID_RECORD.code()\n errorCode = Errors.INVALID_RECORD.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 8: errorCode is Errors.UNSUPPORTED_VERSION.code()\n errorCode = Errors.UNSUPPORTED_VERSION.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 9: errorCode is an unmapped error\n errorCode = (short) 9999;\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n }\n```\n" + ], + "line": 43, + "token": 290, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 12, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_non_serializable_object", + "reference": " @Test\n public void should_clone_non_serializable_object() {\n Object original = new NonSerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_non_serializable_object() {\n // Given\n Map nonSerializableMap = new HashMap<>();\n nonSerializableMap.put(\"key1\", new Object());\n nonSerializableMap.put(\"key2\", new Object());\n\n // When\n Map clonedMap = objectCloner.clone(nonSerializableMap);\n\n // Then\n assertNotSame(nonSerializableMap, clonedMap);\n assertNotSame(nonSerializableMap.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(nonSerializableMap.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 13, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueOperationCompletesOperationWithCachedActionResult", + "reference": " @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ActionKey actionKey = DigestUtil.asActionKey(Digest.newBuilder().setHash(\"test\").build());\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"operation-with-cached-action-result\")\n .setActionDigest(actionKey.getDigest())\n .build();\n\n ActionResult actionResult = ActionResult.getDefaultInstance();\n\n when(mockBackplane.getActionResult(eq(actionKey))).thenReturn(actionResult);\n\n Poller poller = mock(Poller.class);\n\n instance.queue(executeEntry, poller, DEFAULT_TIMEOUT).get();\n\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(CACHE_CHECK));\n verify(mockBackplane, never()).putOperation(any(Operation.class), eq(QUEUED));\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(COMPLETED));\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n ActionResult actionResult = ActionResult.newBuilder().setExitCode(0).build();\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(\"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n when(cache.get(any(ActionKey.class))).thenReturn(actionResult);\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n future.get();\n\n verify(poller).pause();\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 14, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueActionFailsQueueEligibility", + "reference": " @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n Directory inputRoot = Directory.newBuilder().build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(false);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_INVALID)\n .setSubject(INVALID_PLATFORM)\n .setDescription(\n \"properties are not valid for queue eligibility: []. If you think your\"\n + \" queue should still accept these poperties without them being\"\n + \" specified in queue configuration, consider configuring the queue\"\n + \" with `allow_unmatched: True`\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(ACTION_DIGEST));\n when(executeEntry.getStdoutStreamName()).thenReturn(STDOUT_STREAM_NAME);\n when(executeEntry.getStderrStreamName()).thenReturn(STDERR_STREAM_NAME);\n when(executeEntry.getOperationName()).thenReturn(OPERATION_NAME);\n when(executeEntry.getRequestMetadata()).thenReturn(REQUEST_METADATA);\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n when(poller.pause()).thenThrow(new RuntimeException(\"Queue eligibility failed\"));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n\n assertThrows(RuntimeException.class, future::get);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 15, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new SerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new SerializableObject(\"name2\"),\n new SerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key\", new SerializableObject());\n SerializableObject object = new SerializableObject(map);\n\n // When\n SerializableObject clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertNotSame(object.getMap(), clonedObject.getMap());\n assertNotSame(object.getMap().get(\"key\"), clonedObject.getMap().get(\"key\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 16, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForSumAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForSumAgg", + "reference": " @Test\n public void testFullIndexSearchForSumAgg() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new SumAggBuilder(\"test\", TEST_SOURCE_LONG_PROPERTY, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalSum internalSum =\n (InternalSum) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n // 1, 3, 4, 5\n assertThat(internalSum.getValue()).isEqualTo(13);\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForSumAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new SumAggBuilder(\"testField\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertEquals(SumAggBuilder.class, result.getAggregation().getClass());\n assertEquals(\"testField\", ((SumAggBuilder) result.getAggregation()).getField());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 17, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void process(String base64AuthenticatorData, String signature, String clientDataJson, Fido2RegistrationData registration,\n Fido2AuthenticationData authenticationEntity) {\n log.debug(\"Registration: {}\", registration);\n\n AuthData authData = authenticatorDataParser.parseAssertionData(base64AuthenticatorData);\n commonVerifiers.verifyRpIdHash(authData, registration.getDomain());\n\n log.debug(\"User verification option: {}\", authenticationEntity.getUserVerificationOption());\n userVerificationVerifier.verifyUserVerificationOption(authenticationEntity.getUserVerificationOption(), authData);\n\n byte[] clientDataHash = DigestUtils.getSha256Digest().digest(base64Service.urlDecode(clientDataJson));\n\n try {\n int counter = authenticatorDataParser.parseCounter(authData.getCounters());\n commonVerifiers.verifyCounter(registration.getCounter(), counter);\n registration.setCounter(counter);\n\n JsonNode uncompressedECPointNode = dataMapperService.cborReadTree(base64Service.urlDecode(registration.getUncompressedECPoint()));\n PublicKey publicKey = coseService.createUncompressedPointFromCOSEPublicKey(uncompressedECPointNode);\n\n log.debug(\"Uncompressed ECpoint node: {}\", uncompressedECPointNode);\n log.debug(\"EC Public key hex: {}\", Hex.encodeHexString(publicKey.getEncoded()));\n log.debug(\"Registration algorithm: {}, default use: -7\", registration.getSignatureAlgorithm());\n authenticatorDataVerifier.verifyAssertionSignature(authData, clientDataHash, signature, publicKey, -7);\n\n } catch (Fido2CompromisedDevice ex) {\n log.error(\"Error compromised device: {}\", ex.getMessage());\n throw ex;\n } catch (Exception ex) {\n log.error(\"Error to check none assertion: {}\", ex.getMessage());\n throw new Fido2RuntimeException(\"Failed to check none assertion: {}\", ex.getMessage(), ex);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "process_ifCborReadTreeThrowException_fido2RuntimeException", + "reference": " @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n String base64AuthenticatorData = \"base64AuthenticatorData_test\";\n String signature = \"signature_test\";\n String clientDataJson = \"clientDataJson_test\";\n Fido2RegistrationData registration = mock(Fido2RegistrationData.class);\n Fido2AuthenticationData authenticationEntity = mock(Fido2AuthenticationData.class);\n\n when(authenticationEntity.getUserVerificationOption()).thenReturn(UserVerification.preferred);\n when(registration.getDomain()).thenReturn(\"domain_test\");\n when(registration.getCounter()).thenReturn(100);\n when(registration.getUncompressedECPoint()).thenReturn(\"uncompressedECPoint_test\");\n\n when(authenticatorDataParser.parseAssertionData(any())).thenReturn(mock(AuthData.class));\n when(base64Service.urlDecode(any(String.class))).thenReturn(\"decode_test\".getBytes());\n when(dataMapperService.cborReadTree(any(byte[].class))).thenThrow(new IOException(\"IOException_test\"));\n\n Fido2RuntimeException ex = assertThrows(Fido2RuntimeException.class, () -> noneAssertionFormatProcessor.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity));\n assertNotNull(ex);\n assertEquals(ex.getMessage(), \"IOException_test\");\n\n verify(log).debug(eq(\"Registration: {}\"), any(Fido2RegistrationData.class));\n verify(log).debug(eq(\"User verification option: {}\"), any(UserVerification.class));\n verify(commonVerifiers).verifyRpIdHash(any(AuthData.class), any(String.class));\n verify(authenticatorDataParser).parseCounter(any());\n verify(log).error(eq(\"Error to check none assertion: {}\"), any(String.class));\n\n verify(log, never()).error(eq(\"Error compromised device: {}\"), any(String.class));\n verifyNoInteractions(coseService, authenticatorDataVerifier);\n verifyNoMoreInteractions(log);\n }", + "completions": [ + "```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n // Arrange\n String base64AuthenticatorData = \"base64AuthenticatorData\";\n String signature = \"signature\";\n String clientDataJson = \"clientDataJson\";\n Fido2RegistrationData registration = new Fido2RegistrationData();\n Fido2AuthenticationData authenticationEntity = new Fido2AuthenticationData();\n\n when(authenticatorDataParser.parseAssertionData(base64AuthenticatorData)).thenReturn(new AuthData());\n when(authenticatorDataParser.parseCounter(any())).thenReturn(1);\n when(dataMapperService.cborReadTree(any())).thenThrow(new IOException());\n\n // Act & Assert\n Assertions.assertThrows(Fido2RuntimeException.class, () -> {\n fido2AuthenticationService.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity);\n });\n }\n```\n" + ], + "line": 38, + "token": 365, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 18, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void setUserData_invalidKey_throws()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "setUserData_invalidKey_throws", + "reference": "\t@Test\n\tpublic void setUserData_invalidKey_throws() {\n\t\t// setup\n\t\tTestParameters parameters = new TestParameters();\n\n\t\t// test\n\t\tfor (String key : new String[] { null, \"\" }) {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, \"test\");\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid key; key must be non-empty, non-null\"),\n\t\t\t\t\tex.getMessage());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@Test(expected = IllegalArgumentException.class)\n\tpublic void setUserData_invalidKey_throws() {\n\t\tTestClass testClass = new TestClass();\n\t\ttestClass.setUserData(\"\", \"value\");\n\t}\n```\n\nIn this test, we are creating an instance of the TestClass and calling the setUserData method with an empty key and a valid value. We expect this to throw an IllegalArgumentException, which is what we specify in the @Test annotation.\n" + ], + "line": 23, + "token": 170, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueDirectoryMissingErrorsOperation", + "reference": " @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ByteString foo = ByteString.copyFromUtf8(\"foo\");\n Digest subdirDigest = DIGEST_UTIL.compute(foo);\n Directory inputRoot =\n Directory.newBuilder()\n .addDirectories(\n DirectoryNode.newBuilder().setName(\"missing-subdir\").setDigest(subdirDigest))\n .build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(true);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_MISSING)\n .setSubject(\"blobs/\" + DigestUtil.toString(subdirDigest))\n .setDescription(\"The directory `/missing-subdir` was not found in the CAS.\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(10);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(DigestUtil.DEFAULT_INSTANCE_NAME + \"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n ActionKey actionKey = DigestUtil.asActionKey(executeEntry.getActionDigest());\n when(checkCacheFuture(actionKey, any(), any())).thenReturn(immediateFuture(true));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n assertThat(future.get()).isEqualTo(null);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 20, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime = Instant.now();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Start recovery with an offset range that is partially unavailable\n long startOffset = 50;\n long endOffset = 150;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry)).isEqualTo(50);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(51);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 100L,\n 200L,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 4, + "token_diff": 2 + }, + { + "id": 21, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public List selectHeaderAccept(String[] accepts) {\n if (accepts.length == 0) {\n return null;\n }\n for (String accept : accepts) {\n MediaType mediaType = MediaType.parseMediaType(accept);\n if (isJsonMime(mediaType) && !isProblemJsonMime(accept)) {\n return Collections.singletonList(mediaType);\n }\n }\n return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts));\n }\n```\n###test function signature:\n```java\n @Test\n public void testSelectHeaderAccept()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSelectHeaderAccept", + "reference": " @Test\n public void testSelectHeaderAccept() {\n String[] accepts = {\"application/json\", \"application/xml\"};\n assertEquals(\"application/json\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"APPLICATION/XML\", \"APPLICATION/JSON\"};\n assertEquals(\"APPLICATION/JSON\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"application/xml\", \"application/json; charset=UTF8\"};\n assertEquals(\"application/json; charset=UTF8\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"text/plain\", \"application/xml\"};\n assertEquals(\"text/plain,application/xml\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{};\n assertNull(apiClient.selectHeaderAccept(accepts));\n }", + "completions": [ + "```java\n @Test\n public void testSelectHeaderAccept() {\n // Given\n String[] accepts = {\"application/json\", \"application/problem+json\"};\n // When\n List result = selectHeaderAccept(accepts);\n // Then\n assertEquals(1, result.size());\n assertEquals(\"application/json\", result.get(0).toString());\n }\n```\n" + ], + "line": 23, + "token": 154, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollEventSentOnConsumerPoll", + "reference": " @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n SubscriptionState subscriptions = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);\n consumer = newConsumer(\n mock(FetchBuffer.class),\n new ConsumerInterceptors<>(Collections.emptyList()),\n mock(ConsumerRebalanceListenerInvoker.class),\n subscriptions,\n singletonList(new RoundRobinAssignor()),\n \"group-id\",\n \"client-id\");\n final TopicPartition tp = new TopicPartition(\"topic\", 0);\n final List> records = singletonList(\n new ConsumerRecord<>(\"topic\", 0, 2, \"key1\", \"value1\"));\n doAnswer(invocation -> Fetch.forPartition(tp, records, true))\n .when(fetchCollector)\n .collectFetch(Mockito.any(FetchBuffer.class));\n\n consumer.subscribe(singletonList(\"topic1\"));\n consumer.poll(Duration.ofMillis(100));\n verify(applicationEventHandler).add(any(PollEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n // Arrange\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n consumer.poll(Duration.ofMillis(100));\n\n // Act\n ConsumerRecords records = consumer.poll(Duration.ofMillis(100));\n\n // Assert\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 23, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetMultiplePartitionRecoveriesBehind", + "reference": " @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n final RecoveryTaskMetadata recoveryTask2 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"2\",\n \"2\",\n recoveryStartOffset * 3 + 1,\n recoveryStartOffset * 4,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask2);\n final RecoveryTaskMetadata recoveryTask21 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"21\", \"2\", recoveryStartOffset * 4 + 1, 50000, createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask21);\n await()\n .until(\n () ->\n recoveryTaskStore\n .listSync()\n .containsAll(\n List.of(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(5);\n assertThat(recoveryTasks)\n .contains(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n // Given\n String partitionId = \"partition1\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 24, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture commitStreamObject(apache.rocketmq.controller.v1.S3StreamObject streamObject,\n List compactedObjects) {\n LOGGER.info(\"commitStreamObject with streamObject: {}, compactedObjects: {}\", TextFormat.shortDebugString(streamObject),\n compactedObjects);\n\n CompletableFuture future = new CompletableFuture<>();\n try (SqlSession session = sessionFactory.openSession()) {\n if (streamObject.getObjectId() == S3Constants.NOOP_OBJECT_ID) {\n LOGGER.error(\"S3StreamObject[object-id={}] is null or objectId is unavailable\", streamObject.getObjectId());\n String msg = String.format(\"S3StreamObject[object-id=%d] is null or objectId is unavailable\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.NOT_FOUND_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n\n long committedTs = System.currentTimeMillis();\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n\n // commit object\n if (!commitObject(streamObject.getObjectId(), streamObject.getStreamId(), streamObject.getObjectSize(), session)) {\n String msg = String.format(\"S3StreamObject[object-id=%d] is not ready for commit\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.ILLEGAL_STATE_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n long dataTs = committedTs;\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n dataTs = compactedObjects.stream()\n .map(id -> {\n // mark destroy compacted object\n S3Object object = s3ObjectMapper.getById(id);\n object.setState(S3ObjectState.BOS_WILL_DELETE);\n object.setMarkedForDeletionTimestamp(new Date());\n s3ObjectMapper.markToDelete(object.getId(), new Date());\n\n // update dataTs to the min compacted object's dataTs\n com.automq.rocketmq.metadata.dao.S3StreamObject s3StreamObject =\n s3StreamObjectMapper.getByObjectId(id);\n return s3StreamObject.getBaseDataTimestamp().getTime();\n })\n .min(Long::compareTo).get();\n }\n\n List toCache = new ArrayList<>();\n\n // create a new S3StreamObject to replace committed ones\n if (streamObject.getObjectId() != S3Constants.NOOP_OBJECT_ID) {\n com.automq.rocketmq.metadata.dao.S3StreamObject newS3StreamObj =\n new com.automq.rocketmq.metadata.dao.S3StreamObject();\n newS3StreamObj.setStreamId(streamObject.getStreamId());\n newS3StreamObj.setObjectId(streamObject.getObjectId());\n newS3StreamObj.setObjectSize(streamObject.getObjectSize());\n newS3StreamObj.setStartOffset(streamObject.getStartOffset());\n newS3StreamObj.setEndOffset(streamObject.getEndOffset());\n newS3StreamObj.setBaseDataTimestamp(new Date(dataTs));\n newS3StreamObj.setCommittedTimestamp(new Date(committedTs));\n s3StreamObjectMapper.create(newS3StreamObj);\n toCache.add(newS3StreamObj);\n }\n\n // delete the compactedObjects of S3Stream\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n compactedObjects.forEach(id -> s3StreamObjectMapper.delete(null, null, id));\n }\n session.commit();\n\n // Update Cache\n s3StreamObjectCache.cache(streamObject.getStreamId(), toCache);\n s3StreamObjectCache.onCompact(streamObject.getStreamId(), compactedObjects);\n\n LOGGER.info(\"S3StreamObject[object-id={}] commit success, compacted objects: {}\",\n streamObject.getObjectId(), compactedObjects);\n future.complete(null);\n } catch (Exception e) {\n LOGGER.error(\"CommitStream failed\", e);\n ControllerException ex = new ControllerException(Code.INTERNAL_VALUE, \"CommitStream failed\" + e.getMessage());\n future.completeExceptionally(ex);\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCommitStreamObject() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCommitStreamObject", + "reference": " @Test\n public void testCommitStreamObject() throws IOException {\n long objectId, streamId = 1;\n\n try (S3MetadataService metadataService = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n objectId = metadataService.prepareS3Objects(3, 5).get();\n } catch (Exception e) {\n throw new RuntimeException(e);\n }\n\n S3StreamObject news3StreamObject = S3StreamObject.newBuilder()\n .setObjectId(objectId + 2)\n .setStreamId(streamId)\n .setObjectSize(111L)\n .build();\n\n try (SqlSession session = this.getSessionFactory().openSession()) {\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n buildS3StreamObjs(objectId, 2, 3, 100L).forEach(s3StreamObjectMapper::create);\n session.commit();\n }\n\n try (S3MetadataService service = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n List compactedObjects = new ArrayList<>();\n compactedObjects.add(objectId);\n compactedObjects.add(objectId + 1);\n service.commitStreamObject(news3StreamObject, compactedObjects);\n }\n\n try (SqlSession session = getSessionFactory().openSession()) {\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3Object s3Object = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object.getState());\n\n S3Object s3Object1 = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object1.getState());\n\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n for (long index = objectId; index < objectId + 2; index++) {\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(index);\n Assertions.assertNull(object);\n }\n\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(objectId + 2);\n Assertions.assertEquals(111L, object.getObjectSize());\n Assertions.assertEquals(streamId, object.getStreamId());\n Assertions.assertTrue(object.getBaseDataTimestamp().getTime() > 0);\n Assertions.assertTrue(object.getCommittedTimestamp().getTime() > 0);\n }\n }", + "completions": [ + "```java\n @Test\n public void testCommitStreamObject() throws IOException {\n // Given\n apache.rocketmq.controller.v1.S3StreamObject streamObject = new apache.rocketmq.controller.v1.S3StreamObject();\n streamObject.setObjectId(1L);\n streamObject.setStreamId(2L);\n streamObject.setObjectSize(3L);\n streamObject.setStartOffset(4L);\n streamObject.setEndOffset(5L);\n\n List compactedObjects = Arrays.asList(6L, 7L, 8L);\n\n // When\n CompletableFuture future = controller.commitStreamObject(streamObject, compactedObjects);\n\n // Then\n assertDoesNotThrow(() -> future.get());\n }\n```\n" + ], + "line": 85, + "token": 664, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 25, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyTooManyResults", + "reference": " @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(totalHits.value).thenReturn(2L);\n Mockito.when(hits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(response.getHits()).thenReturn(hits);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processService.getProcessByKey(processDefinitionKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Expected exception message to match\",\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n e.getMessage());\n }\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 26, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String doGet(String url) {\n try (CloseableHttpClient client = HttpClients.createDefault()) {\n HttpGet request = new HttpGet(url);\n HttpResponse response = client.execute(request);\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n return EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n }\n } catch (IOException e) {\n log.error(\"DoGet exception.\", e);\n\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n void testDoGet()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDoGet", + "reference": " @Test\n void testDoGet() {\n String result = HttpUtil.doGet(\"http://baidu.com\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testDoGet() {\n String url = \"http://example.com\";\n String expectedResponse = \"Hello, world!\";\n\n // Mock the HttpClient and HttpResponse\n CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);\n HttpResponse mockResponse = Mockito.mock(HttpResponse.class);\n StatusLine mockStatusLine = Mockito.mock(StatusLine.class);\n HttpEntity mockEntity = Mockito.mock(HttpEntity.class);\n\n // Set up the mock objects to return the expected values\n Mockito.when(mockClient.execute(Mockito.any(HttpGet.class))).thenReturn(mockResponse);\n Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);\n Mockito.when(mockStatusLine.getStatusCode()).thenReturn(HttpStatus.SC_OK);\n Mockito.when(mockResponse.getEntity()).thenReturn(mockEntity);\n Mockito.when(EntityUtils.toString(mockEntity, StandardCharsets.UTF_8)).thenReturn(expectedResponse);\n\n // Call the method under test\n String actualResponse = YourClassName.doGet(url);\n\n // Verify the result\n Assertions.assertEquals(expectedResponse, actualResponse);\n }\n```\nPlease replace `YourClassName` with the actual name of the class where the `doGet` method is located.\n" + ], + "line": 23, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n dpopService.validateDpopThumprint(null, \"any\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n // Given\n String existingThumprint = \"\";\n String requestThumprint = \"requestThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n try {\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), e.getResponse().getStatus());\n assertEquals(\"Thumprint does not match\", e.getMessage());\n }\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDataMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDataMapTest", + "reference": " @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"hera_log_doris_table_2_11\";\n for (int i = 0; i < 1; i++) {\n String jsonString = \"{\\\"linenumber\\\":1256,\\\"tailId\\\":90219,\\\"mqtag\\\":\\\"tags_4_13_90219\\\",\\\"filename\\\":\\\"/home/work/log/log-agent/server.log\\\",\\\"tail\\\":\\\"demo-client-agent\\\",\\\"mqtopic\\\":\\\"90219_hera-demo-client\\\",\\\"message\\\":\\\"2024-01-10 19:34:40,360|INFO ||NettyClientPublicExecutor_1|c.x.mone.log.agent.rpc.task.PingTask|83|ping res: log-agent-server:2022-12-05:0.0.2->2024-01-10 19:34:40 358->10.53.129.250:9899\\\",\\\"logstore\\\":\\\"测试doris日志\\\",\\\"logip\\\":\\\"10.53.129.176\\\",\\\"timestamp\\\":1704886481241}\";\n\n Map map = gson.fromJson(jsonString, Map.class);\n dorisStreamLoad.sendData(database, tableName, columnListNew, map);\n }\n }", + "completions": [ + "```java\n @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 29, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyNoResults", + "reference": " @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final NotFoundException exception =\n Assertions.assertThrows(\n NotFoundException.class, () -> processService.getProcessByKey(processDefinitionKey));\n\n // Then\n Assertions.assertEquals(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 30, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String getDiagramByKey(Long processDefinitionKey) {\n final IdsQueryBuilder q = idsQuery().addIds(processDefinitionKey.toString());\n\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n\n if (response.getHits().getTotalHits().value == 1) {\n final Map result = response.getHits().getHits()[0].getSourceAsMap();\n return (String) result.get(BPMN_XML);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with id '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with id '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\n \"Exception occurred, while obtaining the process diagram: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetDiagramByKeyTooManyResults", + "reference": " @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getDiagramByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 1L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHits.Relation.EQUAL_TO));\n\n // When\n try {\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n processDiagramService.getDiagramByKey(processDefinitionKey);\n } catch (NotFoundException e) {\n // Then\n Assert.assertEquals(\n \"Could not find unique process with id '\" + processDefinitionKey + \"'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 37, + "token": 309, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 31, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n void shutdown() {\n LOG.info(\"Running shutdown hook.\");\n try {\n serviceManager.stopAsync().awaitStopped(30, TimeUnit.SECONDS);\n } catch (Exception e) {\n // stopping timed out\n LOG.error(\"ServiceManager shutdown timed out\", e);\n }\n try {\n curatorFramework.unwrap().close();\n } catch (Exception e) {\n LOG.error(\"Error while closing curatorFramework \", e);\n }\n LOG.info(\"Shutting down LogManager\");\n LogManager.shutdown();\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexerShutdownTwice() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexerShutdownTwice", + "reference": " @Test\n public void testIndexerShutdownTwice() throws Exception {\n startKafkaServer();\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // Create a live partition for this partiton\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTimeMs = 1;\n final long endTimeMs = 100;\n final long maxOffset = 30;\n SnapshotMetadata livePartition0 =\n new SnapshotMetadata(\n name + \"live0\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"0\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition0);\n\n SnapshotMetadata livePartition1 =\n new SnapshotMetadata(\n name + \"live1\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"1\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition1);\n\n final SnapshotMetadata partition0 =\n new SnapshotMetadata(name, path, startTimeMs, endTimeMs, maxOffset, \"0\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition0);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .containsOnly(livePartition1, livePartition0, partition0);\n\n // Empty consumer offset since there is no prior consumer.\n astraIndexer =\n new AstraIndexer(\n chunkManagerUtil.chunkManager,\n curatorFramework,\n makeIndexerConfig(50),\n getKafkaConfig(),\n metricsRegistry);\n astraIndexer.startAsync();\n astraIndexer.awaitRunning(DEFAULT_START_STOP_DURATION);\n await().until(() -> kafkaServer.getConnectedConsumerGroups() == 1);\n\n // Produce more messages since the recovery task is created for head.\n produceMessagesToKafka(kafkaServer.getBroker(), startTime);\n\n consumeMessagesAndSearchMessagesTest(100, 1);\n\n // Live snapshot is deleted, recovery task is created.\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(livePartition1, partition0);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isEqualTo(4);\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).size()).isEqualTo(1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(searchMetadataStore).size()).isEqualTo(1);\n RecoveryTaskMetadata recoveryTask1 =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).get(0);\n assertThat(recoveryTask1.startOffset).isEqualTo(31);\n assertThat(recoveryTask1.endOffset).isEqualTo(99);\n assertThat(recoveryTask1.partitionId).isEqualTo(\"0\");\n\n // Shutting down is idempotent. So, doing it twice shouldn't throw an error.\n astraIndexer.shutDown();\n astraIndexer.shutDown();\n astraIndexer = null;\n }", + "completions": [ + "```java\n @Test\n public void testIndexerShutdownTwice() throws Exception {\n // Given\n Indexer indexer = new Indexer();\n indexer.serviceManager = Mockito.mock(ServiceManager.class);\n indexer.curatorFramework = Mockito.mock(CuratorFramework.class);\n Mockito.when(indexer.serviceManager.stopAsync()).thenReturn(Mockito.mock(ServiceManager.class));\n Mockito.when(indexer.serviceManager.awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class))).thenThrow(new TimeoutException());\n Mockito.when(indexer.curatorFramework.unwrap()).thenReturn(Mockito.mock(Closeable.class));\n Mockito.doThrow(new IOException()).when(indexer.curatorFramework.unwrap()).close();\n\n // When\n indexer.shutdown();\n\n // Then\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).stopAsync();\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class));\n Mockito.verify(indexer.curatorFramework.unwrap(), Mockito.times(1)).close();\n }\n```\n" + ], + "line": 27, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(addressList,\"test\",\"test\",\"test\", SystemEnum.AGENT, MsgTypeEnum.ALARM);\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = null;\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 33, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyWithException", + "reference": " @Test\n public void testGetProcessByKeyWithException() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(OperateRuntimeException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(totalHits.value).thenReturn(2L);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final Executable executable = () -> processService.getProcessByKey(processDefinitionKey);\n\n // Then\n final NotFoundException exception =\n Assertions.assertThrows(NotFoundException.class, executable);\n Assertions.assertEquals(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 34, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testConcurrentSearches() throws InterruptedException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testConcurrentSearches", + "reference": " @Test\n public void testConcurrentSearches() throws InterruptedException {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n AtomicInteger searchFailures = new AtomicInteger(0);\n AtomicInteger statsFailures = new AtomicInteger(0);\n AtomicInteger searchExceptions = new AtomicInteger(0);\n AtomicInteger successfulRuns = new AtomicInteger(0);\n\n Runnable searchRun =\n () -> {\n for (int i = 0; i < 100; i++) {\n try {\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n 0L,\n MAX_TIME,\n 100,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n if (babies.hits.size() != 2) {\n searchFailures.addAndGet(1);\n } else {\n successfulRuns.addAndGet(1);\n }\n } catch (Exception e) {\n searchExceptions.addAndGet(1);\n }\n }\n };\n\n Thread t1 = new Thread(searchRun);\n Thread t2 = new Thread(searchRun);\n t1.start();\n t2.start();\n t1.join();\n t2.join();\n assertThat(searchExceptions.get()).isEqualTo(0);\n assertThat(statsFailures.get()).isEqualTo(0);\n assertThat(searchFailures.get()).isEqualTo(0);\n assertThat(successfulRuns.get()).isEqualTo(200);\n }", + "completions": [ + "```java\n @Test\n public void testConcurrentSearches() throws InterruptedException {\n // Create a fixed thread pool with 10 threads\n ExecutorService executorService = Executors.newFixedThreadPool(10);\n\n // Create a list to hold the futures\n List>> futures = new ArrayList<>();\n\n // Submit 100 search tasks to the executor service\n for (int i = 0; i < 100; i++) {\n // Create a search task\n Callable> searchTask = () -> {\n // Generate a random dataset\n String dataset = UUID.randomUUID().toString();\n\n // Generate a random query string\n String queryStr = UUID.randomUUID().toString();\n\n // Generate a random start time\n long startTimeMsEpoch = System.currentTimeMillis() - ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random end time\n long endTimeMsEpoch = startTimeMsEpoch + ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random number of hits requested\n int howMany = ThreadLocalRandom.current().nextInt(100);\n\n // Generate a random aggregation builder\n AggBuilder aggBuilder = new AggBuilder();\n\n // Perform the search\n return logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n };\n\n // Submit the search task to the executor service and add the future to the list\n futures.add(executorService.submit(searchTask));\n }\n\n // Wait for all search tasks to complete\n for (Future> future : futures) {\n future.get();\n }\n\n // Shutdown the executor service\n executorService.shutdown();\n }\n```\n" + ], + "line": 85, + "token": 608, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 35, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String sendGetData(String path, Map headMap) {\n log.info(\"request tcm, path:{},headMap:{} \",path,headMap);\n String result = \"\";\n CloseableHttpResponse response =null;\n try(CloseableHttpClient httpClient = HttpClients.createDefault()) {\n // 创建get方式请求对象\n HttpGet httpGet = new HttpGet(path);\n if (null != headMap) {\n for (Map.Entry entry : headMap.entrySet()) {\n httpGet.addHeader(entry.getKey(), entry.getValue());\n }\n }\n // 通过请求对象获取响应对象\n response = httpClient.execute(httpGet);\n\n // 获取结果实体\n // 判断网络连接状态码是否正常(0--200都数正常)\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n result = EntityUtils.toString(response.getEntity(), UTF_8);\n } else {\n log.error(\"get请求传输 异常.url:{}, headMap:{}\", path, headMap);\n }\n } catch (ClientProtocolException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n } catch (IOException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n }\n finally {\n try {\n if (null!=response){\n response.close();\n }\n } catch (IOException e) {\n log.error(\"关闭response 异常 \", e);\n }\n }\n log.debug(result);\n return result;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendGetData()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendGetData", + "reference": " @Test\n void testSendGetData() {\n final Map headMap = new HashMap<>();\n final String result = HttpUtils.sendGetData(\"path\", headMap);\n assertThat(result).isEqualTo(\"\");\n }", + "completions": [ + "```java\n @Test\n void testSendGetData() {\n String path = \"http://localhost:8080/api/test\";\n Map headMap = new HashMap<>();\n headMap.put(\"Content-Type\", \"application/json\");\n headMap.put(\"Authorization\", \"Bearer your_token\");\n\n String result = sendGetData(path, headMap);\n\n // Add assertions to verify the result\n // For example, if the result is expected to be a JSON string, you can use a JSON library to parse the result and assert the values\n // Here is an example using the JSON library Jackson\n ObjectMapper objectMapper = new ObjectMapper();\n JsonNode jsonNode = objectMapper.readTree(result);\n assertEquals(\"expectedValue\", jsonNode.get(\"key\").asText());\n }\n```\nPlease replace \"your_token\" with the actual token you want to use for the test. Also, replace \"expectedValue\" and \"key\" with the actual expected value and key from the JSON response.\n" + ], + "line": 51, + "token": 367, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 36, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ByteString retainKeyPrefix(ByteString tenantNS, List topicFilterLevels) {\n ByteString prefix = ByteString.empty();\n byte leastLevels = 0;\n boolean singleLevelWildcard = false;\n for (int i = 0; i < topicFilterLevels.size(); i++) {\n String tfl = topicFilterLevels.get(i);\n if (\"+\".equals(tfl)) {\n leastLevels++;\n singleLevelWildcard = true;\n continue;\n }\n if (\"#\".equals(tfl)) {\n break;\n }\n leastLevels++;\n if (!singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(tfl));\n }\n if (i + 1 < topicFilterLevels.size()) {\n if (!topicFilterLevels.get(i + 1).equals(\"#\") && !singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(NUL));\n }\n }\n }\n return tenantNS.concat(unsafeWrap(new byte[] {leastLevels})).concat(prefix);\n }\n```\n###test function signature:\n```java\n @Test\n public void testRetainKeyPrefix()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRetainKeyPrefix", + "reference": " @Test\n public void testRetainKeyPrefix() {\n String tenantId = \"tenantA\";\n ByteString tenantNS = tenantNS(tenantId);\n assertEquals(retainKeyPrefix(tenantNS, parse(\"#\", false)), tenantNS.concat(levelByte(0)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/#\", false)),\n tenantNS.concat(levelByte(1).concat(copyFromUtf8(\"a\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/+\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(\"a\").concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/b\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)).concat(copyFromUtf8(\"b\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/#\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+\", false)),\n tenantNS.concat(\n levelByte(3).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+/+\", false)),\n tenantNS.concat(\n levelByte(4).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/\", false)),\n tenantNS.concat(levelByte(4).concat(copyFromUtf8(NUL))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/+/\", false)),\n tenantNS.concat(levelByte(5).concat(copyFromUtf8(NUL))));\n }", + "completions": [ + "```java\n @Test\n public void testRetainKeyPrefix() {\n ByteString tenantNS = ByteString.copyFromUtf8(\"tenant1\");\n List topicFilterLevels = Arrays.asList(\"level1\", \"level2\", \"+\", \"#\");\n ByteString expectedPrefix = ByteString.copyFromUtf8(\"tenant1\\u0001level1\\u0000level2\");\n ByteString actualPrefix = retainKeyPrefix(tenantNS, topicFilterLevels);\n assertEquals(expectedPrefix, actualPrefix);\n }\n```\n" + ], + "line": 37, + "token": 255, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 37, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n @Transactional(rollbackFor = Exception.class)\n public void updateFileConfigMaster(Long id) {\n // 校验存在\n validateFileConfigExists(id);\n // 更新其它为非 master\n fileConfigMapper.updateBatch(new FileConfigDO().setMaster(false));\n // 更新\n fileConfigMapper.updateById(new FileConfigDO().setId(id).setMaster(true));\n // 发送刷新配置的消息\n TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {\n\n @Override\n public void afterCommit() {\n fileConfigProducer.sendFileConfigRefreshMessage();\n }\n\n });\n }\n```\n###test function signature:\n```java\n @Test\n public void testUpdateFileConfigMaster_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testUpdateFileConfigMaster_success", + "reference": " @Test\n public void testUpdateFileConfigMaster_success() {\n // mock 数据\n FileConfigDO dbFileConfig = randomFileConfigDO().setMaster(false);\n fileConfigMapper.insert(dbFileConfig);// @Sql: 先插入出一条存在的数据\n FileConfigDO masterFileConfig = randomFileConfigDO().setMaster(true);\n fileConfigMapper.insert(masterFileConfig);// @Sql: 先插入出一条存在的数据\n\n // 调用\n fileConfigService.updateFileConfigMaster(dbFileConfig.getId());\n // 断言数据\n assertTrue(fileConfigMapper.selectById(dbFileConfig.getId()).getMaster());\n assertFalse(fileConfigMapper.selectById(masterFileConfig.getId()).getMaster());\n // verify 调用\n verify(fileConfigProducer).sendFileConfigRefreshMessage();\n }", + "completions": [ + "```java\n @Test\n public void testUpdateFileConfigMaster_success() {\n // 准备数据\n FileConfigDO fileConfigDO = new FileConfigDO();\n fileConfigDO.setId(1L);\n fileConfigDO.setMaster(false);\n fileConfigMapper.insert(fileConfigDO);\n\n // 执行方法\n fileConfigService.updateFileConfigMaster(1L);\n\n // 验证结果\n FileConfigDO updatedFileConfigDO = fileConfigMapper.selectById(1L);\n Assert.assertTrue(updatedFileConfigDO.getMaster());\n\n // 清理数据\n fileConfigMapper.deleteById(1L);\n }\n```\n" + ], + "line": 28, + "token": 164, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\t@Override\n\tpublic void add(List documents) {\n\n\t\tAssert.notNull(documents, \"The document list should not be null.\");\n\t\tif (CollectionUtils.isEmpty(documents)) {\n\t\t\treturn; // nothing to do;\n\t\t}\n\n\t\tfinal var searchDocuments = documents.stream().map(document -> {\n\t\t\tfinal var embeddings = this.embeddingClient.embed(document);\n\t\t\tSearchDocument searchDocument = new SearchDocument();\n\t\t\tsearchDocument.put(ID_FIELD_NAME, document.getId());\n\t\t\tsearchDocument.put(EMBEDDING_FIELD_NAME, embeddings);\n\t\t\tsearchDocument.put(CONTENT_FIELD_NAME, document.getContent());\n\t\t\tsearchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());\n\n\t\t\t// Add the filterable metadata fields as top level fields, allowing filler\n\t\t\t// expressions on them.\n\t\t\tfor (MetadataField mf : this.filterMetadataFields) {\n\t\t\t\tif (document.getMetadata().containsKey(mf.name())) {\n\t\t\t\t\tsearchDocument.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name()));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn searchDocument;\n\t\t}).toList();\n\n\t\tIndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);\n\n\t\tfor (IndexingResult indexingResult : result.getResults()) {\n\t\t\tAssert.isTrue(indexingResult.isSucceeded(),\n\t\t\t\t\tString.format(\"Document with key %s upload is not successfully\", indexingResult.getKey()));\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void addAndSearchTest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "addAndSearchTest", + "reference": "\t@Test\n\tpublic void addAndSearchTest() {\n\n\t\tcontextRunner.run(context -> {\n\n\t\t\tVectorStore vectorStore = context.getBean(VectorStore.class);\n\n\t\t\tvectorStore.add(documents);\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\t\t\t}, hasSize(1));\n\n\t\t\tList results = vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\n\t\t\tassertThat(results).hasSize(1);\n\t\t\tDocument resultDoc = results.get(0);\n\t\t\tassertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());\n\t\t\tassertThat(resultDoc.getContent()).contains(\"The Great Depression (1929–1939) was an economic shock\");\n\t\t\tassertThat(resultDoc.getMetadata()).hasSize(2);\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"meta2\");\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"distance\");\n\n\t\t\t// Remove all documents from the store\n\t\t\tvectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Hello\").withTopK(1));\n\t\t\t}, hasSize(0));\n\t\t});\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void addAndSearchTest() {\n\t\t// Given\n\t\tList documents = new ArrayList<>();\n\t\tDocument document1 = new Document();\n\t\tdocument1.setId(\"1\");\n\t\tdocument1.setContent(\"This is a test document\");\n\t\tdocument1.setMetadata(new HashMap<>());\n\t\tdocument1.getMetadata().put(\"author\", \"John Doe\");\n\t\tdocument1.getMetadata().put(\"date\", \"2022-01-01\");\n\t\tdocuments.add(document1);\n\n\t\tDocument document2 = new Document();\n\t\tdocument2.setId(\"2\");\n\t\tdocument2.setContent(\"This is another test document\");\n\t\tdocument2.setMetadata(new HashMap<>());\n\t\tdocument2.getMetadata().put(\"author\", \"Jane Doe\");\n\t\tdocument2.getMetadata().put(\"date\", \"2022-01-02\");\n\t\tdocuments.add(document2);\n\n\t\t// When\n\t\tsearchService.add(documents);\n\n\t\t// Then\n\t\tSearchRequest searchRequest = new SearchRequest();\n\t\tsearchRequest.setQuery(\"test\");\n\t\tsearchRequest.setFilter(new HashMap<>());\n\t\tsearchRequest.getFilter().put(\"author\", \"John Doe\");\n\t\tSearchResponse searchResponse = searchService.search(searchRequest);\n\n\t\tAssert.assertEquals(1, searchResponse.getDocuments().size());\n\t\tAssert.assertEquals(\"1\", searchResponse.getDocuments().get(0).getId());\n\t}\n```\n" + ], + "line": 39, + "token": 312, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 39, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {\n long startTime = System.currentTimeMillis();\n\n // 校验参数\n try {\n validateParams(metrics);\n } catch (Exception e) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(e.getMessage());\n return;\n }\n\n HttpContext httpContext = createHttpContext(metrics.getNginx());\n HttpUriRequest request = createHttpRequest(metrics.getNginx());\n try (CloseableHttpResponse response = CommonHttpClient.getHttpClient().execute(request, httpContext)){\n // 发起http请求,获取响应数据\n int statusCode = response.getStatusLine().getStatusCode();\n if (statusCode != SUCCESS_CODE) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(\"StatusCode \" + statusCode);\n return;\n }\n String resp = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n\n Long responseTime = System.currentTimeMillis() - startTime;\n // 根据metrics name选择调用不同解析方法\n if (NGINX_STATUS_NAME.equals(metrics.getName()) || AVAILABLE.equals(metrics.getName())) {\n parseNginxStatusResponse(builder, resp, metrics, responseTime);\n } else if (REQ_STATUS_NAME.equals(metrics.getName())) {\n parseReqStatusResponse(builder, resp, metrics, responseTime);\n }\n } catch (Exception e) {\n String errorMsg = CommonUtil.getMessageFromThrowable(e);\n log.info(errorMsg);\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(errorMsg);\n } finally {\n if (request != null) {\n request.abort();\n }\n }\n\n }\n```\n###test function signature:\n```java\n @Test\n public void testNginxCollect()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNginxCollect", + "reference": " @Test\n public void testNginxCollect() {\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n long monitorId = 999;\n String app = \"testNginx\";\n\n Metrics metrics = new Metrics();\n metrics.setNginx(nginxProtocol);\n nginxCollect.collect(builder, monitorId, app, metrics);\n }", + "completions": [ + "```java\n @Test\n public void testNginxCollect() {\n // 创建一个Metrics对象\n Metrics metrics = new Metrics();\n metrics.setName(\"nginx_status\");\n metrics.setNginx(new Nginx(\"http://localhost:8080/nginx_status\"));\n\n // 创建一个CollectRep.MetricsData.Builder对象\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n\n // 创建一个NginxCollect对象\n NginxCollect nginxCollect = new NginxCollect();\n\n // 调用collect方法\n nginxCollect.collect(builder, 1, \"test\", metrics);\n\n // 验证结果\n assertEquals(CollectRep.Code.SUCCESS, builder.getCode());\n assertNotNull(builder.getData());\n assertTrue(builder.getData().getMetrics().containsKey(\"active_connections\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"accepts\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"handled\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"requests\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"reading\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"writing\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"waiting\"));\n }\n```\n" + ], + "line": 51, + "token": 335, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 40, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Transactional(rollbackOn = Exception.class)\n public void submit(@NonNull PartitionPlanConfig partitionPlanConfig)\n throws SchedulerException, ClassNotFoundException {\n Long databaseId = partitionPlanConfig.getDatabaseId();\n Validate.notNull(databaseId, \"DatabaseId can not be null\");\n // disable all related partition plan task\n Database database = this.databaseService.detail(databaseId);\n disablePartitionPlan(database.getId());\n PartitionPlanEntity partitionPlanEntity = modelToEntity(partitionPlanConfig);\n partitionPlanEntity = this.partitionPlanRepository.save(partitionPlanEntity);\n if (!partitionPlanConfig.isEnabled()\n || CollectionUtils.isEmpty(partitionPlanConfig.getPartitionTableConfigs())) {\n log.info(\"Partition plan is disabled or table config is empty, do nothing and return\");\n return;\n }\n Validate.isTrue(partitionPlanConfig.getCreationTrigger() != null, \"Creation trigger can not be null\");\n if (partitionPlanConfig.getDroppingTrigger() == null) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(partitionPlanConfig.getPartitionTableConfigs(),\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n return;\n }\n Map> strategy2TblCfgs =\n partitionPlanConfig.getPartitionTableConfigs().stream().flatMap(tableConfig -> {\n Map> strategy2Cfgs =\n tableConfig.getPartitionKeyConfigs().stream()\n .collect(Collectors.groupingBy(PartitionPlanKeyConfig::getStrategy));\n return strategy2Cfgs.values().stream().map(cfgs -> {\n PartitionPlanTableConfig cfg = new PartitionPlanTableConfig();\n cfg.setPartitionKeyConfigs(cfgs);\n cfg.setTableName(tableConfig.getTableName());\n cfg.setEnabled(tableConfig.isEnabled());\n cfg.setPartitionNameInvoker(tableConfig.getPartitionNameInvoker());\n cfg.setPartitionNameInvokerParameters(tableConfig.getPartitionNameInvokerParameters());\n return cfg;\n });\n }).collect(Collectors.groupingBy(cfg -> cfg.getPartitionKeyConfigs().get(0).getStrategy()));\n List createConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.CREATE);\n List dropConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.DROP);\n if (CollectionUtils.isNotEmpty(createConfigs)) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(createConfigs,\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n if (CollectionUtils.isNotEmpty(dropConfigs)) {\n ScheduleEntity dropScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getDroppingTrigger());\n createPartitionPlanTables(dropConfigs,\n partitionPlanEntity.getId(), dropScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "submit_bothCreateAndDropTrigger_submitSucceed", + "reference": " @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n PartitionPlanTableConfig tableConfig = new PartitionPlanTableConfig();\n tableConfig.setTableName(MYSQL_REAL_RANGE_TABLE_NAME);\n tableConfig.setPartitionNameInvoker(\"CUSTOM_PARTITION_NAME_GENERATOR\");\n SqlExprBasedGeneratorConfig config = new SqlExprBasedGeneratorConfig();\n config.setGenerateExpr(\"concat('p', date_format(from_unixtime(unix_timestamp(\"\n + \"STR_TO_DATE(20240125, '%Y%m%d')) + \"\n + PartitionPlanVariableKey.INTERVAL.getVariable() + \"), '%Y%m%d'))\");\n config.setIntervalGenerateExpr(\"86400\");\n tableConfig.setPartitionNameInvokerParameters(getSqlExprBasedNameGeneratorParameters(config));\n PartitionPlanKeyConfig c3Create = getMysqlc3CreateConfig();\n PartitionPlanKeyConfig datekeyCreate = getMysqldatekeyCreateConfig();\n PartitionPlanKeyConfig dropConfig = getDropConfig();\n tableConfig.setPartitionKeyConfigs(Arrays.asList(c3Create, datekeyCreate, dropConfig));\n\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setPartitionTableConfigs(Collections.singletonList(tableConfig));\n partitionPlanConfig.setFlowInstanceId(1L);\n partitionPlanConfig.setTimeoutMillis(180000L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setDatabaseId(1L);\n\n long t1 = System.currentTimeMillis();\n long t2 = System.currentTimeMillis() + 1000;\n TriggerConfig createTrigger = TestRandom.nextObject(TriggerConfig.class);\n createTrigger.setStartAt(new Date(t1));\n partitionPlanConfig.setCreationTrigger(createTrigger);\n\n TriggerConfig dropTrigger = TestRandom.nextObject(TriggerConfig.class);\n dropTrigger.setStartAt(new Date(t2));\n partitionPlanConfig.setDroppingTrigger(dropTrigger);\n Database database = TestRandom.nextObject(Database.class);\n database.setId(1L);\n Mockito.when(this.databaseService.detail(1L)).thenReturn(database);\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t1;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t2;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.doNothing().when(this.scheduleService).enable(Mockito.isA(ScheduleEntity.class));\n this.partitionPlanScheduleService.submit(partitionPlanConfig);\n List actuals = this.partitionPlanTableRepository.findAll();\n Set scheduleIds = actuals.stream()\n .map(PartitionPlanTableEntity::getScheduleId).collect(Collectors.toSet());\n Assert.assertEquals(scheduleIds.size(), 2);\n }", + "completions": [ + "```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n // Given\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setDatabaseId(1L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanConfig.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanConfig.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanConfig.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanConfig.setTaskId(\"taskId\");\n partitionPlanConfig.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanConfig.setTimeoutMillis(1000L);\n\n Database database = new Database();\n database.setId(1L);\n database.setName(\"database\");\n database.setType(DatabaseType.MYSQL);\n database.setHost(\"localhost\");\n database.setPort(3306);\n database.setUsername(\"root\");\n database.setPassword(\"password\");\n database.setDatabaseName(\"test\");\n database.setProperties(\"{}\");\n database.setStatus(DatabaseStatus.ONLINE);\n database.setCreateTime(new Date());\n database.setUpdateTime(new Date());\n when(databaseService.detail(1L)).thenReturn(database);\n\n PartitionPlanEntity partitionPlanEntity = new PartitionPlanEntity();\n partitionPlanEntity.setId(1L);\n partitionPlanEntity.setDatabaseId(1L);\n partitionPlanEntity.setEnabled(true);\n partitionPlanEntity.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanEntity.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanEntity.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanEntity.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanEntity.setTaskId(\"taskId\");\n partitionPlanEntity.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanEntity.setTimeoutMillis(1000L);\n when(partitionPlanRepository.save(any(PartitionPlanEntity.class))).thenReturn(partitionPlanEntity);\n\n ScheduleEntity createScheduleEntity = new ScheduleEntity();\n createScheduleEntity.setId(1L);\n createScheduleEntity.setDatabaseId(1L);\n createScheduleEntity.setTriggerConfig(new TriggerConfig(TriggerType" + ], + "line": 70, + "token": 530, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 41, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAllQueryWithFullTextSearchDisabled", + "reference": " @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n Instant time = Instant.now();\n Trace.KeyValue customField =\n Trace.KeyValue.newBuilder()\n .setVStr(\"value\")\n .setKey(\"customField\")\n .setFieldType(Schema.SchemaFieldType.KEYWORD)\n .build();\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time, List.of(customField)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n SearchResult termQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"customField:value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(termQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermStrQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermStrQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermNumericQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermNumericQuery.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"*\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNull(result.getInternalAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 42, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollExecutedCommitAsyncCallbacks", + "reference": " @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks() {\n consumer = newConsumer();\n MockCommitCallback callback = new MockCommitCallback();\n completeCommitAsyncApplicationEventSuccessfully();\n doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n completeFetchedCommittedOffsetApplicationEventSuccessfully(mkMap());\n\n consumer.assign(Collections.singleton(new TopicPartition(\"foo\", 0)));\n assertDoesNotThrow(() -> consumer.commitAsync(new HashMap<>(), callback));\n assertMockCommitCallbackInvoked(() -> consumer.poll(Duration.ZERO),\n callback,\n null);\n }", + "completions": [ + "```java\n @Test\n public void testPoll() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertNotNull(records);\n assertEquals(0, records.count());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 43, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static URI resolve(URI base, URI ref) {\n String path = ref.getRawPath();\n if (path != null && !path.isEmpty()) {\n path = base.getRawPath();\n if (path == null || !path.endsWith(\"/\")) {\n try {\n base = new URI(base.getScheme(), base.getAuthority(), base.getPath() + '/', null, null);\n } catch (URISyntaxException e) {\n throw new IllegalStateException(e);\n }\n }\n }\n return URIUtils.resolve(base, ref);\n }\n```\n###test function signature:\n```java\n @Test\n void testResolve_BasePathTrailingSlash()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testResolve_BasePathTrailingSlash", + "reference": " @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://host/sub/dir/\");\n assertEquals(\"http://host/sub/dir/file.jar\", resolve(base, \"file.jar\"));\n assertEquals(\"http://host/sub/dir/dir/file.jar\", resolve(base, \"dir/file.jar\"));\n assertEquals(\"http://host/sub/dir/?arg=val\", resolve(base, \"?arg=val\"));\n assertEquals(\"http://host/sub/dir/file?arg=val\", resolve(base, \"file?arg=val\"));\n assertEquals(\"http://host/sub/dir/dir/file?arg=val\", resolve(base, \"dir/file?arg=val\"));\n }", + "completions": [ + "```java\n @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://example.com/path/\");\n URI ref = URI.create(\"http://example.com/path/file.txt\");\n URI expected = URI.create(\"http://example.com/path/file.txt\");\n URI actual = resolve(base, ref);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 25, + "token": 187, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_collection_of_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_collection_of_non_serializable_object", + "reference": " @Test\n public void should_clone_collection_of_non_serializable_object() {\n List original = new ArrayList<>();\n original.add(new NonSerializableObject(\"value\"));\n List cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_collection_of_non_serializable_object() {\n // Given\n List nonSerializableList = new ArrayList<>();\n nonSerializableList.add(new NonSerializableObject(\"test\"));\n nonSerializableList.add(new NonSerializableObject(\"test2\"));\n\n // When\n List clonedList = objectCloner.clone(nonSerializableList);\n\n // Then\n assertNotSame(nonSerializableList, clonedList);\n assertEquals(nonSerializableList.size(), clonedList.size());\n for (int i = 0; i < nonSerializableList.size(); i++) {\n assertNotSame(nonSerializableList.get(i), clonedList.get(i));\n assertEquals(nonSerializableList.get(i), clonedList.get(i));\n }\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 45, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssign()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssign", + "reference": " @Test\n public void testAssign() {\n consumer = newConsumer();\n final TopicPartition tp = new TopicPartition(\"foo\", 3);\n consumer.assign(singleton(tp));\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().contains(tp));\n verify(applicationEventHandler).add(any(AssignmentChangeEvent.class));\n verify(applicationEventHandler).add(any(NewTopicsMetadataUpdateRequestEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testAssign() {\n // Create a mock Consumer instance\n Consumer mockConsumer = mock(Consumer.class);\n\n // Create a collection of TopicPartition instances\n Collection partitions = new ArrayList<>();\n partitions.add(new TopicPartition(\"testTopic\", 0));\n partitions.add(new TopicPartition(\"testTopic\", 1));\n\n // Call the assign method\n mockConsumer.assign(partitions);\n\n // Verify that the assign method was called with the correct parameters\n verify(mockConsumer).assign(partitions);\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 46, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String relativePath(Path absolutePath, Path basePath) {\n Preconditions.checkArgument(absolutePath.isAbsolute(), \"absolutePath must be an absolute path\");\n Preconditions.checkArgument(basePath.isAbsolute(), \"basePath must be an absolute path\");\n\n List absolutePathComponents = Lists.newArrayList(toPathComponents(absolutePath)); // make a copy\n List basePathComponents = toPathComponents(basePath);\n boolean hasCommonPrefix = basePathComponents.isEmpty(); // when basePath is \"/\", we always have a common prefix\n\n for (String base : basePathComponents) {\n if (absolutePathComponents.get(0).equals(base)) {\n absolutePathComponents.remove(0);\n hasCommonPrefix = true;\n } else {\n break;\n }\n }\n if (hasCommonPrefix) {\n return PATH_JOINER.join(absolutePathComponents);\n } else {\n return absolutePath.toString();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testRelativePath() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRelativePath", + "reference": " @Test\n public void testRelativePath() throws Exception {\n assertEquals(\"a\", PathUtils.relativePath(Path.of(\"/a\"), Path.of(\"/\")));\n assertEquals(\"b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/a\")));\n assertEquals(\"b/c.json\", PathUtils.relativePath(Path.of(\"/a/b/c.json\"), Path.of(\"/a\")));\n assertEquals(\"c/d/e\", PathUtils.relativePath(Path.of(\"/a/b/c/d/e\"), Path.of(\"/a/b\")));\n assertEquals(\"/a/b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/c/d\"))); // no common prefix\n }", + "completions": [ + "```java\n @Test\n public void testRelativePath() throws Exception {\n Path absolutePath = Paths.get(\"/a/b/c/d\");\n Path basePath = Paths.get(\"/a/b\");\n String expected = \"c/d\";\n String actual = relativePath(absolutePath, basePath);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 31, + "token": 234, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Boolean isJson(String value) {\n JsonElement jsonElement;\n try {\n jsonElement = JsonParser.parseString(value);\n } catch (Exception e) {\n return false;\n }\n if (jsonElement == null || jsonElement.isJsonNull()) {\n return false;\n }\n if (!jsonElement.isJsonObject() && !jsonElement.isJsonArray()) {\n return false;\n }\n return true;\n }\n```\n###test function signature:\n```java\n @Test\n void testIsJson()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJson", + "reference": " @Test\n void testIsJson() {\n assertThat(JsonUtil.isJson(\"value\")).isFalse();\n }", + "completions": [ + "```java\n @Test\n void testIsJson() {\n // Test case 1: Valid JSON object\n String jsonObject = \"{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}\";\n assertTrue(isJson(jsonObject));\n\n // Test case 2: Valid JSON array\n String jsonArray = \"[{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}, {\\\"name\\\":\\\"Jane\\\", \\\"age\\\":25, \\\"city\\\":\\\"Chicago\\\"}]\";\n assertTrue(isJson(jsonArray));\n\n // Test case 3: Invalid JSON (not an object or array)\n String invalidJson = \"\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"\";\n assertFalse(isJson(invalidJson));\n\n // Test case 4: Empty string\n String emptyString = \"\";\n assertFalse(isJson(emptyString));\n\n // Test case 5: Null value\n String nullValue = null;\n assertFalse(isJson(nullValue));\n }\n```\n" + ], + "line": 26, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchOnlyHistogram()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchOnlyHistogram", + "reference": " @Test\n public void testSearchOnlyHistogram() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 0,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(0);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(2);\n\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(1).getDocCount()).isEqualTo(1);\n\n assertThat(\n Long.parseLong(histogram.getBuckets().get(0).getKeyAsString()) >= time.toEpochMilli())\n .isTrue();\n assertThat(\n Long.parseLong(histogram.getBuckets().get(1).getKeyAsString())\n <= time.plusSeconds(10).toEpochMilli())\n .isTrue();\n }", + "completions": [ + "```java\n @Test\n public void testSearchOnlyHistogram() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 0;\n AggBuilder aggBuilder = new AggBuilder();\n aggBuilder.addHistogramAgg(\"testField\", \"testInterval\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation().getHistogramAggs().size() > 0);\n assertEquals(\"testField\", result.getAggregation().getHistogramAggs().get(0).getField());\n assertEquals(\"testInterval\", result.getAggregation().getHistogramAggs().get(0).getInterval());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 49, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent", + "reference": " @Test\n public void testReadComponent() {\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n EntityMetadata entityMetadata = mock(EntityMetadata.class);\n VariableSource variableSource = mock(VariableSource.class);\n when(contentPermissionChecker.isPermitted(anyString(), any(), eq(BreadActions.BROWSE), any())).thenReturn(true);\n when(component.getEntityMetadata()).thenReturn(entityMetadata);\n when(entityMetadata.getId()).thenReturn(new DetachedEntityId(\"someid\"));\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(Arrays.asList(asset));\n when(assetVariableResolver.fromAsset(asset)).thenReturn(variableSource);\n ComponentXO componentXO = underTest.readComponent(\"someid\", \"testRepositoryName\");\n\n assertThat(componentXO, is(notNullValue()));\n assertThat(componentXO.getId(), is(\"someid\"));\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n Supplier txSupplier = () -> storageTx;\n\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(txSupplier);\n when(storageTx.findComponent(componentId)).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(ImmutableList.of(asset));\n\n // When\n ComponentXO result = readComponent(repository, componentId);\n\n // Then\n verify(repository).facet(StorageFacet.class);\n verify(storageFacet).txSupplier();\n verify(storageTx).begin();\n verify(storageTx).findComponent(componentId);\n verify(storageTx).browseAssets(component);\n verify(storageTx).close();\n // Add assertions for the result\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 50, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ListResult listEntities(String tableName, String searchIndexName, List matchFilters,\n List queryFilters, List multiMatchFilter, String nextToken, List sorters, Class clazz) {\n SearchQuery searchQuery = createSearchQuery(matchFilters, queryFilters, multiMatchFilter);\n SearchRequest searchRequest = new SearchRequest(tableName, searchIndexName, searchQuery);\n if (!StringUtils.isEmpty(nextToken)) {\n byte[] tokenBytes = EncryptionUtil.decode(nextToken);\n searchRequest.getSearchQuery().setToken(tokenBytes);\n } else {\n if (sorters != null &&!sorters.isEmpty()) {\n searchQuery.setSort(new Sort(sorters));\n }\n }\n SearchRequest.ColumnsToGet columnsToGet = new SearchRequest.ColumnsToGet();\n columnsToGet.setColumns(ReflectionUtil.getPropertyNames(clazz));\n searchRequest.setColumnsToGet(columnsToGet);\n log.info(\"searchRequest:{}\", JsonUtil.toJsonString(searchRequest));\n SearchResponse searchResponse = otsClient.search(searchRequest);\n if (searchResponse == null || searchResponse.getRows() == null) {\n return ListResult.genSuccessListResult(null, 0);\n }\n byte[] nextTokenBytes = searchResponse.getNextToken();\n nextToken = nextTokenBytes == null || nextTokenBytes.length == 0 ? null : EncryptionUtil.encode(nextTokenBytes);\n List result = searchResponse.getRows().stream()\n .map(row -> OtsUtil.convertRowToDTO(row, clazz))\n .collect(Collectors.toList());\n return ListResult.genSuccessListResult(result, searchResponse.getTotalCount(), nextToken);\n }\n```\n###test function signature:\n```java\n @Test\n void testListEntities()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListEntities", + "reference": " @Test\n void testListEntities() {\n final List matchFilters = createMatchFilters();\n final List queryFilters = createQueryFilters();\n final ListResult expectedResult = new ListResult<>();\n FieldSort fieldSort = new FieldSort(OrderOtsConstant.GMT_CREATE_LONG);\n fieldSort.setOrder(SortOrder.DESC);\n expectedResult.setData(Arrays.asList());\n expectedResult.setCount(0L);\n String nextToken = \"CAESFQoTChEKDWdtdENyZWF0ZUxvbmcQARgBIlQKCQBI8UqGigEAAApHA0IAAAAxUzM1MzQzMTM0NjQzMjYzMzAzMzYyMzE2MTMzMzkzOTM1MzEzNjM2MzM2NDM2MzAzMDMwNjYzNTM1MzA2NjY0MzM=\";\n expectedResult.setNextToken(nextToken);\n final SearchResponse searchResponse = new SearchResponse(new Response(\"requestId\"));\n when(mockOtsClient.search(any(SearchRequest.class))).thenReturn(searchResponse);\n ListResult result = baseOtsHelper.listEntities(\"order\", \"order_index\", matchFilters, queryFilters, null, nextToken, Arrays.asList(fieldSort), OrderDTO.class);\n assertThat(result.getCount()).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n void testListEntities() {\n // Given\n String tableName = \"testTable\";\n String searchIndexName = \"testIndex\";\n List matchFilters = Arrays.asList(new OtsFilter(\"field1\", OtsFilter.CompareOperator.EQUAL, \"value1\"),\n new OtsFilter(\"field2\", OtsFilter.CompareOperator.EQUAL, \"value2\"));\n List queryFilters = Arrays.asList(new OtsFilter(\"field3\", OtsFilter.CompareOperator.EQUAL, \"value3\"),\n new OtsFilter(\"field4\", OtsFilter.CompareOperator.EQUAL, \"value4\"));\n List multiMatchFilter = Arrays.asList(new OtsFilter(\"field5\", OtsFilter.CompareOperator.EQUAL, \"value5\"),\n new OtsFilter(\"field6\", OtsFilter.CompareOperator.EQUAL, \"value6\"));\n String nextToken = \"testNextToken\";\n List sorters = Arrays.asList(new Sort.Sorter(\"field7\", Sort.SortOrder.ASC),\n new Sort.Sorter(\"field8\", Sort.SortOrder.DESC));\n Class clazz = TestDTO.class;\n\n // When\n ListResult result = listEntities(tableName, searchIndexName, matchFilters, queryFilters, multiMatchFilter, nextToken, sorters, clazz);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalCount());\n assertNull(result.getNextToken());\n assertTrue(result.getData().isEmpty());\n }\n```\n" + ], + "line": 38, + "token": 346, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 51, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String[] listFiles(URI fileUri, boolean recursive) throws IOException {\n try {\n ImmutableList.Builder builder = ImmutableList.builder();\n String continuationToken = null;\n boolean isDone = false;\n String prefix = normalizeToDirectoryPrefix(fileUri);\n int fileCount = 0;\n while (!isDone) {\n ListObjectsV2Request.Builder listObjectsV2RequestBuilder =\n ListObjectsV2Request.builder().bucket(fileUri.getHost());\n if (!prefix.equals(DELIMITER)) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.prefix(prefix);\n }\n if (!recursive) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.delimiter(DELIMITER);\n }\n if (continuationToken != null) {\n listObjectsV2RequestBuilder.continuationToken(continuationToken);\n }\n ListObjectsV2Request listObjectsV2Request = listObjectsV2RequestBuilder.build();\n LOG.debug(\"Trying to send ListObjectsV2Request {}\", listObjectsV2Request);\n ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);\n LOG.debug(\"Getting ListObjectsV2Response: {}\", listObjectsV2Response);\n List filesReturned = listObjectsV2Response.contents();\n fileCount += filesReturned.size();\n filesReturned.stream()\n .forEach(\n object -> {\n // Only add files and not directories\n if (!object.key().equals(fileUri.getPath())\n && !object.key().endsWith(DELIMITER)) {\n String fileKey = object.key();\n if (fileKey.startsWith(DELIMITER)) {\n fileKey = fileKey.substring(1);\n }\n builder.add(S3_SCHEME + fileUri.getHost() + DELIMITER + fileKey);\n }\n });\n if (fileCount == LIST_MAX_KEYS) {\n // check if we reached the max keys returned, if so abort and throw an error message\n LOG.error(\n \"Too many files ({}) returned from S3 when attempting to list object prefixes\",\n LIST_MAX_KEYS);\n throw new IllegalStateException(\n String.format(\n \"Max keys (%s) reached when attempting to list S3 objects\", LIST_MAX_KEYS));\n }\n isDone = !listObjectsV2Response.isTruncated();\n continuationToken = listObjectsV2Response.nextContinuationToken();\n }\n String[] listedFiles = builder.build().toArray(new String[0]);\n LOG.debug(\n \"Listed {} files from URI: {}, is recursive: {}\", listedFiles.length, fileUri, recursive);\n return listedFiles;\n } catch (Throwable t) {\n throw new IOException(t);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListFilesInFolderRecursive", + "reference": " @Test\n public void testListFilesInFolderRecursive() throws Exception {\n String folder = \"list-files-rec\";\n String[] nestedFolders = new String[] {\"list-files-child-1\", \"list-files-child-2\"};\n String[] originalFiles = new String[] {\"a-list-3.txt\", \"b-list-3.txt\", \"c-list-3.txt\"};\n\n List expectedResultList = new ArrayList<>();\n for (String childFolder : nestedFolders) {\n String folderName = folder + DELIMITER + childFolder;\n for (String fileName : originalFiles) {\n createEmptyFile(folderName, fileName);\n expectedResultList.add(\n String.format(FILE_FORMAT, SCHEME, bucket, folderName + DELIMITER + fileName));\n }\n }\n String[] actualFiles =\n s3BlobFs.listFiles(URI.create(String.format(FILE_FORMAT, SCHEME, bucket, folder)), true);\n\n actualFiles =\n Arrays.stream(actualFiles).filter(x -> x.contains(\"list-3\")).toArray(String[]::new);\n assertEquals(actualFiles.length, expectedResultList.size());\n assertTrue(Arrays.equals(expectedResultList.toArray(), actualFiles));\n }", + "completions": [ + "```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception {\n // Given\n URI fileUri = new URI(\"s3://bucket/folder/\");\n boolean recursive = true;\n String[] expectedFiles = {\n \"s3://bucket/folder/file1.txt\",\n \"s3://bucket/folder/file2.txt\",\n \"s3://bucket/folder/subfolder/file3.txt\",\n \"s3://bucket/folder/subfolder/file4.txt\"\n };\n\n // When\n String[] actualFiles = s3FileSystem.listFiles(fileUri, recursive);\n\n // Then\n assertArrayEquals(expectedFiles, actualFiles);\n }\n```\n" + ], + "line": 70, + "token": 490, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 52, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_nameDuplicate", + "reference": " @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // mock 数据\n RoleDO roleDO = randomPojo(RoleDO.class, o -> o.setName(\"role_name\"));\n roleMapper.insert(roleDO);\n // 准备参数\n String name = \"role_name\";\n\n // 调用,并断言异常\n assertServiceException(() -> roleService.validateRoleDuplicate(name, randomString(), null),\n ROLE_NAME_DUPLICATE, name);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // Given\n String name = \"duplicateName\";\n String code = \"uniqueCode\";\n Long id = 1L;\n RoleDO existingRole = new RoleDO();\n existingRole.setId(2L);\n existingRole.setName(name);\n existingRole.setCode(code);\n when(roleMapper.selectByName(name)).thenReturn(existingRole);\n\n // When\n try {\n roleService.validateRoleDuplicate(name, code, id);\n fail(\"Expected exception not thrown\");\n } catch (ServiceException e) {\n // Then\n assertEquals(ROLE_NAME_DUPLICATE, e.getCode());\n assertEquals(name, e.getArgs()[0]);\n }\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public String sign(Map map) {\n try {\n Date date = new Date(System.currentTimeMillis() + jwtProperties.getExpireTimeMills());\n Algorithm algorithm = Algorithm.HMAC256(jwtProperties.getTokenSecret());\n Map header = new HashMap<>(2);\n header.put(\"typ\", \"jwt\");\n JWTCreator.Builder builder = JWT.create()\n .withHeader(header)\n .withIssuedAt(new Date())\n .withExpiresAt(date);\n map.entrySet().forEach(entry -> {\n if (entry.getValue() instanceof Integer) {\n builder.withClaim(entry.getKey(), (Integer) entry.getValue());\n } else if (entry.getValue() instanceof Long) {\n builder.withClaim(entry.getKey(), (Long) entry.getValue());\n } else if (entry.getValue() instanceof Boolean) {\n builder.withClaim(entry.getKey(), (Boolean) entry.getValue());\n } else if (entry.getValue() instanceof String) {\n builder.withClaim(entry.getKey(), String.valueOf(entry.getValue()));\n } else if (entry.getValue() instanceof Double) {\n builder.withClaim(entry.getKey(), (Double) entry.getValue());\n } else if (entry.getValue() instanceof Date) {\n builder.withClaim(entry.getKey(), (Date) entry.getValue());\n }\n });\n return builder.sign(algorithm);\n } catch (Exception e) {\n return null;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sign_generateJwtToken_succeed()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sign_generateJwtToken_succeed", + "reference": " @Test\n public void sign_generateJwtToken_succeed() {\n Map map = new HashMap<>();\n map.put(\"userId\", 123);\n map.put(\"username\", \"testUser\");\n String token = jwtService.sign(map);\n Assert.assertNotNull(token);\n }", + "completions": [ + "```java\n @Test\n public void sign_generateJwtToken_succeed() {\n // Given\n Map map = new HashMap<>();\n map.put(\"userId\", 1);\n map.put(\"username\", \"testUser\");\n map.put(\"isAdmin\", true);\n map.put(\"expireTime\", System.currentTimeMillis() + 3600000);\n\n // When\n String jwtToken = jwtService.sign(map);\n\n // Then\n assertNotNull(jwtToken);\n assertTrue(jwtToken.length() > 0);\n }\n```\n" + ], + "line": 41, + "token": 349, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 54, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_non_serializable_value_with_null_value", + "reference": " @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map original = new LinkedHashMap<>();\n original.put(\"null\", null);\n original.put(\"key\", new NonSerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map map = new HashMap<>();\n map.put(\"key\", null);\n\n Map clonedMap = (Map) clone(map);\n\n assertNotSame(map, clonedMap);\n assertEquals(map, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 55, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void init(final DiscoveryConfig config) {\n if (this.namingService != null) {\n LOGGER.info(\"Nacos naming service already registered\");\n return;\n }\n Properties properties = config.getProps();\n Properties nacosProperties = new Properties();\n this.groupName = properties.getProperty(\"groupName\", \"SHENYU_GROUP\");\n String serverAddr = config.getServerList();\n nacosProperties.put(PropertyKeyConst.SERVER_ADDR, serverAddr);\n nacosProperties.put(PropertyKeyConst.NAMESPACE, properties.getProperty(NAMESPACE, \"\"));\n nacosProperties.put(PropertyKeyConst.USERNAME, properties.getProperty(PropertyKeyConst.USERNAME, \"\"));\n nacosProperties.put(PropertyKeyConst.PASSWORD, properties.getProperty(PropertyKeyConst.PASSWORD, \"\"));\n nacosProperties.put(PropertyKeyConst.ACCESS_KEY, properties.getProperty(PropertyKeyConst.ACCESS_KEY, \"\"));\n nacosProperties.put(PropertyKeyConst.SECRET_KEY, properties.getProperty(PropertyKeyConst.SECRET_KEY, \"\"));\n try {\n this.namingService = NamingFactory.createNamingService(nacosProperties);\n LOGGER.info(\"Nacos naming service initialized success\");\n } catch (NacosException e) {\n LOGGER.error(\"Error initializing Nacos naming service\", e);\n throw new ShenyuException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testInit", + "reference": " @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n // Set the discovery config\n setField(nacosDiscoveryServiceUnderTest.getClass(), \"namingService\", null);\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n config.setServerList(\"127.0.0.1:8848\");\n properties.setProperty(\"groupName\", \"SHENYU_GROUP\");\n config.setProps(properties);\n\n try (MockedStatic mockedNamingFactory = mockStatic(NamingFactory.class)) {\n // Mock the successful creation of NamingService\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenReturn(namingService);\n nacosDiscoveryServiceUnderTest.init(config);\n mockedNamingFactory.verify(() -> NamingFactory.createNamingService(any(Properties.class)));\n assertEquals(namingService, getField(nacosDiscoveryServiceUnderTest, \"namingService\"));\n // Mock the situation where NamingService fails to be created and throws an exception\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenThrow(new NacosException());\n assertDoesNotThrow(() -> nacosDiscoveryServiceUnderTest.init(config));\n }\n }", + "completions": [ + "```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n properties.put(\"groupName\", \"SHENYU_GROUP\");\n properties.put(\"serverAddr\", \"localhost:8848\");\n properties.put(\"namespace\", \"\");\n properties.put(\"username\", \"\");\n properties.put(\"password\", \"\");\n properties.put(\"accessKey\", \"\");\n properties.put(\"secretKey\", \"\");\n config.setProps(properties);\n config.setServerList(\"localhost:8848\");\n\n NacosDiscoveryService nacosDiscoveryService = new NacosDiscoveryService();\n nacosDiscoveryService.init(config);\n\n Field namingServiceField = NacosDiscoveryService.class.getDeclaredField(\"namingService\");\n namingServiceField.setAccessible(true);\n NamingService namingService = (NamingService) namingServiceField.get(nacosDiscoveryService);\n\n assertNotNull(namingService);\n }\n```\n" + ], + "line": 35, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 56, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyRecoveryNotBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(850, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 2) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(750, 0, indexerConfig));\n\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n assertThat(recoveryTaskCreator.determineStartingOffset(1201, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTaskCreator.determineStartingOffset(1200, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(1150, 0, indexerConfig));\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore))\n .contains(recoveryTask1, recoveryTask11);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 57, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public boolean isJsonMime(String mediaType) {\n // \"* / *\" is default to JSON\n if (\"*/*\".equals(mediaType)) {\n return true;\n }\n\n try {\n return isJsonMime(MediaType.parseMediaType(mediaType));\n } catch (InvalidMediaTypeException e) {\n }\n return false;\n }\n```\n###test function signature:\n```java\n @Test\n public void testIsJsonMime()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJsonMime", + "reference": " @Test\n public void testIsJsonMime() {\n assertFalse(apiClient.isJsonMime(null));\n assertFalse(apiClient.isJsonMime(\"\"));\n assertFalse(apiClient.isJsonMime(\"text/plain\"));\n assertFalse(apiClient.isJsonMime(\"application/xml\"));\n assertFalse(apiClient.isJsonMime(\"application/jsonp\"));\n assertFalse(apiClient.isJsonMime(\"example/json\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+jsonx\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+xjson\"));\n\n assertTrue(apiClient.isJsonMime(\"application/json\"));\n assertTrue(apiClient.isJsonMime(\"application/json; charset=UTF8\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/JSON\"));\n\n assertTrue(apiClient.isJsonMime(\"application/problem+json\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/PROBLEM+JSON\"));\n assertTrue(apiClient.isJsonMime(\"application/json\\t\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+bar+json\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json;x;y\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json\\t;\"));\n assertTrue(apiClient.isJsonMime(\"Example/fOO+JSON\"));\n }", + "completions": [ + "```java\n @Test\n public void testIsJsonMime() {\n // Test case 1: \"* / *\"\n assertTrue(isJsonMime(\"*/*\"));\n\n // Test case 2: \"application/json\"\n assertTrue(isJsonMime(\"application/json\"));\n\n // Test case 3: \"application/json; charset=UTF-8\"\n assertTrue(isJsonMime(\"application/json; charset=UTF-8\"));\n\n // Test case 4: \"application/xml\"\n assertFalse(isJsonMime(\"application/xml\"));\n\n // Test case 5: \"text/plain\"\n assertFalse(isJsonMime(\"text/plain\"));\n\n // Test case 6: \"invalid/media/type\"\n assertFalse(isJsonMime(\"invalid/media/type\"));\n\n // Test case 7: null\n assertFalse(isJsonMime(null));\n }\n```\n" + ], + "line": 22, + "token": 139, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n dpopService.validateDpopThumprint(\"test\", \"test\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n // Given\n String existingThumprint = \"existingThumprint\";\n String requestThumprint = \"existingThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n // Call the method to test\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String getExpandUrl(String url, Object paramObj) {\n UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url);\n if (paramObj instanceof Map) {\n for (Map.Entry param : ((Map) paramObj).entrySet()) {\n if (null != param.getValue()) {\n builder.queryParam(param.getKey(), param.getValue());\n }\n }\n } else {\n for (Field param : paramObj.getClass().getDeclaredFields()) {\n param.setAccessible(true);\n\n try {\n Object value = param.get(paramObj);\n if (null != value) {\n if (value instanceof List) {\n builder.queryParam(param.getName(), (List) value);\n } else {\n builder.queryParam(param.getName(), value);\n }\n }\n } catch (IllegalAccessException e) {\n log.warn(\"Field: {} value retrieval failed, exception: {}\", param.getName(), e);\n }\n }\n }\n\n return builder.build().encode().toUri().toString();\n }\n```\n###test function signature:\n```java\n @Test\n void testGetExpandUrl()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetExpandUrl", + "reference": " @Test\n void testGetExpandUrl() {\n String result = HttpUtil.getExpandUrl(\"http://baidu.com\", \"paramObj\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testGetExpandUrl() {\n // Test with a Map\n Map map = new HashMap<>();\n map.put(\"param1\", \"value1\");\n map.put(\"param2\", \"value2\");\n String url = \"http://example.com\";\n String expected = \"http://example.com?param1=value1¶m2=value2\";\n String actual = getExpandUrl(url, map);\n assertEquals(expected, actual);\n\n // Test with a custom object\n TestObject obj = new TestObject(\"value1\", \"value2\");\n expected = \"http://example.com?param1=value1¶m2=value2\";\n actual = getExpandUrl(url, obj);\n assertEquals(expected, actual);\n }\n\n public static class TestObject {\n private String param1;\n private String param2;\n\n public TestObject(String param1, String param2) {\n this.param1 = param1;\n this.param2 = param2;\n }\n\n public String getParam1() {\n return param1;\n }\n\n public String getParam2() {\n return param2;\n }\n }\n```\n" + ], + "line": 38, + "token": 277, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 60, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"asdf\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n String invalidExpirationTime = \"invalid\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(invalidExpirationTime);\n assertEquals(CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 61, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Get\n @Path(\"/:indexName/_mapping\")\n public HttpResponse mapping(\n @Param(\"indexName\") Optional indexName,\n @Param(\"startTimeEpochMs\") Optional startTimeEpochMs,\n @Param(\"endTimeEpochMs\") Optional endTimeEpochMs)\n throws IOException {\n // Use a tree map so the results are naturally sorted\n Map> propertiesMap = new TreeMap<>();\n\n // we default the schema search to the last hour if params are not provided\n AstraSearch.SchemaResult schemaResult =\n searcher.getSchema(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(\n startTimeEpochMs.orElse(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build());\n\n Map schema = SearchResultUtils.fromSchemaResultProto(schemaResult);\n schema.forEach((key, value) -> propertiesMap.put(key, Map.of(\"type\", value.getName())));\n\n // todo - remove this after we add support for a \"date\" type\n // override the timestamp as a date field for proper autocomplete\n propertiesMap.put(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, Map.of(\"type\", \"date\"));\n\n return HttpResponse.of(\n HttpStatus.OK,\n MediaType.JSON,\n JsonUtil.writeAsString(\n ImmutableMap.of(\n indexName.orElseThrow(),\n ImmutableMap.of(\"mappings\", ImmutableMap.of(\"properties\", propertiesMap)))));\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexMapping() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexMapping", + "reference": " @Test\n public void testIndexMapping() throws IOException {\n AstraQueryServiceBase searcher = mock(AstraQueryServiceBase.class);\n ElasticsearchApiService serviceUnderTest = new ElasticsearchApiService(searcher);\n\n Instant start = Instant.now();\n Instant end = start.minusSeconds(60);\n\n when(searcher.getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build())))\n .thenReturn(AstraSearch.SchemaResult.newBuilder().build());\n\n HttpResponse response =\n serviceUnderTest.mapping(\n Optional.of(\"foo\"), Optional.of(start.toEpochMilli()), Optional.of(end.toEpochMilli()));\n verify(searcher)\n .getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build()));\n\n // handle response\n AggregatedHttpResponse aggregatedRes = response.aggregate().join();\n String body = aggregatedRes.content(StandardCharsets.UTF_8);\n JsonNode jsonNode = new ObjectMapper().readTree(body);\n\n assertThat(aggregatedRes.status().code()).isEqualTo(200);\n\n assertThat(jsonNode.findValue(\"foo\")).isNotNull();\n assertThat(\n jsonNode.findValue(\"foo\").findValue(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName))\n .isNotNull();\n\n when(searcher.getSchema(any()))\n .thenAnswer(\n invocationOnMock -> {\n AstraSearch.SchemaRequest request =\n ((AstraSearch.SchemaRequest) invocationOnMock.getArguments()[0]);\n assertThat(request.getDataset()).isEqualTo(\"bar\");\n assertThat(request.getStartTimeEpochMs())\n .isCloseTo(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli(),\n Offset.offset(1000L));\n assertThat(request.getEndTimeEpochMs())\n .isCloseTo(Instant.now().toEpochMilli(), Offset.offset(1000L));\n return AstraSearch.SchemaResult.newBuilder().build();\n });\n serviceUnderTest.mapping(Optional.of(\"bar\"), Optional.empty(), Optional.empty());\n }", + "completions": [ + "```java\n @Test\n public void testIndexMapping() throws IOException {\n Optional indexName = Optional.of(\"testIndex\");\n Optional startTimeEpochMs = Optional.of(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli());\n Optional endTimeEpochMs = Optional.of(Instant.now().toEpochMilli());\n\n AstraSearch.SchemaResult schemaResult =\n AstraSearch.SchemaResult.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(startTimeEpochMs.orElse(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build();\n\n when(searcher.getSchema(any(AstraSearch.SchemaRequest.class))).thenReturn(schemaResult);\n\n HttpResponse response = mapping(indexName, startTimeEpochMs, endTimeEpochMs);\n\n assertEquals(HttpStatus.OK, response.getStatus());\n assertEquals(MediaType.JSON, response.getContentType());\n assertNotNull(response.getContent());\n }\n```\n" + ], + "line": 43, + "token": 354, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 62, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "@SuppressWarnings", + "reference": " @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n consumer = newConsumer();\n Exception e = assertThrows(UnsupportedOperationException.class, () -> consumer.poll(0L));\n assertEquals(\"Consumer.poll(long) is not supported when \\\"group.protocol\\\" is \\\"consumer\\\". \" +\n \"This method is deprecated and will be removed in the next major release.\", e.getMessage());\n }", + "completions": [ + "```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 46, + "token": 373, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 63, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testNullSearchString()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNullSearchString", + "reference": " @Test\n public void testNullSearchString() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n assertThatExceptionOfType(IllegalArgumentException.class)\n .isThrownBy(\n () ->\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME + \"miss\",\n null,\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\")));\n }", + "completions": [ + "```java\n @Test\n public void testNullSearchString() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = null;\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n try {\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n fail(\"Expected an IllegalArgumentException to be thrown\");\n } catch (IllegalArgumentException e) {\n // Expected\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 64, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture sendMessageBack(ProxyContext ctx, ReceiptHandle handle, String messageId,\n ConsumerSendMsgBackRequestHeader requestHeader, long timeoutMillis) {\n // Build the response.\n final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null, null);\n\n Integer delayLevel = requestHeader.getDelayLevel();\n if (Objects.isNull(delayLevel)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument delay level is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n Long offset = requestHeader.getOffset();\n if (Objects.isNull(offset)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument offset is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n VirtualQueue virtualQueue = new VirtualQueue(requestHeader.getBname());\n\n CompletableFuture topicFuture = topicOf(requestHeader.getOriginTopic());\n CompletableFuture groupFuture = consumerGroupOf(requestHeader.getGroup());\n return topicFuture.thenCombine(groupFuture, (topic, group) -> {\n if (topic.getTopicId() != virtualQueue.topicId()) {\n LOGGER.error(\"Topic id in request header {} does not match topic id in message queue {}, maybe the topic is recreated.\",\n topic.getTopicId(), virtualQueue.topicId());\n throw new ProxyException(apache.rocketmq.v2.Code.TOPIC_NOT_FOUND, \"Topic resource does not exist.\");\n }\n return Pair.of(topic, group);\n }).thenCompose(pair -> {\n Topic topic = pair.getLeft();\n ConsumerGroup group = pair.getRight();\n\n return store.pull(StoreContext.EMPTY, group.getGroupId(), topic.getTopicId(), virtualQueue.physicalQueueId(),\n Filter.DEFAULT_FILTER, requestHeader.getOffset(), 1, false)\n .thenApply(pullResult -> {\n if (pullResult.status() == com.automq.rocketmq.store.model.message.PullResult.Status.FOUND) {\n return pullResult.messageList().get(0);\n }\n throw new ProxyException(apache.rocketmq.v2.Code.MESSAGE_NOT_FOUND, \"Message not found from server.\");\n }).thenCompose(messageExt -> {\n if (messageExt.deliveryAttempts() > group.getMaxDeliveryAttempt()) {\n return deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n }\n\n // Message consume retry strategy\n // <0: no retry,put into DLQ directly\n // =0: broker control retry frequency\n // >0: client control retry frequency\n return switch (Integer.compare(delayLevel, 0)) {\n case -1 ->\n deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n case 0 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n // Keep the same logic as apache RocketMQ.\n int serverDelayLevel = messageExt.deliveryAttempts() + 1;\n messageExt.setDeliveryAttempts(serverDelayLevel);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(serverDelayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put messageExt to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n case 1 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n messageExt.setDeliveryAttempts(messageExt.deliveryAttempts() + 1);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(delayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put message to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n default -> throw new IllegalStateException(\"Never reach here\");\n };\n });\n }).whenComplete((nil, throwable) -> {\n if (throwable != null) {\n response.setCode(ResponseCode.SYSTEM_ERROR);\n response.setRemark(throwable.getMessage());\n }\n }).thenApply(nil -> response);\n }\n```\n###test function signature:\n```java\n @Test\n void sendMessageBack()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendMessageBack", + "reference": " @Test\n void sendMessageBack() {\n String topicName = \"topic\";\n VirtualQueue virtualQueue = new VirtualQueue(2, 0);\n ConsumerSendMsgBackRequestHeader header = new ConsumerSendMsgBackRequestHeader();\n header.setBname(virtualQueue.brokerName());\n header.setGroup(\"group\");\n header.setOriginTopic(topicName);\n header.setOffset(0L);\n header.setDelayLevel(0);\n\n CompletionException exception = assertThrowsExactly(CompletionException.class, () -> messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join());\n assertInstanceOf(ProxyException.class, ExceptionUtils.getRealException(exception));\n ProxyException realException = (ProxyException) ExceptionUtils.getRealException(exception);\n assertEquals(realException.getErrorCode(), Code.MESSAGE_NOT_FOUND);\n assertEquals(realException.getMessage(), \"Message not found from server.\");\n\n long topicId = metadataService.topicOf(topicName).join().getTopicId();\n messageStore.put(StoreContext.EMPTY, FlatMessageUtil.convertTo(ProxyContextExt.create(), topicId, 0, \"\", new Message(topicName, \"\", new byte[] {})));\n\n // Broker controlled delay level.\n header.setDelayLevel(0);\n RemotingCommand response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Client controlled delay level.\n header.setDelayLevel(16);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Forward message into dead letter topic.\n header.setDelayLevel(-1);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.times(1)).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n }", + "completions": [ + "```java\n @Test\n void sendMessageBack() {\n // Given\n ProxyContext ctx = mock(ProxyContext.class);\n ReceiptHandle handle = mock(ReceiptHandle.class);\n String messageId = \"testMessageId\";\n ConsumerSendMsgBackRequestHeader requestHeader = mock(ConsumerSendMsgBackRequestHeader.class);\n long timeoutMillis = 1000L;\n\n when(requestHeader.getDelayLevel()).thenReturn(1);\n when(requestHeader.getOffset()).thenReturn(1L);\n when(requestHeader.getOriginTopic()).thenReturn(\"testOriginTopic\");\n when(requestHeader.getGroup()).thenReturn(\"testGroup\");\n when(requestHeader.getBname()).thenReturn(\"testBname\");\n\n Topic topic = mock(Topic.class);\n when(topic.getTopicId()).thenReturn(1);\n\n ConsumerGroup group = mock(ConsumerGroup.class);\n when(group.getGroupId()).thenReturn(\"testGroupId\");\n when(group.getMaxDeliveryAttempt()).thenReturn(3);\n\n FlatMessage message = mock(FlatMessage.class);\n when(message.deliveryAttempts()).thenReturn(1);\n when(message.originalOffset()).thenReturn(1L);\n\n MessageExt messageExt = mock(MessageExt.class);\n when(messageExt.message()).thenReturn(message);\n when(messageExt.deliveryAttempts()).thenReturn(1);\n\n PullResult pullResult = mock(PullResult.class);\n when(pullResult.status()).thenReturn(PullResult.Status.FOUND);\n when(pullResult.messageList()).thenReturn(List.of(messageExt));\n\n Store store = mock(Store.class);\n when(store.pull(any(), any(), any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(pullResult));\n when(store.put(any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n DeadLetterService deadLetterService = mock(DeadLetterService.class);\n when(deadLetterService.send(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n ProxyContextExt ctxExt = mock(ProxyContextExt.class);\n\n TopicService topicService = mock(TopicService.class);\n when(topicService.topicOf(any())).thenReturn(CompletableFuture.completedFuture(topic));\n when(topicService.consumerGroupOf(any())).thenReturn(CompletableFuture.completedFuture(group));\n\n ConsumerService consumerService = new ConsumerService(store, deadLetterService, topicService);\n\n // When\n CompletableFuture future = consumerService.sendMessageBack(ctx, handle, messageId, requestHeader, timeoutMillis);\n\n // Then\n RemotingCommand response = future.join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n }\n```\n" + ], + "line": 99, + "token": 793, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 65, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInRequestThreads()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInRequestThreads", + "reference": " @Test\n public void testGetTracerInRequestThreads() {\n ApolloAuditTracer mockTracer = new ApolloAuditTracer(Mockito.mock(ApolloAuditScopeManager.class), supplier);\n RequestAttributes mockRequestAttributes = Mockito.mock(RequestAttributes.class);\n RequestContextHolder.setRequestAttributes(mockRequestAttributes);\n Mockito.when(mockRequestAttributes.getAttribute(Mockito.eq(ApolloAuditConstants.TRACER), Mockito.eq(RequestAttributes.SCOPE_REQUEST)))\n .thenReturn(mockTracer);\n ApolloAuditTracer get = traceContext.tracer();\n assertNotNull(get);\n Mockito.verify(traceContext, Mockito.times(0))\n .setTracer(Mockito.any(ApolloAuditTracer.class));\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInRequestThreads() {\n // Given\n RequestAttributes requestAttributes = Mockito.mock(RequestAttributes.class);\n ApolloAuditTracer tracer = Mockito.mock(ApolloAuditTracer.class);\n Mockito.when(requestAttributes.getAttribute(ApolloAuditConstants.TRACER, RequestAttributes.SCOPE_REQUEST)).thenReturn(tracer);\n Mockito.when(RequestContextHolder.getRequestAttributes()).thenReturn(requestAttributes);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n Assert.assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Map generate(String modelName) {\n Map kv = new HashMap<>();\n if (modelName != null) {\n final Schema schema = this.globalContext.getSchemaMap().get(modelName);\n if (schema != null) {\n String example = Json.pretty(this.resolveModelToExample(modelName, schema));\n\n if (example != null) {\n kv.put(EXAMPLE, example);\n return Map.copyOf(kv);\n }\n }\n }\n return Collections.emptyMap();\n }\n```\n###test function signature:\n```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "shouldGenerateOneOfWhenOneOfInRoot", + "reference": " @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n OpenAPIModelGenerator generator = setupPayloadGenerator();\n\n Map example = generator.generate(\"PetType\");\n String exampleJson = example.get(\"example\");\n\n Assertions.assertThat(exampleJson)\n .contains(\"PetTypeONE_OF#/components/schemas/Husky\")\n .contains(\"PetTypeONE_OF#/components/schemas/Labrador\");\n }", + "completions": [ + "```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n // Given\n String modelName = \"ModelName\";\n Schema schema = new Schema();\n schema.setOneOf(Arrays.asList(new Schema(), new Schema()));\n this.globalContext.getSchemaMap().put(modelName, schema);\n\n // When\n Map result = this.generate(modelName);\n\n // Then\n assertNotNull(result);\n assertFalse(result.isEmpty());\n assertTrue(result.containsKey(EXAMPLE));\n assertNotNull(result.get(EXAMPLE));\n }\n```\n" + ], + "line": 25, + "token": 181, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInAnotherThreadButSameRequest", + "reference": " @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n ApolloAuditTracer mockTracer = Mockito.mock(ApolloAuditTracer.class);\n {\n Mockito.when(traceContext.tracer()).thenReturn(mockTracer);\n }\n CountDownLatch latch = new CountDownLatch(1);\n Executors.newSingleThreadExecutor().submit(() -> {\n ApolloAuditTracer tracer = traceContext.tracer();\n\n assertEquals(mockTracer, tracer);\n\n latch.countDown();\n });\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n // Given\n RequestAttributes requestAttributes = new ServletRequestAttributes(new MockHttpServletRequest());\n RequestContextHolder.setRequestAttributes(requestAttributes);\n ApolloAuditTracer tracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n requestAttributes.setAttribute(ApolloAuditConstants.TRACER, tracer, RequestAttributes.SCOPE_REQUEST);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchMultipleItemsAndIndices()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchMultipleItemsAndIndices", + "reference": " @Test\n public void testSearchMultipleItemsAndIndices() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(1);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchMultipleItemsAndIndices() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(10, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 69, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n static Map>> getCandidateJobs(String tableNameWithType,\n Map> allJobMetadata)\n throws Exception {\n long nowMs = System.currentTimeMillis();\n Map>> candidates = new HashMap<>();\n // If the job started most recently has already completed, then skip retry for the table.\n Pair latestStartedJob = null;\n Pair latestCompletedJob = null;\n // The processing order of job metadata from the given Map is not deterministic. Track the completed original\n // jobs so that we can simply skip the retry jobs belonging to the completed original jobs.\n Map completedOriginalJobs = new HashMap<>();\n Set cancelledOriginalJobs = new HashSet<>();\n for (Map.Entry> entry : allJobMetadata.entrySet()) {\n String jobId = entry.getKey();\n Map jobMetadata = entry.getValue();\n long statsUpdatedAt = Long.parseLong(jobMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));\n String jobStatsInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);\n if (StringUtils.isEmpty(jobStatsInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job progress stats\", jobId);\n continue;\n }\n String jobCtxInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n if (StringUtils.isEmpty(jobCtxInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job context\", jobId);\n continue;\n }\n TableRebalanceProgressStats jobStats = JsonUtils.stringToObject(jobStatsInStr, TableRebalanceProgressStats.class);\n TableRebalanceContext jobCtx = JsonUtils.stringToObject(jobCtxInStr, TableRebalanceContext.class);\n long jobStartTimeMs = jobStats.getStartTimeMs();\n if (latestStartedJob == null || latestStartedJob.getRight() < jobStartTimeMs) {\n latestStartedJob = Pair.of(jobId, jobStartTimeMs);\n }\n String originalJobId = jobCtx.getOriginalJobId();\n RebalanceResult.Status jobStatus = jobStats.getStatus();\n if (jobStatus == RebalanceResult.Status.DONE || jobStatus == RebalanceResult.Status.NO_OP) {\n LOGGER.info(\"Skip rebalance job: {} as it has completed with status: {}\", jobId, jobStatus);\n completedOriginalJobs.put(originalJobId, jobId);\n if (latestCompletedJob == null || latestCompletedJob.getRight() < jobStartTimeMs) {\n latestCompletedJob = Pair.of(jobId, jobStartTimeMs);\n }\n continue;\n }\n if (jobStatus == RebalanceResult.Status.FAILED || jobStatus == RebalanceResult.Status.ABORTED) {\n LOGGER.info(\"Found rebalance job: {} for original job: {} has been stopped with status: {}\", jobId,\n originalJobId, jobStatus);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n continue;\n }\n if (jobStatus == RebalanceResult.Status.CANCELLED) {\n LOGGER.info(\"Found cancelled rebalance job: {} for original job: {}\", jobId, originalJobId);\n cancelledOriginalJobs.add(originalJobId);\n continue;\n }\n // Check if an IN_PROGRESS job is still actively running.\n long heartbeatTimeoutMs = jobCtx.getConfig().getHeartbeatTimeoutInMs();\n if (nowMs - statsUpdatedAt < heartbeatTimeoutMs) {\n LOGGER.info(\"Rebalance job: {} is actively running with status updated at: {} within timeout: {}. Skip \"\n + \"retry for table: {}\", jobId, statsUpdatedAt, heartbeatTimeoutMs, tableNameWithType);\n return Collections.emptyMap();\n }\n // The job is considered failed, but it's possible it is still running, then we might end up with more than one\n // rebalance jobs running in parallel for a table. The rebalance algorithm is idempotent, so this should be fine\n // for the correctness.\n LOGGER.info(\"Found stuck rebalance job: {} for original job: {}\", jobId, originalJobId);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n }\n if (latestCompletedJob != null && latestCompletedJob.getLeft().equals(latestStartedJob.getLeft())) {\n LOGGER.info(\"Rebalance job: {} started most recently has already done. Skip retry for table: {}\",\n latestCompletedJob.getLeft(), tableNameWithType);\n return Collections.emptyMap();\n }\n for (String jobId : cancelledOriginalJobs) {\n LOGGER.info(\"Skip original job: {} as it's cancelled\", jobId);\n candidates.remove(jobId);\n }\n for (Map.Entry entry : completedOriginalJobs.entrySet()) {\n LOGGER.info(\"Skip original job: {} as it's completed by attempt: {}\", entry.getKey(), entry.getValue());\n candidates.remove(entry.getKey());\n }\n return candidates;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetCandidateJobs()\n throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetCandidateJobs", + "reference": " @Test\n public void testGetCandidateJobs()\n throws Exception {\n String tableName = \"table01\";\n Map> allJobMetadata = new HashMap<>();\n\n // Original job run as job1, and all its retry jobs failed too.\n RebalanceConfig jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n TableRebalanceProgressStats stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(1000);\n TableRebalanceContext jobCtx = TableRebalanceContext.forInitialAttempt(\"job1\", jobCfg);\n Map jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job1\", stats, jobCtx);\n allJobMetadata.put(\"job1\", jobMetadata);\n // 3 failed retry runs for job1\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 2, 1100, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 3, 1200, RebalanceResult.Status.ABORTED);\n allJobMetadata.put(\"job1_3\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 4, 1300, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_4\", jobMetadata);\n\n // Original job run as job2, and its retry job job2_1 completed.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(2000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job2\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job2\", stats, jobCtx);\n allJobMetadata.put(\"job2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job2\", 2, 2100, RebalanceResult.Status.DONE);\n allJobMetadata.put(\"job2_2\", jobMetadata);\n\n // Original job run as job3, and failed to send out heartbeat in time.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.IN_PROGRESS);\n stats.setStartTimeMs(3000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job3\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job3\", stats, jobCtx);\n jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"3000\");\n allJobMetadata.put(\"job3\", jobMetadata);\n\n // Original job run as job4, which didn't have retryJobCfg as from old version of the code.\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(4000);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job4\", stats, null);\n jobMetadata.remove(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n allJobMetadata.put(\"job4\", jobMetadata);\n\n // Only need to retry job1 and job3, as job2 is completed and job4 is from old version of code.\n Map>> jobs =\n RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 2);\n assertTrue(jobs.containsKey(\"job1\"));\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job1\").size(), 4); // four runs including job1,job1_1,job1_2,job1_3\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Abort job1 and cancel its retries, then only job3 is retry candidate.\n jobMetadata = allJobMetadata.get(\"job1_4\");\n cancelRebalanceJob(jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 1);\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Add latest job5 that's already done, thus no need to retry for table.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.DONE);\n stats.setStartTimeMs(5000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job5\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job5\", stats, jobCtx);\n allJobMetadata.put(\"job5\", jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 0);\n }", + "completions": [ + "```java\n @Test\n public void testGetCandidateJobs()\n throws Exception {\n // Given\n String tableNameWithType = \"table1_type1\";\n Map> allJobMetadata = new HashMap<>();\n Map jobMetadata1 = new HashMap<>();\n jobMetadata1.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"FAILED\\\"}\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job1\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job1\", jobMetadata1);\n Map jobMetadata2 = new HashMap<>();\n jobMetadata2.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"CANCELLED\\\"}\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job2\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job2\", jobMetadata2);\n\n // When\n Map>> result = getCandidateJobs(tableNameWithType, allJobMetadata);\n\n // Then\n assertEquals(1, result.size());\n assertTrue(result.containsKey(\"job2\"));\n Set> jobSet = result.get(\"job2\");\n assertEquals(1, jobSet.size());\n Pair jobPair = jobSet.iterator().next();\n assertEquals(\"job2\", jobPair.getLeft().getOriginalJobId());\n assertEquals(1640995200000L, jobPair.getRight().longValue());\n }\n```\n" + ], + "line": 94, + "token": 952, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 70, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchError()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchError", + "reference": " @Test\n public void distClientMatchError() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.ERROR);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.ERROR);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.NO_SUB));\n for (IMQTTTransientSession session : sessions) {\n verify(session, never()).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchError() {\n // Create a mock for the ITenantMeter interface\n ITenantMeter mockTenantMeter = Mockito.mock(ITenantMeter.class);\n // Create a mock for the ILocalDistService interface\n ILocalDistService mockLocalDistService = Mockito.mock(ILocalDistService.class);\n // Create a mock for the sessionMap\n Map mockSessionMap = new HashMap<>();\n // Create a mock for the routeMap\n Map> mockRouteMap = new HashMap<>();\n // Create a mock for the resourceThrottler\n IResourceThrottler mockResourceThrottler = Mockito.mock(IResourceThrottler.class);\n\n // Set up the mock objects\n Mockito.when(mockTenantMeter.get(Mockito.anyString())).thenReturn(mockTenantMeter);\n Mockito.when(mockLocalDistService.isGlobal(Mockito.anyString())).thenReturn(false);\n Mockito.when(mockSessionMap.get(Mockito.anyString())).thenReturn(null);\n Mockito.when(mockRouteMap.get(Mockito.any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(null));\n Mockito.when(mockResourceThrottler.hasResource(Mockito.anyString(), Mockito.anyLong())).thenReturn(false);\n\n // Create a new instance of the class under test\n DistService distService = new DistService(mockTenantMeter, mockLocalDistService, mockSessionMap, mockRouteMap, mockResourceThrottler);\n\n // Create a test request\n DeliveryRequest request = DeliveryRequest.newBuilder().build();\n\n // Call the method under test\n CompletableFuture replyFuture = distService.dist(request);\n\n // Verify the results\n DeliveryReply reply = replyFuture.join();\n Assert.assertEquals(0, reply.getResultMap().size());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 71, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchOK()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchOK", + "reference": " @Test\n public void distClientMatchOK() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n when(session.publish(any(), anyList())).thenReturn(true);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.OK);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.OK);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.OK));\n for (IMQTTTransientSession session : sessions) {\n verify(session, times(1)).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchOK() {\n // Given\n DeliveryRequest request = DeliveryRequest.newBuilder()\n .putPackage(\"tenant1\", DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .setMessagePack(TopicMessagePack.newBuilder()\n .setTopic(\"topic1\")\n .setPayload(\"payload1\")\n .build())\n .addMatchInfo(MatchInfo.newBuilder()\n .setReceiverId(\"receiver1\")\n .setTopicFilter(\"topic1\")\n .build())\n .build())\n .build())\n .build();\n sessionMap.put(\"receiver1\", mock(IMQTTTransientSession.class));\n when(sessionMap.get(\"receiver1\").publish(any(MatchInfo.class), anyList())).thenReturn(true);\n when(routeMap.get(any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(LocalRoutes.newBuilder()\n .setLocalizedReceiverId(\"receiver1\")\n .putRoute(\"receiver1\", mock(IMQTTTransientSession.class))\n .build()));\n when(ITenantMeter.get(\"tenant1\")).thenReturn(mock(ITenantMeter.class));\n when(resourceThrottler.hasResource(\"tenant1\", TotalTransientFanOutBytesPerSeconds)).thenReturn(true);\n\n // When\n CompletableFuture replyFuture = dist(request);\n\n // Then\n DeliveryReply reply = replyFuture.join();\n DeliveryResults results = reply.getResultMap().get(\"tenant1\");\n assertEquals(1, results.getResultCount());\n DeliveryResult result = results.getResult(0);\n assertEquals(DeliveryResult.Code.OK, result.getCode());\n assertEquals(\"receiver1\", result.getMatchInfo().getReceiverId());\n assertEquals(\"topic1\", result.getMatchInfo().getTopicFilter());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 72, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Range calcInterquartileRangeApproximation(List sortedValues) {\n Range iqr = null;\n\n if (sortedValues.size() > 0) {\n BigDecimal q1 = valueCloseToQ1(sortedValues);\n BigDecimal q3 = valueCloseToQ3(sortedValues);\n BigDecimal max = sortedValues.get(sortedValues.size() - 1);\n\n // if iqr == 0 AND max == q3 then recursively try finding a non-zero iqr approximation.\n if (q1.compareTo(q3) == 0 && max.compareTo(q3) == 0) {\n // filter out max and try again\n iqr = this.calcInterquartileRangeApproximation(\n sortedValues.stream().filter(d -> d.compareTo(max) < 0).collect(Collectors.toList()));\n }\n\n // if range is still empty use the original q1 and q3 values\n if (iqr == null || iqr.isEmpty()) {\n iqr = Range.closedOpen(q1, q3);\n }\n }\n\n return iqr;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCalcInterquartileRangeApproximation()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCalcInterquartileRangeApproximation", + "reference": " @Test\n public void testCalcInterquartileRangeApproximation() {\n List values = decList(1, 2, 3, 4, 5, 6, 7, 8);\n assertEquals(new BigDecimal(\"3\"), dataBinHelper.calcInterquartileRangeApproximation(values).lowerEndpoint());\n assertEquals(new BigDecimal(\"7\"), dataBinHelper.calcInterquartileRangeApproximation(values).upperEndpoint());\n }", + "completions": [ + "```java\n @Test\n public void testCalcInterquartileRangeApproximation() {\n List sortedValues = Arrays.asList(\n new BigDecimal(\"1.0\"),\n new BigDecimal(\"2.0\"),\n new BigDecimal(\"3.0\"),\n new BigDecimal(\"4.0\"),\n new BigDecimal(\"5.0\"),\n new BigDecimal(\"6.0\"),\n new BigDecimal(\"7.0\"),\n new BigDecimal(\"8.0\"),\n new BigDecimal(\"9.0\"),\n new BigDecimal(\"10.0\")\n );\n\n Range expectedIqr = Range.closedOpen(new BigDecimal(\"2.5\"), new BigDecimal(\"7.5\"));\n Range actualIqr = calcInterquartileRangeApproximation(sortedValues);\n\n assertEquals(expectedIqr, actualIqr);\n }\n```\n" + ], + "line": 30, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 73, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_non_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new NonSerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new NonSerializableObject(\"name2\"),\n new NonSerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObjectWithNonSerializableNestedObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serialable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key1\", \"value1\");\n map.put(\"key2\", new NonSerializableObject(\"nonSerializableValue\"));\n\n // When\n Map clonedMap = objectCloner.clone(map);\n\n // Then\n assertNotSame(map, clonedMap);\n assertEquals(map.size(), clonedMap.size());\n assertEquals(map.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(map.get(\"key2\"), clonedMap.get(\"key2\"));\n assertEquals(map.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 74, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SignedJWT token2Jwt(String token) {\n if (token == null) {\n log.error(\"Token is null.\");\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n try {\n return SignedJWT.parse(token);\n } catch (ParseException e) {\n log.error(\"Token to JWT Error,message = {}\", e.getMessage());\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testToken2Jwt()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testToken2Jwt", + "reference": " @Test\n public void testToken2Jwt() {\n Assertions.assertThrows(BizException.class, () -> tokenParseHelper.token2Jwt(null));\n }", + "completions": [ + "```java\n @Test\n public void testToken2Jwt() {\n // Given\n String token = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c\";\n\n // When\n SignedJWT jwt = token2Jwt(token);\n\n // Then\n assertNotNull(jwt);\n assertEquals(token, jwt.serialize());\n }\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testLongPollWaitIsLimited()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testLongPollWaitIsLimited", + "reference": " @Test\n public void testLongPollWaitIsLimited() {\n consumer = newConsumer();\n String topicName = \"topic1\";\n consumer.subscribe(singletonList(topicName));\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertTrue(consumer.assignment().isEmpty());\n\n final int partition = 3;\n final TopicPartition tp = new TopicPartition(topicName, partition);\n final List> records = asList(\n new ConsumerRecord<>(topicName, partition, 2, \"key1\", \"value1\"),\n new ConsumerRecord<>(topicName, partition, 3, \"key2\", \"value2\")\n );\n\n // On the first iteration, return no data; on the second, return two records\n doAnswer(invocation -> {\n // Mock the subscription being assigned as the first fetch is collected\n consumer.subscriptions().assignFromSubscribed(Collections.singleton(tp));\n return Fetch.empty();\n }).doAnswer(invocation -> {\n return Fetch.forPartition(tp, records, true);\n }).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n\n // And then poll for up to 10000ms, which should return 2 records without timing out\n ConsumerRecords returnedRecords = consumer.poll(Duration.ofMillis(10000));\n assertEquals(2, returnedRecords.count());\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertEquals(singleton(tp), consumer.assignment());\n }", + "completions": [ + "```java\n @Test\n public void testLongPollWaitIsLimited() {\n // Arrange\n Duration timeout = Duration.ofMillis(1000);\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // Act\n long startTime = System.currentTimeMillis();\n ConsumerRecords records = consumer.poll(timeout);\n long endTime = System.currentTimeMillis();\n\n // Assert\n long pollDuration = endTime - startTime;\n assertTrue(pollDuration >= timeout.toMillis());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 76, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForMinAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForMinAgg", + "reference": " @Test\n public void testFullIndexSearchForMinAgg() {\n Instant time = Instant.now();\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new MinAggBuilder(\n \"test\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalMin internalMin =\n (InternalMin) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n assertThat(Double.valueOf(internalMin.getValue()).longValue()).isEqualTo(time.toEpochMilli());\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForMinAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new MinAggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation() instanceof MinAgg);\n assertTrue(result.getElapsedTimeMicros() > 0);\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 77, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_bsaChecksumFetched", + "reference": " @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n setupMocks();\n when(connection.getResponseCode()).thenReturn(SC_OK);\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream(\"bsa-checksum\\ndata\".getBytes(UTF_8)));\n LazyBlockList download = fetcher.tryFetch(BlockListType.BLOCK);\n assertThat(download.getName()).isEqualTo(BlockListType.BLOCK);\n assertThat(download.checksum()).isEqualTo(\"bsa-checksum\");\n verify(connection, times(1)).setRequestMethod(\"GET\");\n verify(connection, times(1)).setRequestProperty(\"Authorization\", \"Bearer authToken\");\n }", + "completions": [ + "```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n // Given\n URL dataUrl = new URL(\"https://example.com/bsaChecksum\");\n HttpsURLConnection connection = mock(HttpsURLConnection.class);\n when(connection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);\n when(connection.getErrorStream()).thenReturn(new ByteArrayInputStream(\"\".getBytes()));\n when(urlConnectionService.createConnection(dataUrl)).thenReturn(connection);\n\n // When\n LazyBlockList lazyBlockList = bsaService.tryFetch(BlockListType.BSA_CHECKSUM);\n\n // Then\n assertEquals(BlockListType.BSA_CHECKSUM, lazyBlockList.getBlockListType());\n assertEquals(connection, lazyBlockList.getConnection());\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 78, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDisabledFullTextSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDisabledFullTextSearch", + "reference": " @Test\n public void testDisabledFullTextSearch() {\n Instant time = Instant.ofEpochSecond(1593365471);\n Trace.KeyValue field1Tag =\n Trace.KeyValue.newBuilder()\n .setVInt32(1234)\n .setKey(\"field1\")\n .setFieldType(Schema.SchemaFieldType.INTEGER)\n .build();\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(2, \"apple baby\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(3, \"baby car 1234\", time.plusSeconds(4)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Without the _all field as default.\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(3);\n\n // empty string\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"app*\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Returns baby or car, 2 messages.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby car\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Test numbers\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"apple 1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"123\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n }", + "completions": [ + "```java\n @Test\n public void testDisabledFullTextSearch() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getResults().size());\n assertNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 79, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyWithException", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(\n OperateRuntimeException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHitsRelation.EQUAL_TO));\n Mockito.when(searchHits.getHits()).thenReturn(new SearchHit[0]);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processInstanceListViewService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException to be thrown\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find unique process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 80, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_SecurityException_notRetriable", + "reference": " @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n when(connectionService.createConnection(any(URL.class)))\n .thenThrow(new GeneralSecurityException());\n assertThat(\n assertThrows(BsaException.class, () -> fetcher.tryFetch(BlockListType.BLOCK))\n .isRetriable())\n .isFalse();\n }", + "completions": [ + "```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n // Arrange\n String blockListTypeName = \"TEST\";\n BlockListType blockListType = new BlockListType(blockListTypeName);\n URL dataUrl = new URL(\"http://example.com\");\n when(blockListUrls.get(blockListTypeName)).thenReturn(dataUrl.toString());\n when(urlConnectionService.createConnection(dataUrl)).thenThrow(new GeneralSecurityException());\n\n // Act and Assert\n assertThrows(BsaException.class, () -> {\n try {\n blockListService.tryFetch(blockListType);\n } catch (BsaException e) {\n assertFalse(e.isRetriable());\n throw e;\n }\n });\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 81, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssignOnEmptyTopicPartition()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssignOnEmptyTopicPartition", + "reference": " @Test\n public void testAssignOnEmptyTopicPartition() {\n consumer = newConsumer();\n completeUnsubscribeApplicationEventSuccessfully();\n\n consumer.assign(Collections.emptyList());\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().isEmpty());\n }", + "completions": [ + "```java\n @Test\n public void testAssignOnEmptyTopicPartition() {\n // Given\n Collection partitions = new ArrayList<>();\n ConsumerImpl consumer = new ConsumerImpl();\n\n // When\n consumer.assign(partitions);\n\n // Then\n // Add assertions to verify the expected behavior\n // For example, you can check if the consumer is unsubscribed and if the fetch buffer is empty\n // assertFalse(consumer.isSubscribed());\n // assertTrue(consumer.getFetchBuffer().isEmpty());\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 82, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ImmutableList applyLabelDiff(\n ImmutableList labels,\n IdnChecker idnChecker,\n DownloadSchedule schedule,\n DateTime now) {\n ImmutableList.Builder nonBlockedDomains = new ImmutableList.Builder<>();\n ImmutableMap> labelsByType =\n ImmutableMap.copyOf(\n labels.stream().collect(groupingBy(BlockLabel::labelType, toImmutableList())));\n\n tm().transact(\n () -> {\n for (Map.Entry> entry :\n labelsByType.entrySet()) {\n switch (entry.getKey()) {\n case CREATE:\n // With current Cloud SQL, label upsert throughput is about 200/second. If\n // better performance is needed, consider bulk insert in native SQL.\n tm().putAll(\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(\n label ->\n new BsaLabel(label.label(), schedule.jobCreationTime()))\n .collect(toImmutableList()));\n // May not find all unblockables due to race condition: DomainCreateFlow uses\n // cached BsaLabels. Eventually will be consistent.\n nonBlockedDomains.addAll(\n tallyUnblockableDomainsForNewLabels(entry.getValue(), idnChecker, now));\n break;\n case DELETE:\n ImmutableSet deletedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n // Delete labels in DB. Also cascade-delete BsaUnblockableDomain.\n int nDeleted = Queries.deleteBsaLabelByLabels(deletedLabels);\n if (nDeleted != deletedLabels.size()) {\n logger.atSevere().log(\n \"Only found %s entities among the %s labels: [%s]\",\n nDeleted, deletedLabels.size(), deletedLabels);\n }\n break;\n case NEW_ORDER_ASSOCIATION:\n ImmutableSet affectedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n ImmutableSet labelsInDb =\n Queries.queryBsaLabelByLabels(affectedLabels)\n .map(BsaLabel::getLabel)\n .collect(toImmutableSet());\n verify(\n labelsInDb.size() == affectedLabels.size(),\n \"Missing labels in DB: %s\",\n LazyArgs.lazy(() -> difference(affectedLabels, labelsInDb)));\n\n // Reuse registered and reserved names that are already computed.\n Queries.queryBsaUnblockableDomainByLabels(affectedLabels)\n .map(BsaUnblockableDomain::toUnblockableDomain)\n .forEach(nonBlockedDomains::add);\n\n for (BlockLabel label : entry.getValue()) {\n getInvalidTldsForLabel(label, idnChecker)\n .map(tld -> UnblockableDomain.of(label.label(), tld, Reason.INVALID))\n .forEach(nonBlockedDomains::add);\n }\n break;\n }\n }\n },\n TRANSACTION_REPEATABLE_READ);\n logger.atInfo().log(\"Processed %s of labels.\", labels.size());\n return nonBlockedDomains.build();\n }\n```\n###test function signature:\n```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "applyLabelDiffs_newAssociationOfLabelToOrder", + "reference": " @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n tm().transact(\n () -> {\n tm().insert(new BsaLabel(\"label\", fakeClock.nowUtc()));\n tm().insert(new BsaUnblockableDomain(\"label\", \"app\", Reason.REGISTERED));\n });\n when(idnChecker.getSupportingTlds(any())).thenReturn(ImmutableSet.of(app));\n when(idnChecker.getForbiddingTlds(any()))\n .thenReturn(Sets.difference(ImmutableSet.of(dev), ImmutableSet.of()).immutableCopy());\n\n ImmutableList unblockableDomains =\n applyLabelDiff(\n ImmutableList.of(\n BlockLabel.of(\"label\", LabelType.NEW_ORDER_ASSOCIATION, ImmutableSet.of())),\n idnChecker,\n schedule,\n fakeClock.nowUtc());\n assertThat(unblockableDomains)\n .containsExactly(\n UnblockableDomain.of(\"label.app\", UnblockableDomain.Reason.REGISTERED),\n UnblockableDomain.of(\"label.dev\", UnblockableDomain.Reason.INVALID));\n assertThat(tm().transact(() -> tm().loadByKeyIfPresent(BsaLabel.vKey(\"label\")))).isPresent();\n assertThat(\n tm().transact(() -> tm().loadByKeyIfPresent(BsaUnblockableDomain.vKey(\"label\", \"app\"))))\n .isPresent();\n }", + "completions": [ + "```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n // Given\n ImmutableList labels = ImmutableList.of(BlockLabel.of(\"label1\", LabelType.NEW_ORDER_ASSOCIATION));\n IdnChecker idnChecker = mock(IdnChecker.class);\n DownloadSchedule schedule = mock(DownloadSchedule.class);\n DateTime now = DateTime.now();\n\n // When\n ImmutableList result = applyLabelDiff(labels, idnChecker, schedule, now);\n\n // Then\n // Add assertions here to verify the result\n }\n```\n" + ], + "line": 85, + "token": 602, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 83, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture createTopic(CreateTopicRequest request) {\n CompletableFuture future = new CompletableFuture<>();\n for (; ; ) {\n if (metadataStore.isLeader()) {\n try (SqlSession session = metadataStore.openSession()) {\n if (!metadataStore.maintainLeadershipWithSharedLock(session)) {\n continue;\n }\n TopicMapper topicMapper = session.getMapper(TopicMapper.class);\n if (null != topicMapper.get(null, request.getTopic())) {\n ControllerException e = new ControllerException(Code.DUPLICATED_VALUE,\n String.format(\"Topic %s was taken\", request.getTopic()));\n future.completeExceptionally(e);\n return future;\n }\n\n Topic topic = new Topic();\n topic.setName(request.getTopic());\n topic.setQueueNum(request.getCount());\n topic.setStatus(TopicStatus.TOPIC_STATUS_ACTIVE);\n topic.setAcceptMessageTypes(JsonFormat.printer().print(request.getAcceptTypes()));\n topic.setRetentionHours(request.getRetentionHours());\n topicMapper.create(topic);\n long topicId = topic.getId();\n List assignments = createQueues(IntStream.range(0, request.getCount()),\n topicId, session);\n // Commit transaction\n session.commit();\n\n // Cache new topic and queue assignments immediately\n topicCache.apply(List.of(topic));\n assignmentCache.apply(assignments);\n future.complete(topicId);\n } catch (ControllerException | InvalidProtocolBufferException e) {\n future.completeExceptionally(e);\n }\n return future;\n } else {\n Optional leaderAddress = metadataStore.electionService().leaderAddress();\n if (leaderAddress.isEmpty()) {\n return CompletableFuture.failedFuture(new ControllerException(Code.NO_LEADER_VALUE, \"No leader is elected yet\"));\n }\n metadataStore.controllerClient().createTopic(leaderAddress.get(), request).whenComplete((res, e) -> {\n if (null != e) {\n future.completeExceptionally(e);\n } else {\n future.complete(res);\n }\n });\n }\n break;\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCreateTopic()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCreateTopic", + "reference": " @Test\n public void testCreateTopic() {\n MetadataStore metadataStore = Mockito.mock(MetadataStore.class);\n ControllerClient controllerClient = Mockito.mock(ControllerClient.class);\n Mockito.when(controllerClient.createTopic(ArgumentMatchers.anyString(), ArgumentMatchers.any())).thenReturn(\n CompletableFuture.failedFuture(new CompletionException(new RuntimeException()))\n );\n Mockito.when(metadataStore.controllerClient()).thenReturn(controllerClient);\n Mockito.when(metadataStore.isLeader()).thenReturn(false);\n\n ElectionService electionService = Mockito.mock(ElectionService.class);\n Mockito.when(metadataStore.electionService()).thenReturn(electionService);\n\n Mockito.when(electionService.leaderAddress()).thenReturn(Optional.of(\"localhost:1234\"));\n\n TopicManager topicManager = new TopicManager(metadataStore);\n topicManager.createTopic(CreateTopicRequest.newBuilder().build());\n }", + "completions": [ + "```java\n @Test\n public void testCreateTopic() {\n // Given\n CreateTopicRequest request = new CreateTopicRequest();\n request.setTopic(\"testTopic\");\n request.setCount(3);\n request.setRetentionHours(24);\n request.setAcceptTypes(List.of(MessageType.MESSAGE_TYPE_JSON));\n\n // When\n CompletableFuture future = controller.createTopic(request);\n\n // Then\n assertDoesNotThrow(() -> {\n Long topicId = future.get();\n assertNotNull(topicId);\n assertTrue(topicId > 0);\n });\n }\n```\n" + ], + "line": 63, + "token": 424, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 84, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull(){\n try(MockedStatic sessionMockedStatic = Mockito.mockStatic(Session.class)){\n Session mockSession = mock(Session.class);\n Transport transport = mock(Transport.class);\n sessionMockedStatic.when(()->Session.getDefaultInstance(any(Properties.class))).thenReturn(mockSession);\n when(blacklistService.inBlacklist(anyString())).thenReturn(false);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n when(settingsService.getByCategoryAndKey(anyString(),anyString())).thenReturn(\"test\").thenReturn(5678).thenReturn(\"test\").thenReturn(\"test\").thenReturn(\"123456\");\n when(mockSession.getTransport(anyString())).thenReturn(transport);\n doNothing().when(transport).connect(anyString(),anyInt(),anyString(),anyString());\n doAnswer(invocationOnMock -> {\n InternetAddress[] internetAddressList = invocationOnMock.getArgument(1);\n Assertions.assertEquals(addressList.get(0),internetAddressList[0].getAddress());\n return null;\n }).when(transport).sendMessage(any(MimeMessage.class),any());\n mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n } catch (MessagingException e) {\n throw new RuntimeException(e);\n }\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = Arrays.asList(\"test1@example.com\", \"test2@example.com\");\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://example.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 85, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead", + "reference": " @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 1,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n AstraConfigs.IndexerConfig headLocationAndNoRecoveryConfig =\n AstraConfigs.IndexerConfig.newBuilder()\n .setCreateRecoveryTasksOnStart(false)\n .setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST)\n .build();\n\n // When there is no data and ReadFromLocationOnStart is set to LATEST, return the current head\n assertThat(\n recoveryTaskCreator.determineStartingOffset(1000, 0, headLocationAndNoRecoveryConfig))\n .isEqualTo(1000);\n\n // Data exists for not for this partition.\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTime = 1;\n final long endTime = 100;\n final long maxOffset = 100;\n\n final SnapshotMetadata partition1 =\n new SnapshotMetadata(name, path, startTime, endTime, maxOffset, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition1);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition1));\n\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n\n final SnapshotMetadata partition11 =\n new SnapshotMetadata(\n name + \"1\", path, endTime + 1, endTime * 2, maxOffset * 2, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition11);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition11));\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(partition1, partition11);\n assertThat(recoveryTaskCreator.determineStartingOffset(0, 0, indexerConfig)).isNegative();\n\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n \"2\",\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(false);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n\n // When\n long result = determineStartingOffset(currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(currentEndOffsetForPartition, result);\n }\n```\n" + ], + "line": 128, + "token": 935, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 86, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime =\n LocalDateTime.of(2020, 10, 1, 10, 10, 0).atZone(ZoneOffset.UTC).toInstant();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n long startOffset = 1;\n long endOffset = msgsToProduce - 1;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry))\n .isEqualTo(endOffset - startOffset + 1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(0);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(0);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isEmpty();\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 0,\n 100,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 4, + "token_diff": 2 + }, + { + "id": 87, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"3g\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n String expirationTime = \"10x\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(expirationTime);\n assertEquals(CacheControlHeader.DEFAULT_BASE_VALUE + CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_empty_collection()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_empty_collection", + "reference": " @Test\n public void should_clone_empty_collection() {\n List original = new ArrayList<>();\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_empty_collection() {\n // Given\n Collection collection = new ArrayList<>();\n\n // When\n Collection clonedCollection = (Collection) clone(collection);\n\n // Then\n assertNotSame(collection, clonedCollection);\n assertEquals(collection, clonedCollection);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 89, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testTimeBoundSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testTimeBoundSearch", + "reference": " @Test\n public void testTimeBoundSearch() {\n Instant time = Instant.now();\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(1, time));\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(2, time.plusSeconds(100)));\n strictLogStore.logStore.commit();\n strictLogStore.logStore.refresh();\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(2);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(0);\n assertThat(getTimerCount(REFRESHES_TIMER, strictLogStore.metricsRegistry)).isEqualTo(1);\n\n // Start inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Extended range still only picking one element.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(90).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Both ranges are inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.toEpochMilli(),\n time.plusSeconds(100).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n // Extended range to pick up both events\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(1000).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n }", + "completions": [ + "```java\n @Test\n public void testTimeBoundSearch() {\n // Given\n String dataset = \"test_dataset\";\n String queryStr = \"test_query\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n for (LogMessage logMessage : result.getResults()) {\n assertTrue(logMessage.getTimestampMs() >= startTimeMsEpoch);\n assertTrue(logMessage.getTimestampMs() <= endTimeMsEpoch);\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 90, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "\t@ParameterizedTest\n\t@MethodSource", + "reference": "\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\t// setup\n\t\tString key = \"key\";\n\t\tTestParameters parameters = new TestParameters();\n\t\tObject testValue = theParams.getTestValue();\n\n\t\t// test\n\t\tif (theParams.isExpectedToWork()) {\n\t\t\tparameters.setUserData(key, testValue);\n\t\t\tassertFalse(parameters.getUserData().isEmpty());\n\t\t\tassertEquals(testValue, parameters.getUserData().get(key));\n\t\t} else {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, testValue);\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tString dataType = testValue.getClass().getName();\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid data type provided \" + dataType),\n\t\t\t\t\tex.getMessage());\n\t\t\t\tassertTrue(parameters.getUserData().isEmpty());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\tUser user = new User();\n\t\tuser.setUserData(theParams.getKey(), theParams.getValue());\n\t\tassertEquals(theParams.getValue(), user.getUserData().get(theParams.getKey()));\n\t}\n\n\tstatic Stream parameters() {\n\t\treturn Stream.of(\n\t\t\t\tnew TestParam(\"stringKey\", \"stringValue\"),\n\t\t\t\tnew TestParam(\"numberKey\", 123),\n\t\t\t\tnew TestParam(\"booleanKey\", true),\n\t\t\t\tnew TestParam(\"nullKey\", null)\n\t\t);\n\t}\n\n\tstatic class TestParam {\n\t\tprivate final String key;\n\t\tprivate final Object value;\n\n\t\tTestParam(String key, Object value) {\n\t\t\tthis.key = key;\n\t\t\tthis.value = value;\n\t\t}\n\n\t\tString getKey() {\n\t\t\treturn key;\n\t\t}\n\n\t\tObject getValue() {\n\t\t\treturn value;\n\t\t}\n\t}\n```\n" + ], + "line": 24, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\n\t\tString text = \"Hello, how are you?\";\n\n\t\tEmbedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));\n\n\t\tserver\n\t\t\t.expect(requestToUriTemplate(\"/models/{generative}:embedText?key={apiKey}\",\n\t\t\t\t\tVertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))\n\t\t\t.andExpect(method(HttpMethod.POST))\n\t\t\t.andExpect(content().json(objectMapper.writeValueAsString(Map.of(\"text\", text))))\n\t\t\t.andRespond(withSuccess(objectMapper.writeValueAsString(Map.of(\"embedding\", expectedEmbedding)),\n\t\t\t\t\tMediaType.APPLICATION_JSON));\n\n\t\tEmbedding embedding = client.embedText(text);\n\n\t\tassertThat(embedding).isEqualTo(expectedEmbedding);\n\n\t\tserver.verify();\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\tEmbeddingResponse response = new EmbeddingResponse(expectedEmbedding);\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(response);\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Mono syncAclWithAclCsv(KafkaCluster cluster, String csv) {\n return adminClientService.get(cluster)\n .flatMap(ac -> ac.listAcls(ResourcePatternFilter.ANY).flatMap(existingAclList -> {\n var existingSet = Set.copyOf(existingAclList);\n var newAcls = Set.copyOf(AclCsv.parseCsv(csv));\n var toDelete = Sets.difference(existingSet, newAcls);\n var toAdd = Sets.difference(newAcls, existingSet);\n logAclSyncPlan(cluster, toAdd, toDelete);\n if (toAdd.isEmpty() && toDelete.isEmpty()) {\n return Mono.empty();\n }\n log.info(\"Starting new ACLs creation\");\n return ac.createAcls(toAdd)\n .doOnSuccess(v -> {\n log.info(\"{} new ACLs created\", toAdd.size());\n log.info(\"Starting ACLs deletion\");\n })\n .then(ac.deleteAcls(toDelete)\n .doOnSuccess(v -> log.info(\"{} ACLs deleted\", toDelete.size())));\n }));\n }\n```\n###test function signature:\n```java\n @Test\n void testSyncAclWithAclCsv()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSyncAclWithAclCsv", + "reference": " @Test\n void testSyncAclWithAclCsv() {\n var existingBinding1 = new AclBinding(\n new ResourcePattern(ResourceType.TOPIC, \"*\", PatternType.LITERAL),\n new AccessControlEntry(\"User:test1\", \"*\", AclOperation.READ, AclPermissionType.ALLOW));\n\n var existingBinding2 = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"group1\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test2\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n var newBindingToBeAdded = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"groupNew\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test3\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n when(adminClientMock.listAcls(ResourcePatternFilter.ANY))\n .thenReturn(Mono.just(List.of(existingBinding1, existingBinding2)));\n\n ArgumentCaptor> createdCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.createAcls(createdCaptor.capture()))\n .thenReturn(Mono.empty());\n\n ArgumentCaptor> deletedCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.deleteAcls(deletedCaptor.capture()))\n .thenReturn(Mono.empty());\n\n aclsService.syncAclWithAclCsv(\n CLUSTER,\n \"Principal,ResourceType, PatternType, ResourceName,Operation,PermissionType,Host\\n\"\n + \"User:test1,TOPIC,LITERAL,*,READ,ALLOW,*\\n\"\n + \"User:test3,GROUP,PREFIXED,groupNew,DESCRIBE,DENY,localhost\"\n ).block();\n\n Collection createdBindings = createdCaptor.getValue();\n assertThat(createdBindings)\n .hasSize(1)\n .contains(newBindingToBeAdded);\n\n Collection deletedBindings = deletedCaptor.getValue();\n assertThat(deletedBindings)\n .hasSize(1)\n .contains(existingBinding2);\n }", + "completions": [ + "```java\n @Test\n void testSyncAclWithAclCsv() {\n // Given\n KafkaCluster cluster = new KafkaCluster(\"testCluster\", \"localhost:9092\");\n String csv = \"principal,host,operation,permissionType\\nUser:Alice,*,Read,Allow\\nUser:Bob,*,Write,Allow\";\n AdminClient adminClient = Mockito.mock(AdminClient.class);\n when(adminClientService.get(cluster)).thenReturn(Mono.just(adminClient));\n when(adminClient.listAcls(ResourcePatternFilter.ANY)).thenReturn(Flux.just(\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Alice\", \"*\", AclOperation.READ, AclPermissionType.ALLOW)),\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Bob\", \"*\", AclOperation.WRITE, AclPermissionType.ALLOW))\n ).collectList());\n\n // When\n Mono result = service.syncAclWithAclCsv(cluster, csv);\n\n // Then\n StepVerifier.create(result)\n .verifyComplete();\n\n verify(adminClient).createAcls(any());\n verify(adminClient).deleteAcls(any());\n }\n```\n" + ], + "line": 32, + "token": 259, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 93, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyNoResults", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final SearchResponse mockResponse = mock(SearchResponse.class);\n final SearchHits mockHits = mock(SearchHits.class);\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n when(mockHits.getHits()).thenReturn(new SearchHit[0]);\n when(tenantAwareClient.search(any(SearchRequest.class))).thenReturn(mockResponse);\n\n // When\n try {\n processInstanceService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 94, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyMultipleRecoveryBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(3);\n assertThat(recoveryTasks).contains(recoveryTask1, recoveryTask11);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo(recoveryStartOffset * 3 + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(100);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 95, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n private String saveAndChecksumBlockList(String jobName, LazyBlockList blockList) {\n BlobId blobId = getBlobId(jobName, getBlockListFileName(blockList.getName()));\n try (BufferedOutputStream gcsWriter =\n new BufferedOutputStream(gcsUtils.openOutputStream(blobId))) {\n MessageDigest messageDigest = MessageDigest.getInstance(checksumAlgorithm);\n blockList.consumeAll(\n (byteArray, length) -> {\n try {\n gcsWriter.write(byteArray, 0, length);\n } catch (IOException e) {\n throw new RuntimeException(e);\n }\n messageDigest.update(byteArray, 0, length);\n });\n return base16().lowerCase().encode(messageDigest.digest());\n } catch (IOException | NoSuchAlgorithmException e) {\n throw new RuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "saveAndChecksumBlockList_success", + "reference": " @Test\n void saveAndChecksumBlockList_success() throws Exception {\n String payload = \"somedata\\n\";\n String payloadChecksum = \"0737c8e591c68b93feccde50829aca86a80137547d8cfbe96bab6b20f8580c63\";\n\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream((\"bsa-checksum\\n\" + payload).getBytes(UTF_8)));\n lazyBlockList = new LazyBlockList(BlockListType.BLOCK, connection);\n\n ImmutableMap checksums =\n gcsClient.saveAndChecksumBlockList(\"some-name\", ImmutableList.of(lazyBlockList));\n assertThat(gcsUtils.existsAndNotEmpty(BlobId.of(\"my-bucket\", \"some-name/BLOCK.csv\"))).isTrue();\n assertThat(checksums).containsExactly(BlockListType.BLOCK, payloadChecksum);\n assertThat(gcsClient.readBlockList(\"some-name\", BlockListType.BLOCK))\n .containsExactly(\"somedata\");\n }", + "completions": [ + "```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception {\n // Given\n String jobName = \"testJob\";\n LazyBlockList blockList = mock(LazyBlockList.class);\n BlobId blobId = mock(BlobId.class);\n BufferedOutputStream gcsWriter = mock(BufferedOutputStream.class);\n MessageDigest messageDigest = mock(MessageDigest.class);\n when(gcsUtils.openOutputStream(blobId)).thenReturn(gcsWriter);\n when(checksumAlgorithm).thenReturn(\"SHA-256\");\n when(messageDigest.digest()).thenReturn(new byte[] {1, 2, 3, 4, 5});\n when(base16().lowerCase().encode(messageDigest.digest())).thenReturn(\"0102030405\");\n\n // When\n String checksum = saveAndChecksumBlockList(jobName, blockList);\n\n // Then\n assertEquals(\"0102030405\", checksum);\n verify(gcsWriter).write(any(), anyInt(), anyInt());\n verify(messageDigest).update(any(), anyInt(), anyInt());\n }\n```\n" + ], + "line": 30, + "token": 211, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ListenableFuture> findMissingBlobs(\n Iterable blobDigests, RequestMetadata requestMetadata) {\n // Some requests have been blocked, and we should tell the client we refuse to perform a lookup.\n try {\n if (inDenyList(requestMetadata)) {\n return immediateFailedFuture(\n Status.UNAVAILABLE\n .withDescription(\"The action associated with this request is forbidden\")\n .asException());\n }\n } catch (IOException e) {\n return immediateFailedFuture(Status.fromThrowable(e).asException());\n }\n\n // Empty blobs are an exceptional case. Filter them out.\n // If the user only requested empty blobs we can immediately tell them we already have it.\n Iterable nonEmptyDigests =\n Iterables.filter(blobDigests, (digest) -> digest.getSizeBytes() != 0);\n if (Iterables.isEmpty(nonEmptyDigests)) {\n return immediateFuture(ImmutableList.of());\n }\n\n if (configs.getServer().isFindMissingBlobsViaBackplane()) {\n return findMissingBlobsViaBackplane(nonEmptyDigests, requestMetadata);\n }\n\n return findMissingBlobsQueryingEachWorker(nonEmptyDigests, requestMetadata);\n }\n```\n###test function signature:\n```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "findMissingBlobsTest_ViaBackPlane", + "reference": " @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n Set activeWorkers = ImmutableSet.of(\"worker1\", \"worker2\", \"worker3\");\n Set expiredWorkers = ImmutableSet.of(\"workerX\", \"workerY\", \"workerZ\");\n Set imposterWorkers = ImmutableSet.of(\"imposter1\", \"imposter2\", \"imposter3\");\n\n Set availableDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFound1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build());\n\n Set missingDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"missing1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build());\n\n Set digestAvailableOnImposters =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFoundOnImposter1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter3\").setSizeBytes(1).build());\n\n Set emptyDigests =\n new HashSet<>(\n Arrays.asList(\n Digest.newBuilder().setHash(\"empty1\").build(),\n Digest.newBuilder().setHash(\"empty2\").build()));\n\n Iterable allDigests =\n Iterables.concat(\n availableDigests,\n missingDigests,\n emptyDigests,\n digestAvailableOnImposters,\n Arrays.asList(\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build()));\n\n Map> digestAndWorkersMap = new HashMap<>();\n\n for (Digest digest : availableDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(activeWorkers));\n }\n for (Digest digest : missingDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(expiredWorkers));\n }\n for (Digest digest : digestAvailableOnImposters) {\n digestAndWorkersMap.put(digest, getRandomSubset(imposterWorkers));\n }\n\n BuildfarmConfigs buildfarmConfigs = instance.getBuildFarmConfigs();\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(true);\n Set activeAndImposterWorkers =\n Sets.newHashSet(Iterables.concat(activeWorkers, imposterWorkers));\n\n when(mockBackplane.getStorageWorkers()).thenReturn(activeAndImposterWorkers);\n when(mockBackplane.getBlobDigestsWorkers(any(Iterable.class))).thenReturn(digestAndWorkersMap);\n when(mockInstanceLoader.load(anyString())).thenReturn(mockWorkerInstance);\n when(mockWorkerInstance.findMissingBlobs(anyIterable(), any(RequestMetadata.class)))\n .thenReturn(Futures.immediateFuture(new ArrayList<>()));\n\n long serverStartTime = 1686951033L; // june 15th, 2023\n Map workersStartTime = new HashMap<>();\n for (String worker : activeAndImposterWorkers) {\n workersStartTime.put(worker, serverStartTime);\n }\n when(mockBackplane.getWorkersStartTimeInEpochSecs(activeAndImposterWorkers))\n .thenReturn(workersStartTime);\n long oneDay = 86400L;\n for (Digest digest : availableDigests) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime + oneDay);\n }\n for (Digest digest : digestAvailableOnImposters) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime - oneDay);\n }\n\n Iterable actualMissingDigests =\n instance.findMissingBlobs(allDigests, RequestMetadata.getDefaultInstance()).get();\n Iterable expectedMissingDigests =\n Iterables.concat(missingDigests, digestAvailableOnImposters);\n\n assertThat(actualMissingDigests).containsExactlyElementsIn(expectedMissingDigests);\n verify(mockWorkerInstance, atMost(3))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n verify(mockWorkerInstance, atLeast(1))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n\n for (Digest digest : actualMissingDigests) {\n assertThat(digest).isNotIn(availableDigests);\n assertThat(digest).isNotIn(emptyDigests);\n assertThat(digest).isIn(expectedMissingDigests);\n }\n\n // reset BuildfarmConfigs\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(false);\n }", + "completions": [ + "```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n // Given\n Iterable blobDigests = Arrays.asList(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n RequestMetadata requestMetadata = RequestMetadata.newBuilder().setClientId(\"client1\").build();\n\n // When\n ListenableFuture> future =\n underTest.findMissingBlobs(blobDigests, requestMetadata);\n\n // Then\n Iterable missingBlobs = future.get();\n assertThat(missingBlobs).containsExactly(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n }\n```\n" + ], + "line": 37, + "token": 276, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 97, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent_notFound()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent_notFound", + "reference": " @Test\n public void testReadComponent_notFound() {\n Component component = mock(Component.class);\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(null);\n try {\n underTest.readComponent(\"someid\", \"testRepositoryName\");\n fail(\"Exception should have been thrown\");\n }\n catch (WebApplicationException e) {\n assertThat(e.getResponse(), is(notNullValue()));\n assertThat(e.getResponse().getStatus(), is(404));\n }\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent_notFound() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(() -> storageTx);\n when(storageTx.findComponent(componentId)).thenReturn(null);\n\n // When\n try {\n readComponent(repository, componentId);\n fail(\"Expected WebApplicationException\");\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Status.NOT_FOUND.getStatusCode(), e.getResponse().getStatus());\n }\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture auth(MQTT3AuthData authData) {\n try {\n HttpRequest request = HttpRequest.newBuilder()\n .uri(URI.create(webhookURI + \"/auth\"))\n .header(\"Content-Type\", \"application/json\")\n .POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))\n .timeout(Duration.ofSeconds(5))\n .build();\n return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())\n .thenApply(response -> {\n if (response.statusCode() == 200) {\n try {\n MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();\n JsonFormat.parser()\n .ignoringUnknownFields()\n .merge(response.body(), resultBuilder);\n return resultBuilder.build();\n } catch (Throwable e) {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build();\n }\n } else {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(\"Authenticate failed\")\n .build())\n .build();\n }\n })\n .exceptionally(e -> {\n System.out.println(\"Failed to call webhook: \" + e.getMessage());\n return null;\n });\n } catch (Throwable e) {\n return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAuth()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAuth", + "reference": " @Test\n public void testAuth() {\n MQTT3AuthData authedUser = MQTT3AuthData.newBuilder()\n .setUsername(\"authUser\")\n .build();\n\n MQTT3AuthData unauthUser = MQTT3AuthData.newBuilder()\n .setUsername(\"unauthUser\")\n .build();\n testServer.addAuthedUser(authedUser.getUsername());\n WebHookBasedAuthProvider provider = new WebHookBasedAuthProvider(testServer.getURI());\n MQTT3AuthResult authResult = provider.auth(authedUser).join();\n assertTrue(authResult.hasOk());\n\n authResult = provider.auth(unauthUser).join();\n assertFalse(authResult.hasOk());\n }", + "completions": [ + "```java\n @Test\n public void testAuth() {\n // Given\n String webhookURI = \"http://localhost:8080\";\n HttpClient httpClient = HttpClient.newHttpClient();\n MQTT3AuthData authData = MQTT3AuthData.newBuilder()\n .setClientId(\"client1\")\n .setUsername(\"user1\")\n .setPassword(\"password1\")\n .build();\n MQTT3AuthResult expectedResult = MQTT3AuthResult.newBuilder()\n .setSuccess(Success.newBuilder()\n .setIsSuperuser(false)\n .build())\n .build();\n\n // When\n CompletableFuture result = new MQTT3AuthService(webhookURI, httpClient).auth(authData);\n\n // Then\n assertNotNull(result);\n assertEquals(expectedResult, result.join());\n }\n```\n" + ], + "line": 59, + "token": 335, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 99, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_success", + "reference": " @Test\n public void testValidateRoleDuplicate_success() {\n // 调用,不会抛异常\n roleService.validateRoleDuplicate(randomString(), randomString(), null);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_success() {\n // Given\n String name = \"TestRole\";\n String code = \"TEST_ROLE\";\n Long id = 1L;\n\n // When\n // Call the method to be tested\n validateRoleDuplicate(name, code, id);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/6/EI/line_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Java/6/EI/line_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..3d14eaa1c8c704e48e971d84ebcebc85831d903e --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/6/EI/line_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(22~39),subset_1(41~57),subset_2(59~74),subset_3(85~85),subset_4(94~109),subset_5(127~128) +StarCoder2-15b,31.34,33.05,27.40,23.55,14.39,24.52 +CodeLlama-7b,32.88,33.37,33.62,19.82,24.73,20.99 +CodeLlama-13b,29.38,32.07,28.99,21.71,20.76,23.66 +CodeLlama-34b,30.25,33.56,28.22,21.69,22.25,25.67 +DeepSeek-Coder-1.3b,28.81,33.95,28.24,23.99,18.90,23.89 +DeepSeek-Coder-6.7b,20.73,19.47,25.46,22.51,21.18,24.81 +DeepSeek-Coder-33b,32.86,35.99,33.40,23.15,21.23,18.29 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/6/EI/token_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Java/6/EI/token_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..9a0b471732d94917a778f8ceeabb112932d8c315 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/6/EI/token_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~273),subset_1(276~377),subset_2(424~544),subset_3(602~664),subset_4(681~793),subset_5(935~952) +StarCoder2-15b,31.20,33.64,21.07,23.55,13.83,23.59 +CodeLlama-7b,33.61,33.91,25.90,19.82,28.42,23.19 +CodeLlama-13b,30.26,31.07,25.79,21.71,20.41,22.53 +CodeLlama-34b,30.96,32.61,25.32,21.69,20.93,26.04 +DeepSeek-Coder-1.3b,29.47,32.58,25.89,23.99,16.45,20.71 +DeepSeek-Coder-6.7b,20.40,20.86,21.17,22.51,24.54,22.94 +DeepSeek-Coder-33b,34.42,34.07,27.48,23.15,28.31,19.38 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/6/EI/tongji.py b/dataset/Test Generation/ComplexCodeEval-Java/6/EI/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/6/EI/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/6/QS/QS.json b/dataset/Test Generation/ComplexCodeEval-Java/6/QS/QS.json new file mode 100644 index 0000000000000000000000000000000000000000..7b9ea0bd863fab1ab6a65ed362ac8b0d1b7f4e8d --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/6/QS/QS.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTask", + "reference": " @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(TEST_S3_BUCKET);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 100L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockKafkaConsumer.consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(true);\n Mockito.when(mockKafkaConsumer.close()).thenReturn(true);\n Mockito.when(mockChunkManager.stopAsync()).thenReturn(true);\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenReturn(true);\n Mockito.when(AstraKafkaConsumer.makeKafkaConfig(Mockito.any(), Mockito.anyInt())).thenReturn(new Properties());\n Mockito.when(RecoveryChunkManager.fromConfig(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(mockChunkManager);\n Mockito.when(adminClient.listConsumerGroupOffsets(Mockito.anyString())).thenReturn(new ConsumerGroupOffsets());\n Mockito.when(AstraConfig.getRecoveryConfig()).thenReturn(new RecoveryConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig()).thenReturn(new KafkaConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic()).thenReturn(\"testTopic\");\n Mockito.when(AstraConfig.getIndexerConfig()).thenReturn(new IndexerConfig());\n Mockito.when(AstraConfig.getS3Config()).thenReturn(new S3Config());\n Mockito.when(meterRegistry.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class));\n Mockito.when(Timer.start(Mockito.any())).thenReturn(Mockito.mock(Timer.Sample.class));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n Assert.assertTrue(result);\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).prepConsumerForConsumption(Mockito.anyLong());\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong());\n Mockito.verify(mockChunkManager, Mockito.times(1)).waitForRollOvers();\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).close();\n Mockito.verify(mockChunkManager, Mockito.times(1)).stopAsync();\n Mockito.verify(mockChunkManager, Mockito.times(1)).awaitTerminated(Mockito.anyLong());\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 1, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskFailure", + "reference": " @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n String fakeS3Bucket = \"fakeBucket\";\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(fakeS3Bucket);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isFalse();\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 10L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(false);\n Mockito.when(mockKafkaConsumer.close()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.stopAsync()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertFalse(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 2, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public int migration(String srcTable, List uuids, String targetTable) {\n if (Objects.isNull(srcTable) || Objects.isNull(targetTable) || CollectionUtils.isEmpty(uuids)) {\n throw new SwValidationException(\n ValidSubject.EVALUATION, \"srcTable, targetTable and uuids must not be null\");\n }\n var filter = TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build();\n return store.migration(DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(filter)\n .build()\n );\n }\n```\n###test function signature:\n```java\n @Test\n public void testMigration()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMigration", + "reference": " @Test\n public void testMigration() {\n var srcTable = \"p/1/space/1/eval/summary\";\n var targetTable = \"p/1/eval/summary\";\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(null, List.of(\"uuid1\"), targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, null, targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, List.of(\"uuid1\"), null));\n\n evaluationRepo.migration(srcTable, List.of(\"uuid1\", \"uuid2\"), targetTable);\n\n verify(dataStore, times(1)).migration(\n DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(\n TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(List.of(\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid1\")))\n .build(),\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid2\")))\n .build()))\n .build())\n .build()\n );\n }", + "completions": [ + "```java\n @Test\n public void testMigration() {\n // Given\n String srcTable = \"srcTable\";\n List uuids = Arrays.asList(\"uuid1\", \"uuid2\", \"uuid3\");\n String targetTable = \"targetTable\";\n DataStoreMigrationRequest expectedRequest = DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build()\n )\n .build();\n when(store.migration(any(DataStoreMigrationRequest.class))).thenReturn(1);\n\n // When\n int result = migrationService.migration(srcTable, uuids, targetTable);\n\n // Then\n assertEquals(1, result);\n verify(store).migration(expectedRequest);\n }\n```\n" + ], + "line": 37, + "token": 221, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 3, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDatasMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDatasMapTest", + "reference": " @Test\n public void sendDatasMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"doris_test_db\";\n List> dataList = new ArrayList<>();\n for (int i = 0; i < 100; i++) {\n Map data = new HashMap<>();\n data.put(\"id\", i);\n data.put(\"name\", \"张三\");\n data.put(\"message\", \"股东风波的法规的规定翻跟斗广泛的给yryrtytr\");\n System.out.println(gson.toJson(data));\n dataList.add(data);\n\n }\n dorisStreamLoad.sendData(database, tableName, columnList, dataList);\n }", + "completions": [ + "```java\n @Test\n public void sendDataTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 4, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_serializable_key_and_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_serializable_key_and_value", + "reference": " @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map original = new HashMap<>();\n original.put(\"key\", new SerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map originalMap = new HashMap<>();\n originalMap.put(\"one\", 1);\n originalMap.put(\"two\", 2);\n originalMap.put(\"three\", 3);\n\n Map clonedMap = clone(originalMap);\n\n assertNotSame(originalMap, clonedMap);\n assertEquals(originalMap, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 5, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void encode(String schemaName, long nextTransactionOffset) {\n if (schemaName == null) {\n schemaName = UNKNOWN;\n }\n this.schemaName = schemaName;\n this.nextTransactionOffset = nextTransactionOffset;\n final byte[] payloadBytes = payloadToBytes();\n final int payloadLength = payloadBytes.length;\n\n // set logEventHeader\n int eventSize = eventHeaderLengthVersionGt1 + payloadLength;\n setLogEventHeader(new LogEventHeader(drc_filter_log_event.getType(), 0, eventSize, nextTransactionOffset)\n );\n\n // set payload\n final ByteBuf payloadByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(payloadLength);\n payloadByteBuf.writeBytes(payloadBytes);\n payloadByteBuf.skipBytes(payloadLength);\n setPayloadBuf(payloadByteBuf);\n }\n```\n###test function signature:\n```java\n @Test\n public void encode()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "encode", + "reference": " @Test\n public void encode() {\n FilterLogEvent filterLogEvent = new FilterLogEvent();\n filterLogEvent.encode(\"drc1\", 101);\n Assert.assertEquals(\"drc1\", filterLogEvent.getSchemaName());\n Assert.assertEquals(101, filterLogEvent.getNextTransactionOffset());\n\n ByteBuf headerBuf = filterLogEvent.getLogEventHeader().getHeaderBuf();\n headerBuf.readerIndex(0);\n ByteBuf payloadBuf = filterLogEvent.getPayloadBuf();\n payloadBuf.readerIndex(0);\n CompositeByteBuf compositeByteBuf = PooledByteBufAllocator.DEFAULT.compositeDirectBuffer();\n compositeByteBuf.addComponents(true, headerBuf, payloadBuf);\n\n FilterLogEvent newFilterLogEvent = new FilterLogEvent();\n newFilterLogEvent.read(compositeByteBuf);\n compositeByteBuf.release(compositeByteBuf.refCnt());\n Assert.assertEquals(\"drc1\", newFilterLogEvent.getSchemaName());\n Assert.assertEquals(101, newFilterLogEvent.getNextTransactionOffset());\n }", + "completions": [ + "```java\n @Test\n public void encode() {\n // Given\n String schemaName = \"testSchema\";\n long nextTransactionOffset = 1234567890L;\n drc_filter_log_event event = new drc_filter_log_event();\n\n // When\n event.encode(schemaName, nextTransactionOffset);\n\n // Then\n assertEquals(schemaName, event.getSchemaName());\n assertEquals(nextTransactionOffset, event.getNextTransactionOffset());\n assertEquals(drc_filter_log_event.getType(), event.getLogEventHeader().getType());\n assertEquals(0, event.getLogEventHeader().getServerId());\n assertEquals(eventHeaderLengthVersionGt1 + event.getPayloadBuf().readableBytes(), event.getLogEventHeader().getEventLength());\n assertEquals(nextTransactionOffset, event.getLogEventHeader().getNextTransactionOffset());\n assertEquals(event.getPayloadBuf().readableBytes(), event.getPayloadBuf().readableBytes());\n }\n```\n" + ], + "line": 29, + "token": 183, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void put(String path, long position, ByteBuf data) {\n writeLock.lock();\n try {\n int dataLength = data.readableBytes();\n NavigableMap cache = path2cache.computeIfAbsent(path, k -> new TreeMap<>());\n Map.Entry pos2value = cache.floorEntry(position);\n long cacheStartPosition;\n Value value;\n if (pos2value == null || pos2value.getKey() + pos2value.getValue().dataLength < position) {\n cacheStartPosition = position;\n value = Value.EMPTY;\n } else {\n cacheStartPosition = pos2value.getKey();\n value = pos2value.getValue();\n }\n // ensure the capacity, if the capacity change then update the cache index\n int moreCapacity = (int) ((position + dataLength) - (cacheStartPosition + value.blocks.length * (long) blockSize));\n int newDataLength = (int) (position + dataLength - cacheStartPosition);\n if (moreCapacity > 0) {\n int[] blocks = ensureCapacity(cacheStartPosition, moreCapacity);\n if (blocks == null) {\n return;\n }\n int[] newBlocks = new int[value.blocks.length + blocks.length];\n System.arraycopy(value.blocks, 0, newBlocks, 0, value.blocks.length);\n System.arraycopy(blocks, 0, newBlocks, value.blocks.length, blocks.length);\n value = new Value(newBlocks, newDataLength);\n } else {\n value = new Value(value.blocks, newDataLength);\n }\n cache.put(cacheStartPosition, value);\n lru.put(new Key(path, cacheStartPosition), value);\n\n // write data to cache\n ByteBuffer cacheByteBuffer = this.cacheByteBuffer.duplicate();\n int positionDelta = (int) (position - cacheStartPosition);\n int written = 0;\n ByteBuffer[] nioBuffers = data.nioBuffers();\n int[] blocks = value.blocks;\n for (ByteBuffer nioBuffer : nioBuffers) {\n ByteBuf buf = Unpooled.wrappedBuffer(nioBuffer);\n while (buf.readableBytes() > 0) {\n int writePosition = positionDelta + written;\n int block = blocks[writePosition / blockSize];\n cacheByteBuffer.position(block * blockSize + writePosition % blockSize);\n int length = Math.min(buf.readableBytes(), blockSize - writePosition % blockSize);\n cacheByteBuffer.put(buf.slice(buf.readerIndex(), length).nioBuffer());\n buf.skipBytes(length);\n written += length;\n }\n }\n } finally {\n writeLock.unlock();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testMergePut() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMergePut", + "reference": " @Test\n public void testMergePut() throws IOException {\n FileCache fileCache = new FileCache(\"/tmp/file_cache_test\", 10 * 1024, 1024);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 1, 500));\n buf.addComponent(true, genBuf((byte) 2, 500));\n buf.addComponent(true, genBuf((byte) 3, 500));\n fileCache.put(\"test\", 3333L, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(1500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 4, 500));\n buf.addComponent(true, genBuf((byte) 5, 500));\n buf.addComponent(true, genBuf((byte) 6, 500));\n fileCache.put(\"test\", 3333L + 1000, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(2500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 7, 500));\n fileCache.put(\"test\", 3333L + 1000 + 1500, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(3000, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n\n assertTrue(verify(fileCache.get(\"test\", 3333L, 500).get(), (byte) 1));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 500, 500).get(), (byte) 2));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1000, 500).get(), (byte) 4));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1500, 500).get(), (byte) 5));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2000, 500).get(), (byte) 6));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2500, 500).get(), (byte) 7));\n }", + "completions": [ + "```java\n @Test\n public void testMergePut() throws IOException {\n // Given\n String path = \"/test/path\";\n long position = 100;\n ByteBuf data = Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5});\n\n // When\n cache.put(path, position, data);\n\n // Then\n NavigableMap cacheMap = cache.path2cache.get(path);\n assertNotNull(cacheMap);\n Map.Entry entry = cacheMap.floorEntry(position);\n assertNotNull(entry);\n assertEquals(position, entry.getKey().longValue());\n assertEquals(5, entry.getValue().dataLength);\n assertEquals(1, entry.getValue().blocks[0]);\n assertEquals(2, entry.getValue().blocks[1]);\n assertEquals(3, entry.getValue().blocks[2]);\n assertEquals(4, entry.getValue().blocks[3]);\n assertEquals(5, entry.getValue().blocks[4]);\n }\n```\n" + ], + "line": 65, + "token": 508, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 7, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchById()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchById", + "reference": " @Test\n public void testSearchById() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult index =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(index.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchById() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 8, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = new ArrayList<>();\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 9, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_object", + "reference": " @Test\n public void should_clone_serializable_object() {\n Object original = new SerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_object() {\n // Given\n Serializable object = new Serializable() {\n private static final long serialVersionUID = 1L;\n // Implement necessary methods\n };\n\n // When\n Serializable clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertEquals(object, clonedObject);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 10, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() {\n\n\t\tvar text = \"Hello, how are you?\";\n\n\t\tEmbedding response = vertexAiPaLm2Api.embedText(text);\n\n\t\tassertThat(response).isNotNull();\n\t\tassertThat(response.value()).hasSize(768);\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(new EmbeddingResponse(expectedEmbedding));\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Optional maybeFetchErrorIntervalMs(short errorCode, int intervalMs) {\n if (errorCode == Errors.NONE.code())\n return Optional.empty();\n\n int pushIntervalMs;\n String reason;\n\n Errors error = Errors.forCode(errorCode);\n switch (error) {\n case INVALID_REQUEST:\n case INVALID_RECORD:\n case UNSUPPORTED_VERSION:\n pushIntervalMs = Integer.MAX_VALUE;\n reason = \"The broker response indicates the client sent an request that cannot be resolved\"\n + \" by re-trying, hence disable telemetry\";\n break;\n case UNKNOWN_SUBSCRIPTION_ID:\n case UNSUPPORTED_COMPRESSION_TYPE:\n pushIntervalMs = 0;\n reason = error.message();\n break;\n case TELEMETRY_TOO_LARGE:\n case THROTTLING_QUOTA_EXCEEDED:\n reason = error.message();\n pushIntervalMs = (intervalMs != -1) ? intervalMs : ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS;\n break;\n default:\n reason = \"Unwrapped error code\";\n log.error(\"Error code: {}. Unmapped error for telemetry, disable telemetry.\", errorCode);\n pushIntervalMs = Integer.MAX_VALUE;\n }\n\n log.debug(\"Error code: {}, reason: {}. Push interval update to {} ms.\", errorCode, reason, pushIntervalMs);\n return Optional.of(pushIntervalMs);\n }\n```\n###test function signature:\n```java\n @Test\n public void testMaybeFetchErrorIntervalMs()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMaybeFetchErrorIntervalMs", + "reference": " @Test\n public void testMaybeFetchErrorIntervalMs() {\n assertEquals(Optional.empty(), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.NONE.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_REQUEST.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_RECORD.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SUBSCRIPTION_ID.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNSUPPORTED_COMPRESSION_TYPE.code(), -1));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), 20000));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), 20000));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SERVER_ERROR.code(), -1));\n }", + "completions": [ + "```java\n @Test\n public void testMaybeFetchErrorIntervalMs() {\n // Test case 1: errorCode is Errors.NONE.code()\n short errorCode = Errors.NONE.code();\n int intervalMs = 1000;\n Optional result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertFalse(result.isPresent());\n\n // Test case 2: errorCode is Errors.INVALID_REQUEST.code()\n errorCode = Errors.INVALID_REQUEST.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 3: errorCode is Errors.UNKNOWN_SUBSCRIPTION_ID.code()\n errorCode = Errors.UNKNOWN_SUBSCRIPTION_ID.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 4: errorCode is Errors.TELEMETRY_TOO_LARGE.code()\n errorCode = Errors.TELEMETRY_TOO_LARGE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 5: errorCode is Errors.THROTTLING_QUOTA_EXCEEDED.code()\n errorCode = Errors.THROTTLING_QUOTA_EXCEEDED.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 6: errorCode is Errors.UNSUPPORTED_COMPRESSION_TYPE.code()\n errorCode = Errors.UNSUPPORTED_COMPRESSION_TYPE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 7: errorCode is Errors.INVALID_RECORD.code()\n errorCode = Errors.INVALID_RECORD.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 8: errorCode is Errors.UNSUPPORTED_VERSION.code()\n errorCode = Errors.UNSUPPORTED_VERSION.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 9: errorCode is an unmapped error\n errorCode = (short) 9999;\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n }\n```\n" + ], + "line": 43, + "token": 290, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 12, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_non_serializable_object", + "reference": " @Test\n public void should_clone_non_serializable_object() {\n Object original = new NonSerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_non_serializable_object() {\n // Given\n Map nonSerializableMap = new HashMap<>();\n nonSerializableMap.put(\"key1\", new Object());\n nonSerializableMap.put(\"key2\", new Object());\n\n // When\n Map clonedMap = objectCloner.clone(nonSerializableMap);\n\n // Then\n assertNotSame(nonSerializableMap, clonedMap);\n assertNotSame(nonSerializableMap.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(nonSerializableMap.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 13, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueOperationCompletesOperationWithCachedActionResult", + "reference": " @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ActionKey actionKey = DigestUtil.asActionKey(Digest.newBuilder().setHash(\"test\").build());\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"operation-with-cached-action-result\")\n .setActionDigest(actionKey.getDigest())\n .build();\n\n ActionResult actionResult = ActionResult.getDefaultInstance();\n\n when(mockBackplane.getActionResult(eq(actionKey))).thenReturn(actionResult);\n\n Poller poller = mock(Poller.class);\n\n instance.queue(executeEntry, poller, DEFAULT_TIMEOUT).get();\n\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(CACHE_CHECK));\n verify(mockBackplane, never()).putOperation(any(Operation.class), eq(QUEUED));\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(COMPLETED));\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n ActionResult actionResult = ActionResult.newBuilder().setExitCode(0).build();\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(\"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n when(cache.get(any(ActionKey.class))).thenReturn(actionResult);\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n future.get();\n\n verify(poller).pause();\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 14, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueActionFailsQueueEligibility", + "reference": " @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n Directory inputRoot = Directory.newBuilder().build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(false);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_INVALID)\n .setSubject(INVALID_PLATFORM)\n .setDescription(\n \"properties are not valid for queue eligibility: []. If you think your\"\n + \" queue should still accept these poperties without them being\"\n + \" specified in queue configuration, consider configuring the queue\"\n + \" with `allow_unmatched: True`\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(ACTION_DIGEST));\n when(executeEntry.getStdoutStreamName()).thenReturn(STDOUT_STREAM_NAME);\n when(executeEntry.getStderrStreamName()).thenReturn(STDERR_STREAM_NAME);\n when(executeEntry.getOperationName()).thenReturn(OPERATION_NAME);\n when(executeEntry.getRequestMetadata()).thenReturn(REQUEST_METADATA);\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n when(poller.pause()).thenThrow(new RuntimeException(\"Queue eligibility failed\"));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n\n assertThrows(RuntimeException.class, future::get);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 15, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new SerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new SerializableObject(\"name2\"),\n new SerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key\", new SerializableObject());\n SerializableObject object = new SerializableObject(map);\n\n // When\n SerializableObject clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertNotSame(object.getMap(), clonedObject.getMap());\n assertNotSame(object.getMap().get(\"key\"), clonedObject.getMap().get(\"key\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 16, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForSumAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForSumAgg", + "reference": " @Test\n public void testFullIndexSearchForSumAgg() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new SumAggBuilder(\"test\", TEST_SOURCE_LONG_PROPERTY, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalSum internalSum =\n (InternalSum) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n // 1, 3, 4, 5\n assertThat(internalSum.getValue()).isEqualTo(13);\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForSumAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new SumAggBuilder(\"testField\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertEquals(SumAggBuilder.class, result.getAggregation().getClass());\n assertEquals(\"testField\", ((SumAggBuilder) result.getAggregation()).getField());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 17, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void process(String base64AuthenticatorData, String signature, String clientDataJson, Fido2RegistrationData registration,\n Fido2AuthenticationData authenticationEntity) {\n log.debug(\"Registration: {}\", registration);\n\n AuthData authData = authenticatorDataParser.parseAssertionData(base64AuthenticatorData);\n commonVerifiers.verifyRpIdHash(authData, registration.getDomain());\n\n log.debug(\"User verification option: {}\", authenticationEntity.getUserVerificationOption());\n userVerificationVerifier.verifyUserVerificationOption(authenticationEntity.getUserVerificationOption(), authData);\n\n byte[] clientDataHash = DigestUtils.getSha256Digest().digest(base64Service.urlDecode(clientDataJson));\n\n try {\n int counter = authenticatorDataParser.parseCounter(authData.getCounters());\n commonVerifiers.verifyCounter(registration.getCounter(), counter);\n registration.setCounter(counter);\n\n JsonNode uncompressedECPointNode = dataMapperService.cborReadTree(base64Service.urlDecode(registration.getUncompressedECPoint()));\n PublicKey publicKey = coseService.createUncompressedPointFromCOSEPublicKey(uncompressedECPointNode);\n\n log.debug(\"Uncompressed ECpoint node: {}\", uncompressedECPointNode);\n log.debug(\"EC Public key hex: {}\", Hex.encodeHexString(publicKey.getEncoded()));\n log.debug(\"Registration algorithm: {}, default use: -7\", registration.getSignatureAlgorithm());\n authenticatorDataVerifier.verifyAssertionSignature(authData, clientDataHash, signature, publicKey, -7);\n\n } catch (Fido2CompromisedDevice ex) {\n log.error(\"Error compromised device: {}\", ex.getMessage());\n throw ex;\n } catch (Exception ex) {\n log.error(\"Error to check none assertion: {}\", ex.getMessage());\n throw new Fido2RuntimeException(\"Failed to check none assertion: {}\", ex.getMessage(), ex);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "process_ifCborReadTreeThrowException_fido2RuntimeException", + "reference": " @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n String base64AuthenticatorData = \"base64AuthenticatorData_test\";\n String signature = \"signature_test\";\n String clientDataJson = \"clientDataJson_test\";\n Fido2RegistrationData registration = mock(Fido2RegistrationData.class);\n Fido2AuthenticationData authenticationEntity = mock(Fido2AuthenticationData.class);\n\n when(authenticationEntity.getUserVerificationOption()).thenReturn(UserVerification.preferred);\n when(registration.getDomain()).thenReturn(\"domain_test\");\n when(registration.getCounter()).thenReturn(100);\n when(registration.getUncompressedECPoint()).thenReturn(\"uncompressedECPoint_test\");\n\n when(authenticatorDataParser.parseAssertionData(any())).thenReturn(mock(AuthData.class));\n when(base64Service.urlDecode(any(String.class))).thenReturn(\"decode_test\".getBytes());\n when(dataMapperService.cborReadTree(any(byte[].class))).thenThrow(new IOException(\"IOException_test\"));\n\n Fido2RuntimeException ex = assertThrows(Fido2RuntimeException.class, () -> noneAssertionFormatProcessor.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity));\n assertNotNull(ex);\n assertEquals(ex.getMessage(), \"IOException_test\");\n\n verify(log).debug(eq(\"Registration: {}\"), any(Fido2RegistrationData.class));\n verify(log).debug(eq(\"User verification option: {}\"), any(UserVerification.class));\n verify(commonVerifiers).verifyRpIdHash(any(AuthData.class), any(String.class));\n verify(authenticatorDataParser).parseCounter(any());\n verify(log).error(eq(\"Error to check none assertion: {}\"), any(String.class));\n\n verify(log, never()).error(eq(\"Error compromised device: {}\"), any(String.class));\n verifyNoInteractions(coseService, authenticatorDataVerifier);\n verifyNoMoreInteractions(log);\n }", + "completions": [ + "```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n // Arrange\n String base64AuthenticatorData = \"base64AuthenticatorData\";\n String signature = \"signature\";\n String clientDataJson = \"clientDataJson\";\n Fido2RegistrationData registration = new Fido2RegistrationData();\n Fido2AuthenticationData authenticationEntity = new Fido2AuthenticationData();\n\n when(authenticatorDataParser.parseAssertionData(base64AuthenticatorData)).thenReturn(new AuthData());\n when(authenticatorDataParser.parseCounter(any())).thenReturn(1);\n when(dataMapperService.cborReadTree(any())).thenThrow(new IOException());\n\n // Act & Assert\n Assertions.assertThrows(Fido2RuntimeException.class, () -> {\n fido2AuthenticationService.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity);\n });\n }\n```\n" + ], + "line": 38, + "token": 365, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 18, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void setUserData_invalidKey_throws()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "setUserData_invalidKey_throws", + "reference": "\t@Test\n\tpublic void setUserData_invalidKey_throws() {\n\t\t// setup\n\t\tTestParameters parameters = new TestParameters();\n\n\t\t// test\n\t\tfor (String key : new String[] { null, \"\" }) {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, \"test\");\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid key; key must be non-empty, non-null\"),\n\t\t\t\t\tex.getMessage());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@Test(expected = IllegalArgumentException.class)\n\tpublic void setUserData_invalidKey_throws() {\n\t\tTestClass testClass = new TestClass();\n\t\ttestClass.setUserData(\"\", \"value\");\n\t}\n```\n\nIn this test, we are creating an instance of the TestClass and calling the setUserData method with an empty key and a valid value. We expect this to throw an IllegalArgumentException, which is what we specify in the @Test annotation.\n" + ], + "line": 23, + "token": 170, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueDirectoryMissingErrorsOperation", + "reference": " @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ByteString foo = ByteString.copyFromUtf8(\"foo\");\n Digest subdirDigest = DIGEST_UTIL.compute(foo);\n Directory inputRoot =\n Directory.newBuilder()\n .addDirectories(\n DirectoryNode.newBuilder().setName(\"missing-subdir\").setDigest(subdirDigest))\n .build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(true);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_MISSING)\n .setSubject(\"blobs/\" + DigestUtil.toString(subdirDigest))\n .setDescription(\"The directory `/missing-subdir` was not found in the CAS.\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(10);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(DigestUtil.DEFAULT_INSTANCE_NAME + \"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n ActionKey actionKey = DigestUtil.asActionKey(executeEntry.getActionDigest());\n when(checkCacheFuture(actionKey, any(), any())).thenReturn(immediateFuture(true));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n assertThat(future.get()).isEqualTo(null);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 20, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime = Instant.now();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Start recovery with an offset range that is partially unavailable\n long startOffset = 50;\n long endOffset = 150;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry)).isEqualTo(50);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(51);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 100L,\n 200L,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 21, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public List selectHeaderAccept(String[] accepts) {\n if (accepts.length == 0) {\n return null;\n }\n for (String accept : accepts) {\n MediaType mediaType = MediaType.parseMediaType(accept);\n if (isJsonMime(mediaType) && !isProblemJsonMime(accept)) {\n return Collections.singletonList(mediaType);\n }\n }\n return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts));\n }\n```\n###test function signature:\n```java\n @Test\n public void testSelectHeaderAccept()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSelectHeaderAccept", + "reference": " @Test\n public void testSelectHeaderAccept() {\n String[] accepts = {\"application/json\", \"application/xml\"};\n assertEquals(\"application/json\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"APPLICATION/XML\", \"APPLICATION/JSON\"};\n assertEquals(\"APPLICATION/JSON\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"application/xml\", \"application/json; charset=UTF8\"};\n assertEquals(\"application/json; charset=UTF8\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"text/plain\", \"application/xml\"};\n assertEquals(\"text/plain,application/xml\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{};\n assertNull(apiClient.selectHeaderAccept(accepts));\n }", + "completions": [ + "```java\n @Test\n public void testSelectHeaderAccept() {\n // Given\n String[] accepts = {\"application/json\", \"application/problem+json\"};\n // When\n List result = selectHeaderAccept(accepts);\n // Then\n assertEquals(1, result.size());\n assertEquals(\"application/json\", result.get(0).toString());\n }\n```\n" + ], + "line": 23, + "token": 154, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollEventSentOnConsumerPoll", + "reference": " @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n SubscriptionState subscriptions = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);\n consumer = newConsumer(\n mock(FetchBuffer.class),\n new ConsumerInterceptors<>(Collections.emptyList()),\n mock(ConsumerRebalanceListenerInvoker.class),\n subscriptions,\n singletonList(new RoundRobinAssignor()),\n \"group-id\",\n \"client-id\");\n final TopicPartition tp = new TopicPartition(\"topic\", 0);\n final List> records = singletonList(\n new ConsumerRecord<>(\"topic\", 0, 2, \"key1\", \"value1\"));\n doAnswer(invocation -> Fetch.forPartition(tp, records, true))\n .when(fetchCollector)\n .collectFetch(Mockito.any(FetchBuffer.class));\n\n consumer.subscribe(singletonList(\"topic1\"));\n consumer.poll(Duration.ofMillis(100));\n verify(applicationEventHandler).add(any(PollEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n // Arrange\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n consumer.poll(Duration.ofMillis(100));\n\n // Act\n ConsumerRecords records = consumer.poll(Duration.ofMillis(100));\n\n // Assert\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 23, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetMultiplePartitionRecoveriesBehind", + "reference": " @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n final RecoveryTaskMetadata recoveryTask2 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"2\",\n \"2\",\n recoveryStartOffset * 3 + 1,\n recoveryStartOffset * 4,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask2);\n final RecoveryTaskMetadata recoveryTask21 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"21\", \"2\", recoveryStartOffset * 4 + 1, 50000, createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask21);\n await()\n .until(\n () ->\n recoveryTaskStore\n .listSync()\n .containsAll(\n List.of(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(5);\n assertThat(recoveryTasks)\n .contains(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n // Given\n String partitionId = \"partition1\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 24, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture commitStreamObject(apache.rocketmq.controller.v1.S3StreamObject streamObject,\n List compactedObjects) {\n LOGGER.info(\"commitStreamObject with streamObject: {}, compactedObjects: {}\", TextFormat.shortDebugString(streamObject),\n compactedObjects);\n\n CompletableFuture future = new CompletableFuture<>();\n try (SqlSession session = sessionFactory.openSession()) {\n if (streamObject.getObjectId() == S3Constants.NOOP_OBJECT_ID) {\n LOGGER.error(\"S3StreamObject[object-id={}] is null or objectId is unavailable\", streamObject.getObjectId());\n String msg = String.format(\"S3StreamObject[object-id=%d] is null or objectId is unavailable\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.NOT_FOUND_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n\n long committedTs = System.currentTimeMillis();\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n\n // commit object\n if (!commitObject(streamObject.getObjectId(), streamObject.getStreamId(), streamObject.getObjectSize(), session)) {\n String msg = String.format(\"S3StreamObject[object-id=%d] is not ready for commit\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.ILLEGAL_STATE_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n long dataTs = committedTs;\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n dataTs = compactedObjects.stream()\n .map(id -> {\n // mark destroy compacted object\n S3Object object = s3ObjectMapper.getById(id);\n object.setState(S3ObjectState.BOS_WILL_DELETE);\n object.setMarkedForDeletionTimestamp(new Date());\n s3ObjectMapper.markToDelete(object.getId(), new Date());\n\n // update dataTs to the min compacted object's dataTs\n com.automq.rocketmq.metadata.dao.S3StreamObject s3StreamObject =\n s3StreamObjectMapper.getByObjectId(id);\n return s3StreamObject.getBaseDataTimestamp().getTime();\n })\n .min(Long::compareTo).get();\n }\n\n List toCache = new ArrayList<>();\n\n // create a new S3StreamObject to replace committed ones\n if (streamObject.getObjectId() != S3Constants.NOOP_OBJECT_ID) {\n com.automq.rocketmq.metadata.dao.S3StreamObject newS3StreamObj =\n new com.automq.rocketmq.metadata.dao.S3StreamObject();\n newS3StreamObj.setStreamId(streamObject.getStreamId());\n newS3StreamObj.setObjectId(streamObject.getObjectId());\n newS3StreamObj.setObjectSize(streamObject.getObjectSize());\n newS3StreamObj.setStartOffset(streamObject.getStartOffset());\n newS3StreamObj.setEndOffset(streamObject.getEndOffset());\n newS3StreamObj.setBaseDataTimestamp(new Date(dataTs));\n newS3StreamObj.setCommittedTimestamp(new Date(committedTs));\n s3StreamObjectMapper.create(newS3StreamObj);\n toCache.add(newS3StreamObj);\n }\n\n // delete the compactedObjects of S3Stream\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n compactedObjects.forEach(id -> s3StreamObjectMapper.delete(null, null, id));\n }\n session.commit();\n\n // Update Cache\n s3StreamObjectCache.cache(streamObject.getStreamId(), toCache);\n s3StreamObjectCache.onCompact(streamObject.getStreamId(), compactedObjects);\n\n LOGGER.info(\"S3StreamObject[object-id={}] commit success, compacted objects: {}\",\n streamObject.getObjectId(), compactedObjects);\n future.complete(null);\n } catch (Exception e) {\n LOGGER.error(\"CommitStream failed\", e);\n ControllerException ex = new ControllerException(Code.INTERNAL_VALUE, \"CommitStream failed\" + e.getMessage());\n future.completeExceptionally(ex);\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCommitStreamObject() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCommitStreamObject", + "reference": " @Test\n public void testCommitStreamObject() throws IOException {\n long objectId, streamId = 1;\n\n try (S3MetadataService metadataService = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n objectId = metadataService.prepareS3Objects(3, 5).get();\n } catch (Exception e) {\n throw new RuntimeException(e);\n }\n\n S3StreamObject news3StreamObject = S3StreamObject.newBuilder()\n .setObjectId(objectId + 2)\n .setStreamId(streamId)\n .setObjectSize(111L)\n .build();\n\n try (SqlSession session = this.getSessionFactory().openSession()) {\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n buildS3StreamObjs(objectId, 2, 3, 100L).forEach(s3StreamObjectMapper::create);\n session.commit();\n }\n\n try (S3MetadataService service = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n List compactedObjects = new ArrayList<>();\n compactedObjects.add(objectId);\n compactedObjects.add(objectId + 1);\n service.commitStreamObject(news3StreamObject, compactedObjects);\n }\n\n try (SqlSession session = getSessionFactory().openSession()) {\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3Object s3Object = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object.getState());\n\n S3Object s3Object1 = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object1.getState());\n\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n for (long index = objectId; index < objectId + 2; index++) {\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(index);\n Assertions.assertNull(object);\n }\n\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(objectId + 2);\n Assertions.assertEquals(111L, object.getObjectSize());\n Assertions.assertEquals(streamId, object.getStreamId());\n Assertions.assertTrue(object.getBaseDataTimestamp().getTime() > 0);\n Assertions.assertTrue(object.getCommittedTimestamp().getTime() > 0);\n }\n }", + "completions": [ + "```java\n @Test\n public void testCommitStreamObject() throws IOException {\n // Given\n apache.rocketmq.controller.v1.S3StreamObject streamObject = new apache.rocketmq.controller.v1.S3StreamObject();\n streamObject.setObjectId(1L);\n streamObject.setStreamId(2L);\n streamObject.setObjectSize(3L);\n streamObject.setStartOffset(4L);\n streamObject.setEndOffset(5L);\n\n List compactedObjects = Arrays.asList(6L, 7L, 8L);\n\n // When\n CompletableFuture future = controller.commitStreamObject(streamObject, compactedObjects);\n\n // Then\n assertDoesNotThrow(() -> future.get());\n }\n```\n" + ], + "line": 85, + "token": 664, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 25, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyTooManyResults", + "reference": " @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(totalHits.value).thenReturn(2L);\n Mockito.when(hits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(response.getHits()).thenReturn(hits);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processService.getProcessByKey(processDefinitionKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Expected exception message to match\",\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n e.getMessage());\n }\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 26, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String doGet(String url) {\n try (CloseableHttpClient client = HttpClients.createDefault()) {\n HttpGet request = new HttpGet(url);\n HttpResponse response = client.execute(request);\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n return EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n }\n } catch (IOException e) {\n log.error(\"DoGet exception.\", e);\n\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n void testDoGet()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDoGet", + "reference": " @Test\n void testDoGet() {\n String result = HttpUtil.doGet(\"http://baidu.com\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testDoGet() {\n String url = \"http://example.com\";\n String expectedResponse = \"Hello, world!\";\n\n // Mock the HttpClient and HttpResponse\n CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);\n HttpResponse mockResponse = Mockito.mock(HttpResponse.class);\n StatusLine mockStatusLine = Mockito.mock(StatusLine.class);\n HttpEntity mockEntity = Mockito.mock(HttpEntity.class);\n\n // Set up the mock objects to return the expected values\n Mockito.when(mockClient.execute(Mockito.any(HttpGet.class))).thenReturn(mockResponse);\n Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);\n Mockito.when(mockStatusLine.getStatusCode()).thenReturn(HttpStatus.SC_OK);\n Mockito.when(mockResponse.getEntity()).thenReturn(mockEntity);\n Mockito.when(EntityUtils.toString(mockEntity, StandardCharsets.UTF_8)).thenReturn(expectedResponse);\n\n // Call the method under test\n String actualResponse = YourClassName.doGet(url);\n\n // Verify the result\n Assertions.assertEquals(expectedResponse, actualResponse);\n }\n```\nPlease replace `YourClassName` with the actual name of the class where the `doGet` method is located.\n" + ], + "line": 23, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n dpopService.validateDpopThumprint(null, \"any\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n // Given\n String existingThumprint = \"\";\n String requestThumprint = \"requestThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n try {\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), e.getResponse().getStatus());\n assertEquals(\"Thumprint does not match\", e.getMessage());\n }\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 28, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDataMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDataMapTest", + "reference": " @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"hera_log_doris_table_2_11\";\n for (int i = 0; i < 1; i++) {\n String jsonString = \"{\\\"linenumber\\\":1256,\\\"tailId\\\":90219,\\\"mqtag\\\":\\\"tags_4_13_90219\\\",\\\"filename\\\":\\\"/home/work/log/log-agent/server.log\\\",\\\"tail\\\":\\\"demo-client-agent\\\",\\\"mqtopic\\\":\\\"90219_hera-demo-client\\\",\\\"message\\\":\\\"2024-01-10 19:34:40,360|INFO ||NettyClientPublicExecutor_1|c.x.mone.log.agent.rpc.task.PingTask|83|ping res: log-agent-server:2022-12-05:0.0.2->2024-01-10 19:34:40 358->10.53.129.250:9899\\\",\\\"logstore\\\":\\\"测试doris日志\\\",\\\"logip\\\":\\\"10.53.129.176\\\",\\\"timestamp\\\":1704886481241}\";\n\n Map map = gson.fromJson(jsonString, Map.class);\n dorisStreamLoad.sendData(database, tableName, columnListNew, map);\n }\n }", + "completions": [ + "```java\n @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 29, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyNoResults", + "reference": " @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final NotFoundException exception =\n Assertions.assertThrows(\n NotFoundException.class, () -> processService.getProcessByKey(processDefinitionKey));\n\n // Then\n Assertions.assertEquals(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 30, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String getDiagramByKey(Long processDefinitionKey) {\n final IdsQueryBuilder q = idsQuery().addIds(processDefinitionKey.toString());\n\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n\n if (response.getHits().getTotalHits().value == 1) {\n final Map result = response.getHits().getHits()[0].getSourceAsMap();\n return (String) result.get(BPMN_XML);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with id '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with id '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\n \"Exception occurred, while obtaining the process diagram: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetDiagramByKeyTooManyResults", + "reference": " @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getDiagramByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 1L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHits.Relation.EQUAL_TO));\n\n // When\n try {\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n processDiagramService.getDiagramByKey(processDefinitionKey);\n } catch (NotFoundException e) {\n // Then\n Assert.assertEquals(\n \"Could not find unique process with id '\" + processDefinitionKey + \"'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 37, + "token": 309, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 31, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n void shutdown() {\n LOG.info(\"Running shutdown hook.\");\n try {\n serviceManager.stopAsync().awaitStopped(30, TimeUnit.SECONDS);\n } catch (Exception e) {\n // stopping timed out\n LOG.error(\"ServiceManager shutdown timed out\", e);\n }\n try {\n curatorFramework.unwrap().close();\n } catch (Exception e) {\n LOG.error(\"Error while closing curatorFramework \", e);\n }\n LOG.info(\"Shutting down LogManager\");\n LogManager.shutdown();\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexerShutdownTwice() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexerShutdownTwice", + "reference": " @Test\n public void testIndexerShutdownTwice() throws Exception {\n startKafkaServer();\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // Create a live partition for this partiton\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTimeMs = 1;\n final long endTimeMs = 100;\n final long maxOffset = 30;\n SnapshotMetadata livePartition0 =\n new SnapshotMetadata(\n name + \"live0\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"0\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition0);\n\n SnapshotMetadata livePartition1 =\n new SnapshotMetadata(\n name + \"live1\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"1\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition1);\n\n final SnapshotMetadata partition0 =\n new SnapshotMetadata(name, path, startTimeMs, endTimeMs, maxOffset, \"0\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition0);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .containsOnly(livePartition1, livePartition0, partition0);\n\n // Empty consumer offset since there is no prior consumer.\n astraIndexer =\n new AstraIndexer(\n chunkManagerUtil.chunkManager,\n curatorFramework,\n makeIndexerConfig(50),\n getKafkaConfig(),\n metricsRegistry);\n astraIndexer.startAsync();\n astraIndexer.awaitRunning(DEFAULT_START_STOP_DURATION);\n await().until(() -> kafkaServer.getConnectedConsumerGroups() == 1);\n\n // Produce more messages since the recovery task is created for head.\n produceMessagesToKafka(kafkaServer.getBroker(), startTime);\n\n consumeMessagesAndSearchMessagesTest(100, 1);\n\n // Live snapshot is deleted, recovery task is created.\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(livePartition1, partition0);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isEqualTo(4);\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).size()).isEqualTo(1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(searchMetadataStore).size()).isEqualTo(1);\n RecoveryTaskMetadata recoveryTask1 =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).get(0);\n assertThat(recoveryTask1.startOffset).isEqualTo(31);\n assertThat(recoveryTask1.endOffset).isEqualTo(99);\n assertThat(recoveryTask1.partitionId).isEqualTo(\"0\");\n\n // Shutting down is idempotent. So, doing it twice shouldn't throw an error.\n astraIndexer.shutDown();\n astraIndexer.shutDown();\n astraIndexer = null;\n }", + "completions": [ + "```java\n @Test\n public void testIndexerShutdownTwice() throws Exception {\n // Given\n Indexer indexer = new Indexer();\n indexer.serviceManager = Mockito.mock(ServiceManager.class);\n indexer.curatorFramework = Mockito.mock(CuratorFramework.class);\n Mockito.when(indexer.serviceManager.stopAsync()).thenReturn(Mockito.mock(ServiceManager.class));\n Mockito.when(indexer.serviceManager.awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class))).thenThrow(new TimeoutException());\n Mockito.when(indexer.curatorFramework.unwrap()).thenReturn(Mockito.mock(Closeable.class));\n Mockito.doThrow(new IOException()).when(indexer.curatorFramework.unwrap()).close();\n\n // When\n indexer.shutdown();\n\n // Then\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).stopAsync();\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class));\n Mockito.verify(indexer.curatorFramework.unwrap(), Mockito.times(1)).close();\n }\n```\n" + ], + "line": 27, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(addressList,\"test\",\"test\",\"test\", SystemEnum.AGENT, MsgTypeEnum.ALARM);\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = null;\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 33, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyWithException", + "reference": " @Test\n public void testGetProcessByKeyWithException() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(OperateRuntimeException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(totalHits.value).thenReturn(2L);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final Executable executable = () -> processService.getProcessByKey(processDefinitionKey);\n\n // Then\n final NotFoundException exception =\n Assertions.assertThrows(NotFoundException.class, executable);\n Assertions.assertEquals(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 34, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testConcurrentSearches() throws InterruptedException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testConcurrentSearches", + "reference": " @Test\n public void testConcurrentSearches() throws InterruptedException {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n AtomicInteger searchFailures = new AtomicInteger(0);\n AtomicInteger statsFailures = new AtomicInteger(0);\n AtomicInteger searchExceptions = new AtomicInteger(0);\n AtomicInteger successfulRuns = new AtomicInteger(0);\n\n Runnable searchRun =\n () -> {\n for (int i = 0; i < 100; i++) {\n try {\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n 0L,\n MAX_TIME,\n 100,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n if (babies.hits.size() != 2) {\n searchFailures.addAndGet(1);\n } else {\n successfulRuns.addAndGet(1);\n }\n } catch (Exception e) {\n searchExceptions.addAndGet(1);\n }\n }\n };\n\n Thread t1 = new Thread(searchRun);\n Thread t2 = new Thread(searchRun);\n t1.start();\n t2.start();\n t1.join();\n t2.join();\n assertThat(searchExceptions.get()).isEqualTo(0);\n assertThat(statsFailures.get()).isEqualTo(0);\n assertThat(searchFailures.get()).isEqualTo(0);\n assertThat(successfulRuns.get()).isEqualTo(200);\n }", + "completions": [ + "```java\n @Test\n public void testConcurrentSearches() throws InterruptedException {\n // Create a fixed thread pool with 10 threads\n ExecutorService executorService = Executors.newFixedThreadPool(10);\n\n // Create a list to hold the futures\n List>> futures = new ArrayList<>();\n\n // Submit 100 search tasks to the executor service\n for (int i = 0; i < 100; i++) {\n // Create a search task\n Callable> searchTask = () -> {\n // Generate a random dataset\n String dataset = UUID.randomUUID().toString();\n\n // Generate a random query string\n String queryStr = UUID.randomUUID().toString();\n\n // Generate a random start time\n long startTimeMsEpoch = System.currentTimeMillis() - ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random end time\n long endTimeMsEpoch = startTimeMsEpoch + ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random number of hits requested\n int howMany = ThreadLocalRandom.current().nextInt(100);\n\n // Generate a random aggregation builder\n AggBuilder aggBuilder = new AggBuilder();\n\n // Perform the search\n return logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n };\n\n // Submit the search task to the executor service and add the future to the list\n futures.add(executorService.submit(searchTask));\n }\n\n // Wait for all search tasks to complete\n for (Future> future : futures) {\n future.get();\n }\n\n // Shutdown the executor service\n executorService.shutdown();\n }\n```\n" + ], + "line": 85, + "token": 608, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 35, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String sendGetData(String path, Map headMap) {\n log.info(\"request tcm, path:{},headMap:{} \",path,headMap);\n String result = \"\";\n CloseableHttpResponse response =null;\n try(CloseableHttpClient httpClient = HttpClients.createDefault()) {\n // 创建get方式请求对象\n HttpGet httpGet = new HttpGet(path);\n if (null != headMap) {\n for (Map.Entry entry : headMap.entrySet()) {\n httpGet.addHeader(entry.getKey(), entry.getValue());\n }\n }\n // 通过请求对象获取响应对象\n response = httpClient.execute(httpGet);\n\n // 获取结果实体\n // 判断网络连接状态码是否正常(0--200都数正常)\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n result = EntityUtils.toString(response.getEntity(), UTF_8);\n } else {\n log.error(\"get请求传输 异常.url:{}, headMap:{}\", path, headMap);\n }\n } catch (ClientProtocolException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n } catch (IOException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n }\n finally {\n try {\n if (null!=response){\n response.close();\n }\n } catch (IOException e) {\n log.error(\"关闭response 异常 \", e);\n }\n }\n log.debug(result);\n return result;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendGetData()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendGetData", + "reference": " @Test\n void testSendGetData() {\n final Map headMap = new HashMap<>();\n final String result = HttpUtils.sendGetData(\"path\", headMap);\n assertThat(result).isEqualTo(\"\");\n }", + "completions": [ + "```java\n @Test\n void testSendGetData() {\n String path = \"http://localhost:8080/api/test\";\n Map headMap = new HashMap<>();\n headMap.put(\"Content-Type\", \"application/json\");\n headMap.put(\"Authorization\", \"Bearer your_token\");\n\n String result = sendGetData(path, headMap);\n\n // Add assertions to verify the result\n // For example, if the result is expected to be a JSON string, you can use a JSON library to parse the result and assert the values\n // Here is an example using the JSON library Jackson\n ObjectMapper objectMapper = new ObjectMapper();\n JsonNode jsonNode = objectMapper.readTree(result);\n assertEquals(\"expectedValue\", jsonNode.get(\"key\").asText());\n }\n```\nPlease replace \"your_token\" with the actual token you want to use for the test. Also, replace \"expectedValue\" and \"key\" with the actual expected value and key from the JSON response.\n" + ], + "line": 51, + "token": 367, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 36, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ByteString retainKeyPrefix(ByteString tenantNS, List topicFilterLevels) {\n ByteString prefix = ByteString.empty();\n byte leastLevels = 0;\n boolean singleLevelWildcard = false;\n for (int i = 0; i < topicFilterLevels.size(); i++) {\n String tfl = topicFilterLevels.get(i);\n if (\"+\".equals(tfl)) {\n leastLevels++;\n singleLevelWildcard = true;\n continue;\n }\n if (\"#\".equals(tfl)) {\n break;\n }\n leastLevels++;\n if (!singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(tfl));\n }\n if (i + 1 < topicFilterLevels.size()) {\n if (!topicFilterLevels.get(i + 1).equals(\"#\") && !singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(NUL));\n }\n }\n }\n return tenantNS.concat(unsafeWrap(new byte[] {leastLevels})).concat(prefix);\n }\n```\n###test function signature:\n```java\n @Test\n public void testRetainKeyPrefix()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRetainKeyPrefix", + "reference": " @Test\n public void testRetainKeyPrefix() {\n String tenantId = \"tenantA\";\n ByteString tenantNS = tenantNS(tenantId);\n assertEquals(retainKeyPrefix(tenantNS, parse(\"#\", false)), tenantNS.concat(levelByte(0)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/#\", false)),\n tenantNS.concat(levelByte(1).concat(copyFromUtf8(\"a\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/+\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(\"a\").concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/b\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)).concat(copyFromUtf8(\"b\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/#\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+\", false)),\n tenantNS.concat(\n levelByte(3).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+/+\", false)),\n tenantNS.concat(\n levelByte(4).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/\", false)),\n tenantNS.concat(levelByte(4).concat(copyFromUtf8(NUL))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/+/\", false)),\n tenantNS.concat(levelByte(5).concat(copyFromUtf8(NUL))));\n }", + "completions": [ + "```java\n @Test\n public void testRetainKeyPrefix() {\n ByteString tenantNS = ByteString.copyFromUtf8(\"tenant1\");\n List topicFilterLevels = Arrays.asList(\"level1\", \"level2\", \"+\", \"#\");\n ByteString expectedPrefix = ByteString.copyFromUtf8(\"tenant1\\u0001level1\\u0000level2\");\n ByteString actualPrefix = retainKeyPrefix(tenantNS, topicFilterLevels);\n assertEquals(expectedPrefix, actualPrefix);\n }\n```\n" + ], + "line": 37, + "token": 255, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 37, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n @Transactional(rollbackFor = Exception.class)\n public void updateFileConfigMaster(Long id) {\n // 校验存在\n validateFileConfigExists(id);\n // 更新其它为非 master\n fileConfigMapper.updateBatch(new FileConfigDO().setMaster(false));\n // 更新\n fileConfigMapper.updateById(new FileConfigDO().setId(id).setMaster(true));\n // 发送刷新配置的消息\n TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {\n\n @Override\n public void afterCommit() {\n fileConfigProducer.sendFileConfigRefreshMessage();\n }\n\n });\n }\n```\n###test function signature:\n```java\n @Test\n public void testUpdateFileConfigMaster_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testUpdateFileConfigMaster_success", + "reference": " @Test\n public void testUpdateFileConfigMaster_success() {\n // mock 数据\n FileConfigDO dbFileConfig = randomFileConfigDO().setMaster(false);\n fileConfigMapper.insert(dbFileConfig);// @Sql: 先插入出一条存在的数据\n FileConfigDO masterFileConfig = randomFileConfigDO().setMaster(true);\n fileConfigMapper.insert(masterFileConfig);// @Sql: 先插入出一条存在的数据\n\n // 调用\n fileConfigService.updateFileConfigMaster(dbFileConfig.getId());\n // 断言数据\n assertTrue(fileConfigMapper.selectById(dbFileConfig.getId()).getMaster());\n assertFalse(fileConfigMapper.selectById(masterFileConfig.getId()).getMaster());\n // verify 调用\n verify(fileConfigProducer).sendFileConfigRefreshMessage();\n }", + "completions": [ + "```java\n @Test\n public void testUpdateFileConfigMaster_success() {\n // 准备数据\n FileConfigDO fileConfigDO = new FileConfigDO();\n fileConfigDO.setId(1L);\n fileConfigDO.setMaster(false);\n fileConfigMapper.insert(fileConfigDO);\n\n // 执行方法\n fileConfigService.updateFileConfigMaster(1L);\n\n // 验证结果\n FileConfigDO updatedFileConfigDO = fileConfigMapper.selectById(1L);\n Assert.assertTrue(updatedFileConfigDO.getMaster());\n\n // 清理数据\n fileConfigMapper.deleteById(1L);\n }\n```\n" + ], + "line": 28, + "token": 164, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\t@Override\n\tpublic void add(List documents) {\n\n\t\tAssert.notNull(documents, \"The document list should not be null.\");\n\t\tif (CollectionUtils.isEmpty(documents)) {\n\t\t\treturn; // nothing to do;\n\t\t}\n\n\t\tfinal var searchDocuments = documents.stream().map(document -> {\n\t\t\tfinal var embeddings = this.embeddingClient.embed(document);\n\t\t\tSearchDocument searchDocument = new SearchDocument();\n\t\t\tsearchDocument.put(ID_FIELD_NAME, document.getId());\n\t\t\tsearchDocument.put(EMBEDDING_FIELD_NAME, embeddings);\n\t\t\tsearchDocument.put(CONTENT_FIELD_NAME, document.getContent());\n\t\t\tsearchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());\n\n\t\t\t// Add the filterable metadata fields as top level fields, allowing filler\n\t\t\t// expressions on them.\n\t\t\tfor (MetadataField mf : this.filterMetadataFields) {\n\t\t\t\tif (document.getMetadata().containsKey(mf.name())) {\n\t\t\t\t\tsearchDocument.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name()));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn searchDocument;\n\t\t}).toList();\n\n\t\tIndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);\n\n\t\tfor (IndexingResult indexingResult : result.getResults()) {\n\t\t\tAssert.isTrue(indexingResult.isSucceeded(),\n\t\t\t\t\tString.format(\"Document with key %s upload is not successfully\", indexingResult.getKey()));\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void addAndSearchTest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "addAndSearchTest", + "reference": "\t@Test\n\tpublic void addAndSearchTest() {\n\n\t\tcontextRunner.run(context -> {\n\n\t\t\tVectorStore vectorStore = context.getBean(VectorStore.class);\n\n\t\t\tvectorStore.add(documents);\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\t\t\t}, hasSize(1));\n\n\t\t\tList results = vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\n\t\t\tassertThat(results).hasSize(1);\n\t\t\tDocument resultDoc = results.get(0);\n\t\t\tassertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());\n\t\t\tassertThat(resultDoc.getContent()).contains(\"The Great Depression (1929–1939) was an economic shock\");\n\t\t\tassertThat(resultDoc.getMetadata()).hasSize(2);\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"meta2\");\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"distance\");\n\n\t\t\t// Remove all documents from the store\n\t\t\tvectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Hello\").withTopK(1));\n\t\t\t}, hasSize(0));\n\t\t});\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void addAndSearchTest() {\n\t\t// Given\n\t\tList documents = new ArrayList<>();\n\t\tDocument document1 = new Document();\n\t\tdocument1.setId(\"1\");\n\t\tdocument1.setContent(\"This is a test document\");\n\t\tdocument1.setMetadata(new HashMap<>());\n\t\tdocument1.getMetadata().put(\"author\", \"John Doe\");\n\t\tdocument1.getMetadata().put(\"date\", \"2022-01-01\");\n\t\tdocuments.add(document1);\n\n\t\tDocument document2 = new Document();\n\t\tdocument2.setId(\"2\");\n\t\tdocument2.setContent(\"This is another test document\");\n\t\tdocument2.setMetadata(new HashMap<>());\n\t\tdocument2.getMetadata().put(\"author\", \"Jane Doe\");\n\t\tdocument2.getMetadata().put(\"date\", \"2022-01-02\");\n\t\tdocuments.add(document2);\n\n\t\t// When\n\t\tsearchService.add(documents);\n\n\t\t// Then\n\t\tSearchRequest searchRequest = new SearchRequest();\n\t\tsearchRequest.setQuery(\"test\");\n\t\tsearchRequest.setFilter(new HashMap<>());\n\t\tsearchRequest.getFilter().put(\"author\", \"John Doe\");\n\t\tSearchResponse searchResponse = searchService.search(searchRequest);\n\n\t\tAssert.assertEquals(1, searchResponse.getDocuments().size());\n\t\tAssert.assertEquals(\"1\", searchResponse.getDocuments().get(0).getId());\n\t}\n```\n" + ], + "line": 39, + "token": 312, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 39, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {\n long startTime = System.currentTimeMillis();\n\n // 校验参数\n try {\n validateParams(metrics);\n } catch (Exception e) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(e.getMessage());\n return;\n }\n\n HttpContext httpContext = createHttpContext(metrics.getNginx());\n HttpUriRequest request = createHttpRequest(metrics.getNginx());\n try (CloseableHttpResponse response = CommonHttpClient.getHttpClient().execute(request, httpContext)){\n // 发起http请求,获取响应数据\n int statusCode = response.getStatusLine().getStatusCode();\n if (statusCode != SUCCESS_CODE) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(\"StatusCode \" + statusCode);\n return;\n }\n String resp = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n\n Long responseTime = System.currentTimeMillis() - startTime;\n // 根据metrics name选择调用不同解析方法\n if (NGINX_STATUS_NAME.equals(metrics.getName()) || AVAILABLE.equals(metrics.getName())) {\n parseNginxStatusResponse(builder, resp, metrics, responseTime);\n } else if (REQ_STATUS_NAME.equals(metrics.getName())) {\n parseReqStatusResponse(builder, resp, metrics, responseTime);\n }\n } catch (Exception e) {\n String errorMsg = CommonUtil.getMessageFromThrowable(e);\n log.info(errorMsg);\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(errorMsg);\n } finally {\n if (request != null) {\n request.abort();\n }\n }\n\n }\n```\n###test function signature:\n```java\n @Test\n public void testNginxCollect()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNginxCollect", + "reference": " @Test\n public void testNginxCollect() {\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n long monitorId = 999;\n String app = \"testNginx\";\n\n Metrics metrics = new Metrics();\n metrics.setNginx(nginxProtocol);\n nginxCollect.collect(builder, monitorId, app, metrics);\n }", + "completions": [ + "```java\n @Test\n public void testNginxCollect() {\n // 创建一个Metrics对象\n Metrics metrics = new Metrics();\n metrics.setName(\"nginx_status\");\n metrics.setNginx(new Nginx(\"http://localhost:8080/nginx_status\"));\n\n // 创建一个CollectRep.MetricsData.Builder对象\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n\n // 创建一个NginxCollect对象\n NginxCollect nginxCollect = new NginxCollect();\n\n // 调用collect方法\n nginxCollect.collect(builder, 1, \"test\", metrics);\n\n // 验证结果\n assertEquals(CollectRep.Code.SUCCESS, builder.getCode());\n assertNotNull(builder.getData());\n assertTrue(builder.getData().getMetrics().containsKey(\"active_connections\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"accepts\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"handled\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"requests\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"reading\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"writing\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"waiting\"));\n }\n```\n" + ], + "line": 51, + "token": 335, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 40, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Transactional(rollbackOn = Exception.class)\n public void submit(@NonNull PartitionPlanConfig partitionPlanConfig)\n throws SchedulerException, ClassNotFoundException {\n Long databaseId = partitionPlanConfig.getDatabaseId();\n Validate.notNull(databaseId, \"DatabaseId can not be null\");\n // disable all related partition plan task\n Database database = this.databaseService.detail(databaseId);\n disablePartitionPlan(database.getId());\n PartitionPlanEntity partitionPlanEntity = modelToEntity(partitionPlanConfig);\n partitionPlanEntity = this.partitionPlanRepository.save(partitionPlanEntity);\n if (!partitionPlanConfig.isEnabled()\n || CollectionUtils.isEmpty(partitionPlanConfig.getPartitionTableConfigs())) {\n log.info(\"Partition plan is disabled or table config is empty, do nothing and return\");\n return;\n }\n Validate.isTrue(partitionPlanConfig.getCreationTrigger() != null, \"Creation trigger can not be null\");\n if (partitionPlanConfig.getDroppingTrigger() == null) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(partitionPlanConfig.getPartitionTableConfigs(),\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n return;\n }\n Map> strategy2TblCfgs =\n partitionPlanConfig.getPartitionTableConfigs().stream().flatMap(tableConfig -> {\n Map> strategy2Cfgs =\n tableConfig.getPartitionKeyConfigs().stream()\n .collect(Collectors.groupingBy(PartitionPlanKeyConfig::getStrategy));\n return strategy2Cfgs.values().stream().map(cfgs -> {\n PartitionPlanTableConfig cfg = new PartitionPlanTableConfig();\n cfg.setPartitionKeyConfigs(cfgs);\n cfg.setTableName(tableConfig.getTableName());\n cfg.setEnabled(tableConfig.isEnabled());\n cfg.setPartitionNameInvoker(tableConfig.getPartitionNameInvoker());\n cfg.setPartitionNameInvokerParameters(tableConfig.getPartitionNameInvokerParameters());\n return cfg;\n });\n }).collect(Collectors.groupingBy(cfg -> cfg.getPartitionKeyConfigs().get(0).getStrategy()));\n List createConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.CREATE);\n List dropConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.DROP);\n if (CollectionUtils.isNotEmpty(createConfigs)) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(createConfigs,\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n if (CollectionUtils.isNotEmpty(dropConfigs)) {\n ScheduleEntity dropScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getDroppingTrigger());\n createPartitionPlanTables(dropConfigs,\n partitionPlanEntity.getId(), dropScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "submit_bothCreateAndDropTrigger_submitSucceed", + "reference": " @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n PartitionPlanTableConfig tableConfig = new PartitionPlanTableConfig();\n tableConfig.setTableName(MYSQL_REAL_RANGE_TABLE_NAME);\n tableConfig.setPartitionNameInvoker(\"CUSTOM_PARTITION_NAME_GENERATOR\");\n SqlExprBasedGeneratorConfig config = new SqlExprBasedGeneratorConfig();\n config.setGenerateExpr(\"concat('p', date_format(from_unixtime(unix_timestamp(\"\n + \"STR_TO_DATE(20240125, '%Y%m%d')) + \"\n + PartitionPlanVariableKey.INTERVAL.getVariable() + \"), '%Y%m%d'))\");\n config.setIntervalGenerateExpr(\"86400\");\n tableConfig.setPartitionNameInvokerParameters(getSqlExprBasedNameGeneratorParameters(config));\n PartitionPlanKeyConfig c3Create = getMysqlc3CreateConfig();\n PartitionPlanKeyConfig datekeyCreate = getMysqldatekeyCreateConfig();\n PartitionPlanKeyConfig dropConfig = getDropConfig();\n tableConfig.setPartitionKeyConfigs(Arrays.asList(c3Create, datekeyCreate, dropConfig));\n\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setPartitionTableConfigs(Collections.singletonList(tableConfig));\n partitionPlanConfig.setFlowInstanceId(1L);\n partitionPlanConfig.setTimeoutMillis(180000L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setDatabaseId(1L);\n\n long t1 = System.currentTimeMillis();\n long t2 = System.currentTimeMillis() + 1000;\n TriggerConfig createTrigger = TestRandom.nextObject(TriggerConfig.class);\n createTrigger.setStartAt(new Date(t1));\n partitionPlanConfig.setCreationTrigger(createTrigger);\n\n TriggerConfig dropTrigger = TestRandom.nextObject(TriggerConfig.class);\n dropTrigger.setStartAt(new Date(t2));\n partitionPlanConfig.setDroppingTrigger(dropTrigger);\n Database database = TestRandom.nextObject(Database.class);\n database.setId(1L);\n Mockito.when(this.databaseService.detail(1L)).thenReturn(database);\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t1;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t2;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.doNothing().when(this.scheduleService).enable(Mockito.isA(ScheduleEntity.class));\n this.partitionPlanScheduleService.submit(partitionPlanConfig);\n List actuals = this.partitionPlanTableRepository.findAll();\n Set scheduleIds = actuals.stream()\n .map(PartitionPlanTableEntity::getScheduleId).collect(Collectors.toSet());\n Assert.assertEquals(scheduleIds.size(), 2);\n }", + "completions": [ + "```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n // Given\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setDatabaseId(1L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanConfig.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanConfig.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanConfig.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanConfig.setTaskId(\"taskId\");\n partitionPlanConfig.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanConfig.setTimeoutMillis(1000L);\n\n Database database = new Database();\n database.setId(1L);\n database.setName(\"database\");\n database.setType(DatabaseType.MYSQL);\n database.setHost(\"localhost\");\n database.setPort(3306);\n database.setUsername(\"root\");\n database.setPassword(\"password\");\n database.setDatabaseName(\"test\");\n database.setProperties(\"{}\");\n database.setStatus(DatabaseStatus.ONLINE);\n database.setCreateTime(new Date());\n database.setUpdateTime(new Date());\n when(databaseService.detail(1L)).thenReturn(database);\n\n PartitionPlanEntity partitionPlanEntity = new PartitionPlanEntity();\n partitionPlanEntity.setId(1L);\n partitionPlanEntity.setDatabaseId(1L);\n partitionPlanEntity.setEnabled(true);\n partitionPlanEntity.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanEntity.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanEntity.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanEntity.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanEntity.setTaskId(\"taskId\");\n partitionPlanEntity.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanEntity.setTimeoutMillis(1000L);\n when(partitionPlanRepository.save(any(PartitionPlanEntity.class))).thenReturn(partitionPlanEntity);\n\n ScheduleEntity createScheduleEntity = new ScheduleEntity();\n createScheduleEntity.setId(1L);\n createScheduleEntity.setDatabaseId(1L);\n createScheduleEntity.setTriggerConfig(new TriggerConfig(TriggerType" + ], + "line": 70, + "token": 530, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 41, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAllQueryWithFullTextSearchDisabled", + "reference": " @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n Instant time = Instant.now();\n Trace.KeyValue customField =\n Trace.KeyValue.newBuilder()\n .setVStr(\"value\")\n .setKey(\"customField\")\n .setFieldType(Schema.SchemaFieldType.KEYWORD)\n .build();\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time, List.of(customField)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n SearchResult termQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"customField:value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(termQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermStrQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermStrQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermNumericQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermNumericQuery.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"*\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNull(result.getInternalAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 42, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollExecutedCommitAsyncCallbacks", + "reference": " @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks() {\n consumer = newConsumer();\n MockCommitCallback callback = new MockCommitCallback();\n completeCommitAsyncApplicationEventSuccessfully();\n doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n completeFetchedCommittedOffsetApplicationEventSuccessfully(mkMap());\n\n consumer.assign(Collections.singleton(new TopicPartition(\"foo\", 0)));\n assertDoesNotThrow(() -> consumer.commitAsync(new HashMap<>(), callback));\n assertMockCommitCallbackInvoked(() -> consumer.poll(Duration.ZERO),\n callback,\n null);\n }", + "completions": [ + "```java\n @Test\n public void testPoll() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertNotNull(records);\n assertEquals(0, records.count());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 43, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static URI resolve(URI base, URI ref) {\n String path = ref.getRawPath();\n if (path != null && !path.isEmpty()) {\n path = base.getRawPath();\n if (path == null || !path.endsWith(\"/\")) {\n try {\n base = new URI(base.getScheme(), base.getAuthority(), base.getPath() + '/', null, null);\n } catch (URISyntaxException e) {\n throw new IllegalStateException(e);\n }\n }\n }\n return URIUtils.resolve(base, ref);\n }\n```\n###test function signature:\n```java\n @Test\n void testResolve_BasePathTrailingSlash()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testResolve_BasePathTrailingSlash", + "reference": " @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://host/sub/dir/\");\n assertEquals(\"http://host/sub/dir/file.jar\", resolve(base, \"file.jar\"));\n assertEquals(\"http://host/sub/dir/dir/file.jar\", resolve(base, \"dir/file.jar\"));\n assertEquals(\"http://host/sub/dir/?arg=val\", resolve(base, \"?arg=val\"));\n assertEquals(\"http://host/sub/dir/file?arg=val\", resolve(base, \"file?arg=val\"));\n assertEquals(\"http://host/sub/dir/dir/file?arg=val\", resolve(base, \"dir/file?arg=val\"));\n }", + "completions": [ + "```java\n @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://example.com/path/\");\n URI ref = URI.create(\"http://example.com/path/file.txt\");\n URI expected = URI.create(\"http://example.com/path/file.txt\");\n URI actual = resolve(base, ref);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 25, + "token": 187, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_collection_of_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_collection_of_non_serializable_object", + "reference": " @Test\n public void should_clone_collection_of_non_serializable_object() {\n List original = new ArrayList<>();\n original.add(new NonSerializableObject(\"value\"));\n List cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_collection_of_non_serializable_object() {\n // Given\n List nonSerializableList = new ArrayList<>();\n nonSerializableList.add(new NonSerializableObject(\"test\"));\n nonSerializableList.add(new NonSerializableObject(\"test2\"));\n\n // When\n List clonedList = objectCloner.clone(nonSerializableList);\n\n // Then\n assertNotSame(nonSerializableList, clonedList);\n assertEquals(nonSerializableList.size(), clonedList.size());\n for (int i = 0; i < nonSerializableList.size(); i++) {\n assertNotSame(nonSerializableList.get(i), clonedList.get(i));\n assertEquals(nonSerializableList.get(i), clonedList.get(i));\n }\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 45, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssign()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssign", + "reference": " @Test\n public void testAssign() {\n consumer = newConsumer();\n final TopicPartition tp = new TopicPartition(\"foo\", 3);\n consumer.assign(singleton(tp));\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().contains(tp));\n verify(applicationEventHandler).add(any(AssignmentChangeEvent.class));\n verify(applicationEventHandler).add(any(NewTopicsMetadataUpdateRequestEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testAssign() {\n // Create a mock Consumer instance\n Consumer mockConsumer = mock(Consumer.class);\n\n // Create a collection of TopicPartition instances\n Collection partitions = new ArrayList<>();\n partitions.add(new TopicPartition(\"testTopic\", 0));\n partitions.add(new TopicPartition(\"testTopic\", 1));\n\n // Call the assign method\n mockConsumer.assign(partitions);\n\n // Verify that the assign method was called with the correct parameters\n verify(mockConsumer).assign(partitions);\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 46, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String relativePath(Path absolutePath, Path basePath) {\n Preconditions.checkArgument(absolutePath.isAbsolute(), \"absolutePath must be an absolute path\");\n Preconditions.checkArgument(basePath.isAbsolute(), \"basePath must be an absolute path\");\n\n List absolutePathComponents = Lists.newArrayList(toPathComponents(absolutePath)); // make a copy\n List basePathComponents = toPathComponents(basePath);\n boolean hasCommonPrefix = basePathComponents.isEmpty(); // when basePath is \"/\", we always have a common prefix\n\n for (String base : basePathComponents) {\n if (absolutePathComponents.get(0).equals(base)) {\n absolutePathComponents.remove(0);\n hasCommonPrefix = true;\n } else {\n break;\n }\n }\n if (hasCommonPrefix) {\n return PATH_JOINER.join(absolutePathComponents);\n } else {\n return absolutePath.toString();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testRelativePath() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRelativePath", + "reference": " @Test\n public void testRelativePath() throws Exception {\n assertEquals(\"a\", PathUtils.relativePath(Path.of(\"/a\"), Path.of(\"/\")));\n assertEquals(\"b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/a\")));\n assertEquals(\"b/c.json\", PathUtils.relativePath(Path.of(\"/a/b/c.json\"), Path.of(\"/a\")));\n assertEquals(\"c/d/e\", PathUtils.relativePath(Path.of(\"/a/b/c/d/e\"), Path.of(\"/a/b\")));\n assertEquals(\"/a/b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/c/d\"))); // no common prefix\n }", + "completions": [ + "```java\n @Test\n public void testRelativePath() throws Exception {\n Path absolutePath = Paths.get(\"/a/b/c/d\");\n Path basePath = Paths.get(\"/a/b\");\n String expected = \"c/d\";\n String actual = relativePath(absolutePath, basePath);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 31, + "token": 234, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 47, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Boolean isJson(String value) {\n JsonElement jsonElement;\n try {\n jsonElement = JsonParser.parseString(value);\n } catch (Exception e) {\n return false;\n }\n if (jsonElement == null || jsonElement.isJsonNull()) {\n return false;\n }\n if (!jsonElement.isJsonObject() && !jsonElement.isJsonArray()) {\n return false;\n }\n return true;\n }\n```\n###test function signature:\n```java\n @Test\n void testIsJson()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJson", + "reference": " @Test\n void testIsJson() {\n assertThat(JsonUtil.isJson(\"value\")).isFalse();\n }", + "completions": [ + "```java\n @Test\n void testIsJson() {\n // Test case 1: Valid JSON object\n String jsonObject = \"{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}\";\n assertTrue(isJson(jsonObject));\n\n // Test case 2: Valid JSON array\n String jsonArray = \"[{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}, {\\\"name\\\":\\\"Jane\\\", \\\"age\\\":25, \\\"city\\\":\\\"Chicago\\\"}]\";\n assertTrue(isJson(jsonArray));\n\n // Test case 3: Invalid JSON (not an object or array)\n String invalidJson = \"\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"\";\n assertFalse(isJson(invalidJson));\n\n // Test case 4: Empty string\n String emptyString = \"\";\n assertFalse(isJson(emptyString));\n\n // Test case 5: Null value\n String nullValue = null;\n assertFalse(isJson(nullValue));\n }\n```\n" + ], + "line": 26, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchOnlyHistogram()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchOnlyHistogram", + "reference": " @Test\n public void testSearchOnlyHistogram() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 0,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(0);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(2);\n\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(1).getDocCount()).isEqualTo(1);\n\n assertThat(\n Long.parseLong(histogram.getBuckets().get(0).getKeyAsString()) >= time.toEpochMilli())\n .isTrue();\n assertThat(\n Long.parseLong(histogram.getBuckets().get(1).getKeyAsString())\n <= time.plusSeconds(10).toEpochMilli())\n .isTrue();\n }", + "completions": [ + "```java\n @Test\n public void testSearchOnlyHistogram() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 0;\n AggBuilder aggBuilder = new AggBuilder();\n aggBuilder.addHistogramAgg(\"testField\", \"testInterval\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation().getHistogramAggs().size() > 0);\n assertEquals(\"testField\", result.getAggregation().getHistogramAggs().get(0).getField());\n assertEquals(\"testInterval\", result.getAggregation().getHistogramAggs().get(0).getInterval());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 49, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent", + "reference": " @Test\n public void testReadComponent() {\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n EntityMetadata entityMetadata = mock(EntityMetadata.class);\n VariableSource variableSource = mock(VariableSource.class);\n when(contentPermissionChecker.isPermitted(anyString(), any(), eq(BreadActions.BROWSE), any())).thenReturn(true);\n when(component.getEntityMetadata()).thenReturn(entityMetadata);\n when(entityMetadata.getId()).thenReturn(new DetachedEntityId(\"someid\"));\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(Arrays.asList(asset));\n when(assetVariableResolver.fromAsset(asset)).thenReturn(variableSource);\n ComponentXO componentXO = underTest.readComponent(\"someid\", \"testRepositoryName\");\n\n assertThat(componentXO, is(notNullValue()));\n assertThat(componentXO.getId(), is(\"someid\"));\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n Supplier txSupplier = () -> storageTx;\n\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(txSupplier);\n when(storageTx.findComponent(componentId)).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(ImmutableList.of(asset));\n\n // When\n ComponentXO result = readComponent(repository, componentId);\n\n // Then\n verify(repository).facet(StorageFacet.class);\n verify(storageFacet).txSupplier();\n verify(storageTx).begin();\n verify(storageTx).findComponent(componentId);\n verify(storageTx).browseAssets(component);\n verify(storageTx).close();\n // Add assertions for the result\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 50, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ListResult listEntities(String tableName, String searchIndexName, List matchFilters,\n List queryFilters, List multiMatchFilter, String nextToken, List sorters, Class clazz) {\n SearchQuery searchQuery = createSearchQuery(matchFilters, queryFilters, multiMatchFilter);\n SearchRequest searchRequest = new SearchRequest(tableName, searchIndexName, searchQuery);\n if (!StringUtils.isEmpty(nextToken)) {\n byte[] tokenBytes = EncryptionUtil.decode(nextToken);\n searchRequest.getSearchQuery().setToken(tokenBytes);\n } else {\n if (sorters != null &&!sorters.isEmpty()) {\n searchQuery.setSort(new Sort(sorters));\n }\n }\n SearchRequest.ColumnsToGet columnsToGet = new SearchRequest.ColumnsToGet();\n columnsToGet.setColumns(ReflectionUtil.getPropertyNames(clazz));\n searchRequest.setColumnsToGet(columnsToGet);\n log.info(\"searchRequest:{}\", JsonUtil.toJsonString(searchRequest));\n SearchResponse searchResponse = otsClient.search(searchRequest);\n if (searchResponse == null || searchResponse.getRows() == null) {\n return ListResult.genSuccessListResult(null, 0);\n }\n byte[] nextTokenBytes = searchResponse.getNextToken();\n nextToken = nextTokenBytes == null || nextTokenBytes.length == 0 ? null : EncryptionUtil.encode(nextTokenBytes);\n List result = searchResponse.getRows().stream()\n .map(row -> OtsUtil.convertRowToDTO(row, clazz))\n .collect(Collectors.toList());\n return ListResult.genSuccessListResult(result, searchResponse.getTotalCount(), nextToken);\n }\n```\n###test function signature:\n```java\n @Test\n void testListEntities()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListEntities", + "reference": " @Test\n void testListEntities() {\n final List matchFilters = createMatchFilters();\n final List queryFilters = createQueryFilters();\n final ListResult expectedResult = new ListResult<>();\n FieldSort fieldSort = new FieldSort(OrderOtsConstant.GMT_CREATE_LONG);\n fieldSort.setOrder(SortOrder.DESC);\n expectedResult.setData(Arrays.asList());\n expectedResult.setCount(0L);\n String nextToken = \"CAESFQoTChEKDWdtdENyZWF0ZUxvbmcQARgBIlQKCQBI8UqGigEAAApHA0IAAAAxUzM1MzQzMTM0NjQzMjYzMzAzMzYyMzE2MTMzMzkzOTM1MzEzNjM2MzM2NDM2MzAzMDMwNjYzNTM1MzA2NjY0MzM=\";\n expectedResult.setNextToken(nextToken);\n final SearchResponse searchResponse = new SearchResponse(new Response(\"requestId\"));\n when(mockOtsClient.search(any(SearchRequest.class))).thenReturn(searchResponse);\n ListResult result = baseOtsHelper.listEntities(\"order\", \"order_index\", matchFilters, queryFilters, null, nextToken, Arrays.asList(fieldSort), OrderDTO.class);\n assertThat(result.getCount()).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n void testListEntities() {\n // Given\n String tableName = \"testTable\";\n String searchIndexName = \"testIndex\";\n List matchFilters = Arrays.asList(new OtsFilter(\"field1\", OtsFilter.CompareOperator.EQUAL, \"value1\"),\n new OtsFilter(\"field2\", OtsFilter.CompareOperator.EQUAL, \"value2\"));\n List queryFilters = Arrays.asList(new OtsFilter(\"field3\", OtsFilter.CompareOperator.EQUAL, \"value3\"),\n new OtsFilter(\"field4\", OtsFilter.CompareOperator.EQUAL, \"value4\"));\n List multiMatchFilter = Arrays.asList(new OtsFilter(\"field5\", OtsFilter.CompareOperator.EQUAL, \"value5\"),\n new OtsFilter(\"field6\", OtsFilter.CompareOperator.EQUAL, \"value6\"));\n String nextToken = \"testNextToken\";\n List sorters = Arrays.asList(new Sort.Sorter(\"field7\", Sort.SortOrder.ASC),\n new Sort.Sorter(\"field8\", Sort.SortOrder.DESC));\n Class clazz = TestDTO.class;\n\n // When\n ListResult result = listEntities(tableName, searchIndexName, matchFilters, queryFilters, multiMatchFilter, nextToken, sorters, clazz);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalCount());\n assertNull(result.getNextToken());\n assertTrue(result.getData().isEmpty());\n }\n```\n" + ], + "line": 38, + "token": 346, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 51, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String[] listFiles(URI fileUri, boolean recursive) throws IOException {\n try {\n ImmutableList.Builder builder = ImmutableList.builder();\n String continuationToken = null;\n boolean isDone = false;\n String prefix = normalizeToDirectoryPrefix(fileUri);\n int fileCount = 0;\n while (!isDone) {\n ListObjectsV2Request.Builder listObjectsV2RequestBuilder =\n ListObjectsV2Request.builder().bucket(fileUri.getHost());\n if (!prefix.equals(DELIMITER)) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.prefix(prefix);\n }\n if (!recursive) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.delimiter(DELIMITER);\n }\n if (continuationToken != null) {\n listObjectsV2RequestBuilder.continuationToken(continuationToken);\n }\n ListObjectsV2Request listObjectsV2Request = listObjectsV2RequestBuilder.build();\n LOG.debug(\"Trying to send ListObjectsV2Request {}\", listObjectsV2Request);\n ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);\n LOG.debug(\"Getting ListObjectsV2Response: {}\", listObjectsV2Response);\n List filesReturned = listObjectsV2Response.contents();\n fileCount += filesReturned.size();\n filesReturned.stream()\n .forEach(\n object -> {\n // Only add files and not directories\n if (!object.key().equals(fileUri.getPath())\n && !object.key().endsWith(DELIMITER)) {\n String fileKey = object.key();\n if (fileKey.startsWith(DELIMITER)) {\n fileKey = fileKey.substring(1);\n }\n builder.add(S3_SCHEME + fileUri.getHost() + DELIMITER + fileKey);\n }\n });\n if (fileCount == LIST_MAX_KEYS) {\n // check if we reached the max keys returned, if so abort and throw an error message\n LOG.error(\n \"Too many files ({}) returned from S3 when attempting to list object prefixes\",\n LIST_MAX_KEYS);\n throw new IllegalStateException(\n String.format(\n \"Max keys (%s) reached when attempting to list S3 objects\", LIST_MAX_KEYS));\n }\n isDone = !listObjectsV2Response.isTruncated();\n continuationToken = listObjectsV2Response.nextContinuationToken();\n }\n String[] listedFiles = builder.build().toArray(new String[0]);\n LOG.debug(\n \"Listed {} files from URI: {}, is recursive: {}\", listedFiles.length, fileUri, recursive);\n return listedFiles;\n } catch (Throwable t) {\n throw new IOException(t);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListFilesInFolderRecursive", + "reference": " @Test\n public void testListFilesInFolderRecursive() throws Exception {\n String folder = \"list-files-rec\";\n String[] nestedFolders = new String[] {\"list-files-child-1\", \"list-files-child-2\"};\n String[] originalFiles = new String[] {\"a-list-3.txt\", \"b-list-3.txt\", \"c-list-3.txt\"};\n\n List expectedResultList = new ArrayList<>();\n for (String childFolder : nestedFolders) {\n String folderName = folder + DELIMITER + childFolder;\n for (String fileName : originalFiles) {\n createEmptyFile(folderName, fileName);\n expectedResultList.add(\n String.format(FILE_FORMAT, SCHEME, bucket, folderName + DELIMITER + fileName));\n }\n }\n String[] actualFiles =\n s3BlobFs.listFiles(URI.create(String.format(FILE_FORMAT, SCHEME, bucket, folder)), true);\n\n actualFiles =\n Arrays.stream(actualFiles).filter(x -> x.contains(\"list-3\")).toArray(String[]::new);\n assertEquals(actualFiles.length, expectedResultList.size());\n assertTrue(Arrays.equals(expectedResultList.toArray(), actualFiles));\n }", + "completions": [ + "```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception {\n // Given\n URI fileUri = new URI(\"s3://bucket/folder/\");\n boolean recursive = true;\n String[] expectedFiles = {\n \"s3://bucket/folder/file1.txt\",\n \"s3://bucket/folder/file2.txt\",\n \"s3://bucket/folder/subfolder/file3.txt\",\n \"s3://bucket/folder/subfolder/file4.txt\"\n };\n\n // When\n String[] actualFiles = s3FileSystem.listFiles(fileUri, recursive);\n\n // Then\n assertArrayEquals(expectedFiles, actualFiles);\n }\n```\n" + ], + "line": 70, + "token": 490, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 52, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_nameDuplicate", + "reference": " @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // mock 数据\n RoleDO roleDO = randomPojo(RoleDO.class, o -> o.setName(\"role_name\"));\n roleMapper.insert(roleDO);\n // 准备参数\n String name = \"role_name\";\n\n // 调用,并断言异常\n assertServiceException(() -> roleService.validateRoleDuplicate(name, randomString(), null),\n ROLE_NAME_DUPLICATE, name);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // Given\n String name = \"duplicateName\";\n String code = \"uniqueCode\";\n Long id = 1L;\n RoleDO existingRole = new RoleDO();\n existingRole.setId(2L);\n existingRole.setName(name);\n existingRole.setCode(code);\n when(roleMapper.selectByName(name)).thenReturn(existingRole);\n\n // When\n try {\n roleService.validateRoleDuplicate(name, code, id);\n fail(\"Expected exception not thrown\");\n } catch (ServiceException e) {\n // Then\n assertEquals(ROLE_NAME_DUPLICATE, e.getCode());\n assertEquals(name, e.getArgs()[0]);\n }\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 53, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public String sign(Map map) {\n try {\n Date date = new Date(System.currentTimeMillis() + jwtProperties.getExpireTimeMills());\n Algorithm algorithm = Algorithm.HMAC256(jwtProperties.getTokenSecret());\n Map header = new HashMap<>(2);\n header.put(\"typ\", \"jwt\");\n JWTCreator.Builder builder = JWT.create()\n .withHeader(header)\n .withIssuedAt(new Date())\n .withExpiresAt(date);\n map.entrySet().forEach(entry -> {\n if (entry.getValue() instanceof Integer) {\n builder.withClaim(entry.getKey(), (Integer) entry.getValue());\n } else if (entry.getValue() instanceof Long) {\n builder.withClaim(entry.getKey(), (Long) entry.getValue());\n } else if (entry.getValue() instanceof Boolean) {\n builder.withClaim(entry.getKey(), (Boolean) entry.getValue());\n } else if (entry.getValue() instanceof String) {\n builder.withClaim(entry.getKey(), String.valueOf(entry.getValue()));\n } else if (entry.getValue() instanceof Double) {\n builder.withClaim(entry.getKey(), (Double) entry.getValue());\n } else if (entry.getValue() instanceof Date) {\n builder.withClaim(entry.getKey(), (Date) entry.getValue());\n }\n });\n return builder.sign(algorithm);\n } catch (Exception e) {\n return null;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sign_generateJwtToken_succeed()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sign_generateJwtToken_succeed", + "reference": " @Test\n public void sign_generateJwtToken_succeed() {\n Map map = new HashMap<>();\n map.put(\"userId\", 123);\n map.put(\"username\", \"testUser\");\n String token = jwtService.sign(map);\n Assert.assertNotNull(token);\n }", + "completions": [ + "```java\n @Test\n public void sign_generateJwtToken_succeed() {\n // Given\n Map map = new HashMap<>();\n map.put(\"userId\", 1);\n map.put(\"username\", \"testUser\");\n map.put(\"isAdmin\", true);\n map.put(\"expireTime\", System.currentTimeMillis() + 3600000);\n\n // When\n String jwtToken = jwtService.sign(map);\n\n // Then\n assertNotNull(jwtToken);\n assertTrue(jwtToken.length() > 0);\n }\n```\n" + ], + "line": 41, + "token": 349, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 54, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_non_serializable_value_with_null_value", + "reference": " @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map original = new LinkedHashMap<>();\n original.put(\"null\", null);\n original.put(\"key\", new NonSerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map map = new HashMap<>();\n map.put(\"key\", null);\n\n Map clonedMap = (Map) clone(map);\n\n assertNotSame(map, clonedMap);\n assertEquals(map, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 55, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void init(final DiscoveryConfig config) {\n if (this.namingService != null) {\n LOGGER.info(\"Nacos naming service already registered\");\n return;\n }\n Properties properties = config.getProps();\n Properties nacosProperties = new Properties();\n this.groupName = properties.getProperty(\"groupName\", \"SHENYU_GROUP\");\n String serverAddr = config.getServerList();\n nacosProperties.put(PropertyKeyConst.SERVER_ADDR, serverAddr);\n nacosProperties.put(PropertyKeyConst.NAMESPACE, properties.getProperty(NAMESPACE, \"\"));\n nacosProperties.put(PropertyKeyConst.USERNAME, properties.getProperty(PropertyKeyConst.USERNAME, \"\"));\n nacosProperties.put(PropertyKeyConst.PASSWORD, properties.getProperty(PropertyKeyConst.PASSWORD, \"\"));\n nacosProperties.put(PropertyKeyConst.ACCESS_KEY, properties.getProperty(PropertyKeyConst.ACCESS_KEY, \"\"));\n nacosProperties.put(PropertyKeyConst.SECRET_KEY, properties.getProperty(PropertyKeyConst.SECRET_KEY, \"\"));\n try {\n this.namingService = NamingFactory.createNamingService(nacosProperties);\n LOGGER.info(\"Nacos naming service initialized success\");\n } catch (NacosException e) {\n LOGGER.error(\"Error initializing Nacos naming service\", e);\n throw new ShenyuException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testInit", + "reference": " @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n // Set the discovery config\n setField(nacosDiscoveryServiceUnderTest.getClass(), \"namingService\", null);\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n config.setServerList(\"127.0.0.1:8848\");\n properties.setProperty(\"groupName\", \"SHENYU_GROUP\");\n config.setProps(properties);\n\n try (MockedStatic mockedNamingFactory = mockStatic(NamingFactory.class)) {\n // Mock the successful creation of NamingService\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenReturn(namingService);\n nacosDiscoveryServiceUnderTest.init(config);\n mockedNamingFactory.verify(() -> NamingFactory.createNamingService(any(Properties.class)));\n assertEquals(namingService, getField(nacosDiscoveryServiceUnderTest, \"namingService\"));\n // Mock the situation where NamingService fails to be created and throws an exception\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenThrow(new NacosException());\n assertDoesNotThrow(() -> nacosDiscoveryServiceUnderTest.init(config));\n }\n }", + "completions": [ + "```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n properties.put(\"groupName\", \"SHENYU_GROUP\");\n properties.put(\"serverAddr\", \"localhost:8848\");\n properties.put(\"namespace\", \"\");\n properties.put(\"username\", \"\");\n properties.put(\"password\", \"\");\n properties.put(\"accessKey\", \"\");\n properties.put(\"secretKey\", \"\");\n config.setProps(properties);\n config.setServerList(\"localhost:8848\");\n\n NacosDiscoveryService nacosDiscoveryService = new NacosDiscoveryService();\n nacosDiscoveryService.init(config);\n\n Field namingServiceField = NacosDiscoveryService.class.getDeclaredField(\"namingService\");\n namingServiceField.setAccessible(true);\n NamingService namingService = (NamingService) namingServiceField.get(nacosDiscoveryService);\n\n assertNotNull(namingService);\n }\n```\n" + ], + "line": 35, + "token": 273, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 56, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyRecoveryNotBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(850, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 2) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(750, 0, indexerConfig));\n\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n assertThat(recoveryTaskCreator.determineStartingOffset(1201, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTaskCreator.determineStartingOffset(1200, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(1150, 0, indexerConfig));\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore))\n .contains(recoveryTask1, recoveryTask11);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 57, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public boolean isJsonMime(String mediaType) {\n // \"* / *\" is default to JSON\n if (\"*/*\".equals(mediaType)) {\n return true;\n }\n\n try {\n return isJsonMime(MediaType.parseMediaType(mediaType));\n } catch (InvalidMediaTypeException e) {\n }\n return false;\n }\n```\n###test function signature:\n```java\n @Test\n public void testIsJsonMime()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJsonMime", + "reference": " @Test\n public void testIsJsonMime() {\n assertFalse(apiClient.isJsonMime(null));\n assertFalse(apiClient.isJsonMime(\"\"));\n assertFalse(apiClient.isJsonMime(\"text/plain\"));\n assertFalse(apiClient.isJsonMime(\"application/xml\"));\n assertFalse(apiClient.isJsonMime(\"application/jsonp\"));\n assertFalse(apiClient.isJsonMime(\"example/json\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+jsonx\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+xjson\"));\n\n assertTrue(apiClient.isJsonMime(\"application/json\"));\n assertTrue(apiClient.isJsonMime(\"application/json; charset=UTF8\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/JSON\"));\n\n assertTrue(apiClient.isJsonMime(\"application/problem+json\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/PROBLEM+JSON\"));\n assertTrue(apiClient.isJsonMime(\"application/json\\t\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+bar+json\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json;x;y\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json\\t;\"));\n assertTrue(apiClient.isJsonMime(\"Example/fOO+JSON\"));\n }", + "completions": [ + "```java\n @Test\n public void testIsJsonMime() {\n // Test case 1: \"* / *\"\n assertTrue(isJsonMime(\"*/*\"));\n\n // Test case 2: \"application/json\"\n assertTrue(isJsonMime(\"application/json\"));\n\n // Test case 3: \"application/json; charset=UTF-8\"\n assertTrue(isJsonMime(\"application/json; charset=UTF-8\"));\n\n // Test case 4: \"application/xml\"\n assertFalse(isJsonMime(\"application/xml\"));\n\n // Test case 5: \"text/plain\"\n assertFalse(isJsonMime(\"text/plain\"));\n\n // Test case 6: \"invalid/media/type\"\n assertFalse(isJsonMime(\"invalid/media/type\"));\n\n // Test case 7: null\n assertFalse(isJsonMime(null));\n }\n```\n" + ], + "line": 22, + "token": 139, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n dpopService.validateDpopThumprint(\"test\", \"test\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n // Given\n String existingThumprint = \"existingThumprint\";\n String requestThumprint = \"existingThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n // Call the method to test\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 59, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String getExpandUrl(String url, Object paramObj) {\n UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url);\n if (paramObj instanceof Map) {\n for (Map.Entry param : ((Map) paramObj).entrySet()) {\n if (null != param.getValue()) {\n builder.queryParam(param.getKey(), param.getValue());\n }\n }\n } else {\n for (Field param : paramObj.getClass().getDeclaredFields()) {\n param.setAccessible(true);\n\n try {\n Object value = param.get(paramObj);\n if (null != value) {\n if (value instanceof List) {\n builder.queryParam(param.getName(), (List) value);\n } else {\n builder.queryParam(param.getName(), value);\n }\n }\n } catch (IllegalAccessException e) {\n log.warn(\"Field: {} value retrieval failed, exception: {}\", param.getName(), e);\n }\n }\n }\n\n return builder.build().encode().toUri().toString();\n }\n```\n###test function signature:\n```java\n @Test\n void testGetExpandUrl()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetExpandUrl", + "reference": " @Test\n void testGetExpandUrl() {\n String result = HttpUtil.getExpandUrl(\"http://baidu.com\", \"paramObj\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testGetExpandUrl() {\n // Test with a Map\n Map map = new HashMap<>();\n map.put(\"param1\", \"value1\");\n map.put(\"param2\", \"value2\");\n String url = \"http://example.com\";\n String expected = \"http://example.com?param1=value1¶m2=value2\";\n String actual = getExpandUrl(url, map);\n assertEquals(expected, actual);\n\n // Test with a custom object\n TestObject obj = new TestObject(\"value1\", \"value2\");\n expected = \"http://example.com?param1=value1¶m2=value2\";\n actual = getExpandUrl(url, obj);\n assertEquals(expected, actual);\n }\n\n public static class TestObject {\n private String param1;\n private String param2;\n\n public TestObject(String param1, String param2) {\n this.param1 = param1;\n this.param2 = param2;\n }\n\n public String getParam1() {\n return param1;\n }\n\n public String getParam2() {\n return param2;\n }\n }\n```\n" + ], + "line": 38, + "token": 277, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 60, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"asdf\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n String invalidExpirationTime = \"invalid\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(invalidExpirationTime);\n assertEquals(CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 61, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Get\n @Path(\"/:indexName/_mapping\")\n public HttpResponse mapping(\n @Param(\"indexName\") Optional indexName,\n @Param(\"startTimeEpochMs\") Optional startTimeEpochMs,\n @Param(\"endTimeEpochMs\") Optional endTimeEpochMs)\n throws IOException {\n // Use a tree map so the results are naturally sorted\n Map> propertiesMap = new TreeMap<>();\n\n // we default the schema search to the last hour if params are not provided\n AstraSearch.SchemaResult schemaResult =\n searcher.getSchema(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(\n startTimeEpochMs.orElse(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build());\n\n Map schema = SearchResultUtils.fromSchemaResultProto(schemaResult);\n schema.forEach((key, value) -> propertiesMap.put(key, Map.of(\"type\", value.getName())));\n\n // todo - remove this after we add support for a \"date\" type\n // override the timestamp as a date field for proper autocomplete\n propertiesMap.put(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, Map.of(\"type\", \"date\"));\n\n return HttpResponse.of(\n HttpStatus.OK,\n MediaType.JSON,\n JsonUtil.writeAsString(\n ImmutableMap.of(\n indexName.orElseThrow(),\n ImmutableMap.of(\"mappings\", ImmutableMap.of(\"properties\", propertiesMap)))));\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexMapping() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexMapping", + "reference": " @Test\n public void testIndexMapping() throws IOException {\n AstraQueryServiceBase searcher = mock(AstraQueryServiceBase.class);\n ElasticsearchApiService serviceUnderTest = new ElasticsearchApiService(searcher);\n\n Instant start = Instant.now();\n Instant end = start.minusSeconds(60);\n\n when(searcher.getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build())))\n .thenReturn(AstraSearch.SchemaResult.newBuilder().build());\n\n HttpResponse response =\n serviceUnderTest.mapping(\n Optional.of(\"foo\"), Optional.of(start.toEpochMilli()), Optional.of(end.toEpochMilli()));\n verify(searcher)\n .getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build()));\n\n // handle response\n AggregatedHttpResponse aggregatedRes = response.aggregate().join();\n String body = aggregatedRes.content(StandardCharsets.UTF_8);\n JsonNode jsonNode = new ObjectMapper().readTree(body);\n\n assertThat(aggregatedRes.status().code()).isEqualTo(200);\n\n assertThat(jsonNode.findValue(\"foo\")).isNotNull();\n assertThat(\n jsonNode.findValue(\"foo\").findValue(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName))\n .isNotNull();\n\n when(searcher.getSchema(any()))\n .thenAnswer(\n invocationOnMock -> {\n AstraSearch.SchemaRequest request =\n ((AstraSearch.SchemaRequest) invocationOnMock.getArguments()[0]);\n assertThat(request.getDataset()).isEqualTo(\"bar\");\n assertThat(request.getStartTimeEpochMs())\n .isCloseTo(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli(),\n Offset.offset(1000L));\n assertThat(request.getEndTimeEpochMs())\n .isCloseTo(Instant.now().toEpochMilli(), Offset.offset(1000L));\n return AstraSearch.SchemaResult.newBuilder().build();\n });\n serviceUnderTest.mapping(Optional.of(\"bar\"), Optional.empty(), Optional.empty());\n }", + "completions": [ + "```java\n @Test\n public void testIndexMapping() throws IOException {\n Optional indexName = Optional.of(\"testIndex\");\n Optional startTimeEpochMs = Optional.of(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli());\n Optional endTimeEpochMs = Optional.of(Instant.now().toEpochMilli());\n\n AstraSearch.SchemaResult schemaResult =\n AstraSearch.SchemaResult.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(startTimeEpochMs.orElse(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build();\n\n when(searcher.getSchema(any(AstraSearch.SchemaRequest.class))).thenReturn(schemaResult);\n\n HttpResponse response = mapping(indexName, startTimeEpochMs, endTimeEpochMs);\n\n assertEquals(HttpStatus.OK, response.getStatus());\n assertEquals(MediaType.JSON, response.getContentType());\n assertNotNull(response.getContent());\n }\n```\n" + ], + "line": 43, + "token": 354, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 62, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "@SuppressWarnings", + "reference": " @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n consumer = newConsumer();\n Exception e = assertThrows(UnsupportedOperationException.class, () -> consumer.poll(0L));\n assertEquals(\"Consumer.poll(long) is not supported when \\\"group.protocol\\\" is \\\"consumer\\\". \" +\n \"This method is deprecated and will be removed in the next major release.\", e.getMessage());\n }", + "completions": [ + "```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 46, + "token": 373, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 63, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testNullSearchString()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNullSearchString", + "reference": " @Test\n public void testNullSearchString() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n assertThatExceptionOfType(IllegalArgumentException.class)\n .isThrownBy(\n () ->\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME + \"miss\",\n null,\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\")));\n }", + "completions": [ + "```java\n @Test\n public void testNullSearchString() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = null;\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n try {\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n fail(\"Expected an IllegalArgumentException to be thrown\");\n } catch (IllegalArgumentException e) {\n // Expected\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 64, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture sendMessageBack(ProxyContext ctx, ReceiptHandle handle, String messageId,\n ConsumerSendMsgBackRequestHeader requestHeader, long timeoutMillis) {\n // Build the response.\n final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null, null);\n\n Integer delayLevel = requestHeader.getDelayLevel();\n if (Objects.isNull(delayLevel)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument delay level is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n Long offset = requestHeader.getOffset();\n if (Objects.isNull(offset)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument offset is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n VirtualQueue virtualQueue = new VirtualQueue(requestHeader.getBname());\n\n CompletableFuture topicFuture = topicOf(requestHeader.getOriginTopic());\n CompletableFuture groupFuture = consumerGroupOf(requestHeader.getGroup());\n return topicFuture.thenCombine(groupFuture, (topic, group) -> {\n if (topic.getTopicId() != virtualQueue.topicId()) {\n LOGGER.error(\"Topic id in request header {} does not match topic id in message queue {}, maybe the topic is recreated.\",\n topic.getTopicId(), virtualQueue.topicId());\n throw new ProxyException(apache.rocketmq.v2.Code.TOPIC_NOT_FOUND, \"Topic resource does not exist.\");\n }\n return Pair.of(topic, group);\n }).thenCompose(pair -> {\n Topic topic = pair.getLeft();\n ConsumerGroup group = pair.getRight();\n\n return store.pull(StoreContext.EMPTY, group.getGroupId(), topic.getTopicId(), virtualQueue.physicalQueueId(),\n Filter.DEFAULT_FILTER, requestHeader.getOffset(), 1, false)\n .thenApply(pullResult -> {\n if (pullResult.status() == com.automq.rocketmq.store.model.message.PullResult.Status.FOUND) {\n return pullResult.messageList().get(0);\n }\n throw new ProxyException(apache.rocketmq.v2.Code.MESSAGE_NOT_FOUND, \"Message not found from server.\");\n }).thenCompose(messageExt -> {\n if (messageExt.deliveryAttempts() > group.getMaxDeliveryAttempt()) {\n return deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n }\n\n // Message consume retry strategy\n // <0: no retry,put into DLQ directly\n // =0: broker control retry frequency\n // >0: client control retry frequency\n return switch (Integer.compare(delayLevel, 0)) {\n case -1 ->\n deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n case 0 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n // Keep the same logic as apache RocketMQ.\n int serverDelayLevel = messageExt.deliveryAttempts() + 1;\n messageExt.setDeliveryAttempts(serverDelayLevel);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(serverDelayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put messageExt to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n case 1 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n messageExt.setDeliveryAttempts(messageExt.deliveryAttempts() + 1);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(delayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put message to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n default -> throw new IllegalStateException(\"Never reach here\");\n };\n });\n }).whenComplete((nil, throwable) -> {\n if (throwable != null) {\n response.setCode(ResponseCode.SYSTEM_ERROR);\n response.setRemark(throwable.getMessage());\n }\n }).thenApply(nil -> response);\n }\n```\n###test function signature:\n```java\n @Test\n void sendMessageBack()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendMessageBack", + "reference": " @Test\n void sendMessageBack() {\n String topicName = \"topic\";\n VirtualQueue virtualQueue = new VirtualQueue(2, 0);\n ConsumerSendMsgBackRequestHeader header = new ConsumerSendMsgBackRequestHeader();\n header.setBname(virtualQueue.brokerName());\n header.setGroup(\"group\");\n header.setOriginTopic(topicName);\n header.setOffset(0L);\n header.setDelayLevel(0);\n\n CompletionException exception = assertThrowsExactly(CompletionException.class, () -> messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join());\n assertInstanceOf(ProxyException.class, ExceptionUtils.getRealException(exception));\n ProxyException realException = (ProxyException) ExceptionUtils.getRealException(exception);\n assertEquals(realException.getErrorCode(), Code.MESSAGE_NOT_FOUND);\n assertEquals(realException.getMessage(), \"Message not found from server.\");\n\n long topicId = metadataService.topicOf(topicName).join().getTopicId();\n messageStore.put(StoreContext.EMPTY, FlatMessageUtil.convertTo(ProxyContextExt.create(), topicId, 0, \"\", new Message(topicName, \"\", new byte[] {})));\n\n // Broker controlled delay level.\n header.setDelayLevel(0);\n RemotingCommand response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Client controlled delay level.\n header.setDelayLevel(16);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Forward message into dead letter topic.\n header.setDelayLevel(-1);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.times(1)).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n }", + "completions": [ + "```java\n @Test\n void sendMessageBack() {\n // Given\n ProxyContext ctx = mock(ProxyContext.class);\n ReceiptHandle handle = mock(ReceiptHandle.class);\n String messageId = \"testMessageId\";\n ConsumerSendMsgBackRequestHeader requestHeader = mock(ConsumerSendMsgBackRequestHeader.class);\n long timeoutMillis = 1000L;\n\n when(requestHeader.getDelayLevel()).thenReturn(1);\n when(requestHeader.getOffset()).thenReturn(1L);\n when(requestHeader.getOriginTopic()).thenReturn(\"testOriginTopic\");\n when(requestHeader.getGroup()).thenReturn(\"testGroup\");\n when(requestHeader.getBname()).thenReturn(\"testBname\");\n\n Topic topic = mock(Topic.class);\n when(topic.getTopicId()).thenReturn(1);\n\n ConsumerGroup group = mock(ConsumerGroup.class);\n when(group.getGroupId()).thenReturn(\"testGroupId\");\n when(group.getMaxDeliveryAttempt()).thenReturn(3);\n\n FlatMessage message = mock(FlatMessage.class);\n when(message.deliveryAttempts()).thenReturn(1);\n when(message.originalOffset()).thenReturn(1L);\n\n MessageExt messageExt = mock(MessageExt.class);\n when(messageExt.message()).thenReturn(message);\n when(messageExt.deliveryAttempts()).thenReturn(1);\n\n PullResult pullResult = mock(PullResult.class);\n when(pullResult.status()).thenReturn(PullResult.Status.FOUND);\n when(pullResult.messageList()).thenReturn(List.of(messageExt));\n\n Store store = mock(Store.class);\n when(store.pull(any(), any(), any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(pullResult));\n when(store.put(any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n DeadLetterService deadLetterService = mock(DeadLetterService.class);\n when(deadLetterService.send(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n ProxyContextExt ctxExt = mock(ProxyContextExt.class);\n\n TopicService topicService = mock(TopicService.class);\n when(topicService.topicOf(any())).thenReturn(CompletableFuture.completedFuture(topic));\n when(topicService.consumerGroupOf(any())).thenReturn(CompletableFuture.completedFuture(group));\n\n ConsumerService consumerService = new ConsumerService(store, deadLetterService, topicService);\n\n // When\n CompletableFuture future = consumerService.sendMessageBack(ctx, handle, messageId, requestHeader, timeoutMillis);\n\n // Then\n RemotingCommand response = future.join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n }\n```\n" + ], + "line": 99, + "token": 793, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 65, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInRequestThreads()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInRequestThreads", + "reference": " @Test\n public void testGetTracerInRequestThreads() {\n ApolloAuditTracer mockTracer = new ApolloAuditTracer(Mockito.mock(ApolloAuditScopeManager.class), supplier);\n RequestAttributes mockRequestAttributes = Mockito.mock(RequestAttributes.class);\n RequestContextHolder.setRequestAttributes(mockRequestAttributes);\n Mockito.when(mockRequestAttributes.getAttribute(Mockito.eq(ApolloAuditConstants.TRACER), Mockito.eq(RequestAttributes.SCOPE_REQUEST)))\n .thenReturn(mockTracer);\n ApolloAuditTracer get = traceContext.tracer();\n assertNotNull(get);\n Mockito.verify(traceContext, Mockito.times(0))\n .setTracer(Mockito.any(ApolloAuditTracer.class));\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInRequestThreads() {\n // Given\n RequestAttributes requestAttributes = Mockito.mock(RequestAttributes.class);\n ApolloAuditTracer tracer = Mockito.mock(ApolloAuditTracer.class);\n Mockito.when(requestAttributes.getAttribute(ApolloAuditConstants.TRACER, RequestAttributes.SCOPE_REQUEST)).thenReturn(tracer);\n Mockito.when(RequestContextHolder.getRequestAttributes()).thenReturn(requestAttributes);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n Assert.assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Map generate(String modelName) {\n Map kv = new HashMap<>();\n if (modelName != null) {\n final Schema schema = this.globalContext.getSchemaMap().get(modelName);\n if (schema != null) {\n String example = Json.pretty(this.resolveModelToExample(modelName, schema));\n\n if (example != null) {\n kv.put(EXAMPLE, example);\n return Map.copyOf(kv);\n }\n }\n }\n return Collections.emptyMap();\n }\n```\n###test function signature:\n```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "shouldGenerateOneOfWhenOneOfInRoot", + "reference": " @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n OpenAPIModelGenerator generator = setupPayloadGenerator();\n\n Map example = generator.generate(\"PetType\");\n String exampleJson = example.get(\"example\");\n\n Assertions.assertThat(exampleJson)\n .contains(\"PetTypeONE_OF#/components/schemas/Husky\")\n .contains(\"PetTypeONE_OF#/components/schemas/Labrador\");\n }", + "completions": [ + "```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n // Given\n String modelName = \"ModelName\";\n Schema schema = new Schema();\n schema.setOneOf(Arrays.asList(new Schema(), new Schema()));\n this.globalContext.getSchemaMap().put(modelName, schema);\n\n // When\n Map result = this.generate(modelName);\n\n // Then\n assertNotNull(result);\n assertFalse(result.isEmpty());\n assertTrue(result.containsKey(EXAMPLE));\n assertNotNull(result.get(EXAMPLE));\n }\n```\n" + ], + "line": 25, + "token": 181, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInAnotherThreadButSameRequest", + "reference": " @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n ApolloAuditTracer mockTracer = Mockito.mock(ApolloAuditTracer.class);\n {\n Mockito.when(traceContext.tracer()).thenReturn(mockTracer);\n }\n CountDownLatch latch = new CountDownLatch(1);\n Executors.newSingleThreadExecutor().submit(() -> {\n ApolloAuditTracer tracer = traceContext.tracer();\n\n assertEquals(mockTracer, tracer);\n\n latch.countDown();\n });\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n // Given\n RequestAttributes requestAttributes = new ServletRequestAttributes(new MockHttpServletRequest());\n RequestContextHolder.setRequestAttributes(requestAttributes);\n ApolloAuditTracer tracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n requestAttributes.setAttribute(ApolloAuditConstants.TRACER, tracer, RequestAttributes.SCOPE_REQUEST);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchMultipleItemsAndIndices()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchMultipleItemsAndIndices", + "reference": " @Test\n public void testSearchMultipleItemsAndIndices() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(1);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchMultipleItemsAndIndices() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(10, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 69, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n static Map>> getCandidateJobs(String tableNameWithType,\n Map> allJobMetadata)\n throws Exception {\n long nowMs = System.currentTimeMillis();\n Map>> candidates = new HashMap<>();\n // If the job started most recently has already completed, then skip retry for the table.\n Pair latestStartedJob = null;\n Pair latestCompletedJob = null;\n // The processing order of job metadata from the given Map is not deterministic. Track the completed original\n // jobs so that we can simply skip the retry jobs belonging to the completed original jobs.\n Map completedOriginalJobs = new HashMap<>();\n Set cancelledOriginalJobs = new HashSet<>();\n for (Map.Entry> entry : allJobMetadata.entrySet()) {\n String jobId = entry.getKey();\n Map jobMetadata = entry.getValue();\n long statsUpdatedAt = Long.parseLong(jobMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));\n String jobStatsInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);\n if (StringUtils.isEmpty(jobStatsInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job progress stats\", jobId);\n continue;\n }\n String jobCtxInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n if (StringUtils.isEmpty(jobCtxInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job context\", jobId);\n continue;\n }\n TableRebalanceProgressStats jobStats = JsonUtils.stringToObject(jobStatsInStr, TableRebalanceProgressStats.class);\n TableRebalanceContext jobCtx = JsonUtils.stringToObject(jobCtxInStr, TableRebalanceContext.class);\n long jobStartTimeMs = jobStats.getStartTimeMs();\n if (latestStartedJob == null || latestStartedJob.getRight() < jobStartTimeMs) {\n latestStartedJob = Pair.of(jobId, jobStartTimeMs);\n }\n String originalJobId = jobCtx.getOriginalJobId();\n RebalanceResult.Status jobStatus = jobStats.getStatus();\n if (jobStatus == RebalanceResult.Status.DONE || jobStatus == RebalanceResult.Status.NO_OP) {\n LOGGER.info(\"Skip rebalance job: {} as it has completed with status: {}\", jobId, jobStatus);\n completedOriginalJobs.put(originalJobId, jobId);\n if (latestCompletedJob == null || latestCompletedJob.getRight() < jobStartTimeMs) {\n latestCompletedJob = Pair.of(jobId, jobStartTimeMs);\n }\n continue;\n }\n if (jobStatus == RebalanceResult.Status.FAILED || jobStatus == RebalanceResult.Status.ABORTED) {\n LOGGER.info(\"Found rebalance job: {} for original job: {} has been stopped with status: {}\", jobId,\n originalJobId, jobStatus);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n continue;\n }\n if (jobStatus == RebalanceResult.Status.CANCELLED) {\n LOGGER.info(\"Found cancelled rebalance job: {} for original job: {}\", jobId, originalJobId);\n cancelledOriginalJobs.add(originalJobId);\n continue;\n }\n // Check if an IN_PROGRESS job is still actively running.\n long heartbeatTimeoutMs = jobCtx.getConfig().getHeartbeatTimeoutInMs();\n if (nowMs - statsUpdatedAt < heartbeatTimeoutMs) {\n LOGGER.info(\"Rebalance job: {} is actively running with status updated at: {} within timeout: {}. Skip \"\n + \"retry for table: {}\", jobId, statsUpdatedAt, heartbeatTimeoutMs, tableNameWithType);\n return Collections.emptyMap();\n }\n // The job is considered failed, but it's possible it is still running, then we might end up with more than one\n // rebalance jobs running in parallel for a table. The rebalance algorithm is idempotent, so this should be fine\n // for the correctness.\n LOGGER.info(\"Found stuck rebalance job: {} for original job: {}\", jobId, originalJobId);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n }\n if (latestCompletedJob != null && latestCompletedJob.getLeft().equals(latestStartedJob.getLeft())) {\n LOGGER.info(\"Rebalance job: {} started most recently has already done. Skip retry for table: {}\",\n latestCompletedJob.getLeft(), tableNameWithType);\n return Collections.emptyMap();\n }\n for (String jobId : cancelledOriginalJobs) {\n LOGGER.info(\"Skip original job: {} as it's cancelled\", jobId);\n candidates.remove(jobId);\n }\n for (Map.Entry entry : completedOriginalJobs.entrySet()) {\n LOGGER.info(\"Skip original job: {} as it's completed by attempt: {}\", entry.getKey(), entry.getValue());\n candidates.remove(entry.getKey());\n }\n return candidates;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetCandidateJobs()\n throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetCandidateJobs", + "reference": " @Test\n public void testGetCandidateJobs()\n throws Exception {\n String tableName = \"table01\";\n Map> allJobMetadata = new HashMap<>();\n\n // Original job run as job1, and all its retry jobs failed too.\n RebalanceConfig jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n TableRebalanceProgressStats stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(1000);\n TableRebalanceContext jobCtx = TableRebalanceContext.forInitialAttempt(\"job1\", jobCfg);\n Map jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job1\", stats, jobCtx);\n allJobMetadata.put(\"job1\", jobMetadata);\n // 3 failed retry runs for job1\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 2, 1100, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 3, 1200, RebalanceResult.Status.ABORTED);\n allJobMetadata.put(\"job1_3\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 4, 1300, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_4\", jobMetadata);\n\n // Original job run as job2, and its retry job job2_1 completed.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(2000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job2\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job2\", stats, jobCtx);\n allJobMetadata.put(\"job2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job2\", 2, 2100, RebalanceResult.Status.DONE);\n allJobMetadata.put(\"job2_2\", jobMetadata);\n\n // Original job run as job3, and failed to send out heartbeat in time.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.IN_PROGRESS);\n stats.setStartTimeMs(3000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job3\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job3\", stats, jobCtx);\n jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"3000\");\n allJobMetadata.put(\"job3\", jobMetadata);\n\n // Original job run as job4, which didn't have retryJobCfg as from old version of the code.\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(4000);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job4\", stats, null);\n jobMetadata.remove(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n allJobMetadata.put(\"job4\", jobMetadata);\n\n // Only need to retry job1 and job3, as job2 is completed and job4 is from old version of code.\n Map>> jobs =\n RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 2);\n assertTrue(jobs.containsKey(\"job1\"));\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job1\").size(), 4); // four runs including job1,job1_1,job1_2,job1_3\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Abort job1 and cancel its retries, then only job3 is retry candidate.\n jobMetadata = allJobMetadata.get(\"job1_4\");\n cancelRebalanceJob(jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 1);\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Add latest job5 that's already done, thus no need to retry for table.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.DONE);\n stats.setStartTimeMs(5000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job5\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job5\", stats, jobCtx);\n allJobMetadata.put(\"job5\", jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 0);\n }", + "completions": [ + "```java\n @Test\n public void testGetCandidateJobs()\n throws Exception {\n // Given\n String tableNameWithType = \"table1_type1\";\n Map> allJobMetadata = new HashMap<>();\n Map jobMetadata1 = new HashMap<>();\n jobMetadata1.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"FAILED\\\"}\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job1\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job1\", jobMetadata1);\n Map jobMetadata2 = new HashMap<>();\n jobMetadata2.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"CANCELLED\\\"}\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job2\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job2\", jobMetadata2);\n\n // When\n Map>> result = getCandidateJobs(tableNameWithType, allJobMetadata);\n\n // Then\n assertEquals(1, result.size());\n assertTrue(result.containsKey(\"job2\"));\n Set> jobSet = result.get(\"job2\");\n assertEquals(1, jobSet.size());\n Pair jobPair = jobSet.iterator().next();\n assertEquals(\"job2\", jobPair.getLeft().getOriginalJobId());\n assertEquals(1640995200000L, jobPair.getRight().longValue());\n }\n```\n" + ], + "line": 94, + "token": 952, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 70, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchError()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchError", + "reference": " @Test\n public void distClientMatchError() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.ERROR);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.ERROR);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.NO_SUB));\n for (IMQTTTransientSession session : sessions) {\n verify(session, never()).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchError() {\n // Create a mock for the ITenantMeter interface\n ITenantMeter mockTenantMeter = Mockito.mock(ITenantMeter.class);\n // Create a mock for the ILocalDistService interface\n ILocalDistService mockLocalDistService = Mockito.mock(ILocalDistService.class);\n // Create a mock for the sessionMap\n Map mockSessionMap = new HashMap<>();\n // Create a mock for the routeMap\n Map> mockRouteMap = new HashMap<>();\n // Create a mock for the resourceThrottler\n IResourceThrottler mockResourceThrottler = Mockito.mock(IResourceThrottler.class);\n\n // Set up the mock objects\n Mockito.when(mockTenantMeter.get(Mockito.anyString())).thenReturn(mockTenantMeter);\n Mockito.when(mockLocalDistService.isGlobal(Mockito.anyString())).thenReturn(false);\n Mockito.when(mockSessionMap.get(Mockito.anyString())).thenReturn(null);\n Mockito.when(mockRouteMap.get(Mockito.any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(null));\n Mockito.when(mockResourceThrottler.hasResource(Mockito.anyString(), Mockito.anyLong())).thenReturn(false);\n\n // Create a new instance of the class under test\n DistService distService = new DistService(mockTenantMeter, mockLocalDistService, mockSessionMap, mockRouteMap, mockResourceThrottler);\n\n // Create a test request\n DeliveryRequest request = DeliveryRequest.newBuilder().build();\n\n // Call the method under test\n CompletableFuture replyFuture = distService.dist(request);\n\n // Verify the results\n DeliveryReply reply = replyFuture.join();\n Assert.assertEquals(0, reply.getResultMap().size());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 71, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchOK()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchOK", + "reference": " @Test\n public void distClientMatchOK() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n when(session.publish(any(), anyList())).thenReturn(true);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.OK);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.OK);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.OK));\n for (IMQTTTransientSession session : sessions) {\n verify(session, times(1)).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchOK() {\n // Given\n DeliveryRequest request = DeliveryRequest.newBuilder()\n .putPackage(\"tenant1\", DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .setMessagePack(TopicMessagePack.newBuilder()\n .setTopic(\"topic1\")\n .setPayload(\"payload1\")\n .build())\n .addMatchInfo(MatchInfo.newBuilder()\n .setReceiverId(\"receiver1\")\n .setTopicFilter(\"topic1\")\n .build())\n .build())\n .build())\n .build();\n sessionMap.put(\"receiver1\", mock(IMQTTTransientSession.class));\n when(sessionMap.get(\"receiver1\").publish(any(MatchInfo.class), anyList())).thenReturn(true);\n when(routeMap.get(any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(LocalRoutes.newBuilder()\n .setLocalizedReceiverId(\"receiver1\")\n .putRoute(\"receiver1\", mock(IMQTTTransientSession.class))\n .build()));\n when(ITenantMeter.get(\"tenant1\")).thenReturn(mock(ITenantMeter.class));\n when(resourceThrottler.hasResource(\"tenant1\", TotalTransientFanOutBytesPerSeconds)).thenReturn(true);\n\n // When\n CompletableFuture replyFuture = dist(request);\n\n // Then\n DeliveryReply reply = replyFuture.join();\n DeliveryResults results = reply.getResultMap().get(\"tenant1\");\n assertEquals(1, results.getResultCount());\n DeliveryResult result = results.getResult(0);\n assertEquals(DeliveryResult.Code.OK, result.getCode());\n assertEquals(\"receiver1\", result.getMatchInfo().getReceiverId());\n assertEquals(\"topic1\", result.getMatchInfo().getTopicFilter());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 72, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Range calcInterquartileRangeApproximation(List sortedValues) {\n Range iqr = null;\n\n if (sortedValues.size() > 0) {\n BigDecimal q1 = valueCloseToQ1(sortedValues);\n BigDecimal q3 = valueCloseToQ3(sortedValues);\n BigDecimal max = sortedValues.get(sortedValues.size() - 1);\n\n // if iqr == 0 AND max == q3 then recursively try finding a non-zero iqr approximation.\n if (q1.compareTo(q3) == 0 && max.compareTo(q3) == 0) {\n // filter out max and try again\n iqr = this.calcInterquartileRangeApproximation(\n sortedValues.stream().filter(d -> d.compareTo(max) < 0).collect(Collectors.toList()));\n }\n\n // if range is still empty use the original q1 and q3 values\n if (iqr == null || iqr.isEmpty()) {\n iqr = Range.closedOpen(q1, q3);\n }\n }\n\n return iqr;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCalcInterquartileRangeApproximation()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCalcInterquartileRangeApproximation", + "reference": " @Test\n public void testCalcInterquartileRangeApproximation() {\n List values = decList(1, 2, 3, 4, 5, 6, 7, 8);\n assertEquals(new BigDecimal(\"3\"), dataBinHelper.calcInterquartileRangeApproximation(values).lowerEndpoint());\n assertEquals(new BigDecimal(\"7\"), dataBinHelper.calcInterquartileRangeApproximation(values).upperEndpoint());\n }", + "completions": [ + "```java\n @Test\n public void testCalcInterquartileRangeApproximation() {\n List sortedValues = Arrays.asList(\n new BigDecimal(\"1.0\"),\n new BigDecimal(\"2.0\"),\n new BigDecimal(\"3.0\"),\n new BigDecimal(\"4.0\"),\n new BigDecimal(\"5.0\"),\n new BigDecimal(\"6.0\"),\n new BigDecimal(\"7.0\"),\n new BigDecimal(\"8.0\"),\n new BigDecimal(\"9.0\"),\n new BigDecimal(\"10.0\")\n );\n\n Range expectedIqr = Range.closedOpen(new BigDecimal(\"2.5\"), new BigDecimal(\"7.5\"));\n Range actualIqr = calcInterquartileRangeApproximation(sortedValues);\n\n assertEquals(expectedIqr, actualIqr);\n }\n```\n" + ], + "line": 30, + "token": 250, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 73, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_non_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new NonSerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new NonSerializableObject(\"name2\"),\n new NonSerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObjectWithNonSerializableNestedObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serialable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key1\", \"value1\");\n map.put(\"key2\", new NonSerializableObject(\"nonSerializableValue\"));\n\n // When\n Map clonedMap = objectCloner.clone(map);\n\n // Then\n assertNotSame(map, clonedMap);\n assertEquals(map.size(), clonedMap.size());\n assertEquals(map.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(map.get(\"key2\"), clonedMap.get(\"key2\"));\n assertEquals(map.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 74, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SignedJWT token2Jwt(String token) {\n if (token == null) {\n log.error(\"Token is null.\");\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n try {\n return SignedJWT.parse(token);\n } catch (ParseException e) {\n log.error(\"Token to JWT Error,message = {}\", e.getMessage());\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testToken2Jwt()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testToken2Jwt", + "reference": " @Test\n public void testToken2Jwt() {\n Assertions.assertThrows(BizException.class, () -> tokenParseHelper.token2Jwt(null));\n }", + "completions": [ + "```java\n @Test\n public void testToken2Jwt() {\n // Given\n String token = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c\";\n\n // When\n SignedJWT jwt = token2Jwt(token);\n\n // Then\n assertNotNull(jwt);\n assertEquals(token, jwt.serialize());\n }\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testLongPollWaitIsLimited()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testLongPollWaitIsLimited", + "reference": " @Test\n public void testLongPollWaitIsLimited() {\n consumer = newConsumer();\n String topicName = \"topic1\";\n consumer.subscribe(singletonList(topicName));\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertTrue(consumer.assignment().isEmpty());\n\n final int partition = 3;\n final TopicPartition tp = new TopicPartition(topicName, partition);\n final List> records = asList(\n new ConsumerRecord<>(topicName, partition, 2, \"key1\", \"value1\"),\n new ConsumerRecord<>(topicName, partition, 3, \"key2\", \"value2\")\n );\n\n // On the first iteration, return no data; on the second, return two records\n doAnswer(invocation -> {\n // Mock the subscription being assigned as the first fetch is collected\n consumer.subscriptions().assignFromSubscribed(Collections.singleton(tp));\n return Fetch.empty();\n }).doAnswer(invocation -> {\n return Fetch.forPartition(tp, records, true);\n }).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n\n // And then poll for up to 10000ms, which should return 2 records without timing out\n ConsumerRecords returnedRecords = consumer.poll(Duration.ofMillis(10000));\n assertEquals(2, returnedRecords.count());\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertEquals(singleton(tp), consumer.assignment());\n }", + "completions": [ + "```java\n @Test\n public void testLongPollWaitIsLimited() {\n // Arrange\n Duration timeout = Duration.ofMillis(1000);\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // Act\n long startTime = System.currentTimeMillis();\n ConsumerRecords records = consumer.poll(timeout);\n long endTime = System.currentTimeMillis();\n\n // Assert\n long pollDuration = endTime - startTime;\n assertTrue(pollDuration >= timeout.toMillis());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 76, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForMinAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForMinAgg", + "reference": " @Test\n public void testFullIndexSearchForMinAgg() {\n Instant time = Instant.now();\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new MinAggBuilder(\n \"test\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalMin internalMin =\n (InternalMin) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n assertThat(Double.valueOf(internalMin.getValue()).longValue()).isEqualTo(time.toEpochMilli());\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForMinAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new MinAggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation() instanceof MinAgg);\n assertTrue(result.getElapsedTimeMicros() > 0);\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 77, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_bsaChecksumFetched", + "reference": " @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n setupMocks();\n when(connection.getResponseCode()).thenReturn(SC_OK);\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream(\"bsa-checksum\\ndata\".getBytes(UTF_8)));\n LazyBlockList download = fetcher.tryFetch(BlockListType.BLOCK);\n assertThat(download.getName()).isEqualTo(BlockListType.BLOCK);\n assertThat(download.checksum()).isEqualTo(\"bsa-checksum\");\n verify(connection, times(1)).setRequestMethod(\"GET\");\n verify(connection, times(1)).setRequestProperty(\"Authorization\", \"Bearer authToken\");\n }", + "completions": [ + "```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n // Given\n URL dataUrl = new URL(\"https://example.com/bsaChecksum\");\n HttpsURLConnection connection = mock(HttpsURLConnection.class);\n when(connection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);\n when(connection.getErrorStream()).thenReturn(new ByteArrayInputStream(\"\".getBytes()));\n when(urlConnectionService.createConnection(dataUrl)).thenReturn(connection);\n\n // When\n LazyBlockList lazyBlockList = bsaService.tryFetch(BlockListType.BSA_CHECKSUM);\n\n // Then\n assertEquals(BlockListType.BSA_CHECKSUM, lazyBlockList.getBlockListType());\n assertEquals(connection, lazyBlockList.getConnection());\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 78, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDisabledFullTextSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDisabledFullTextSearch", + "reference": " @Test\n public void testDisabledFullTextSearch() {\n Instant time = Instant.ofEpochSecond(1593365471);\n Trace.KeyValue field1Tag =\n Trace.KeyValue.newBuilder()\n .setVInt32(1234)\n .setKey(\"field1\")\n .setFieldType(Schema.SchemaFieldType.INTEGER)\n .build();\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(2, \"apple baby\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(3, \"baby car 1234\", time.plusSeconds(4)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Without the _all field as default.\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(3);\n\n // empty string\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"app*\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Returns baby or car, 2 messages.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby car\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Test numbers\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"apple 1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"123\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n }", + "completions": [ + "```java\n @Test\n public void testDisabledFullTextSearch() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getResults().size());\n assertNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 79, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyWithException", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(\n OperateRuntimeException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHitsRelation.EQUAL_TO));\n Mockito.when(searchHits.getHits()).thenReturn(new SearchHit[0]);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processInstanceListViewService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException to be thrown\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find unique process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 80, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_SecurityException_notRetriable", + "reference": " @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n when(connectionService.createConnection(any(URL.class)))\n .thenThrow(new GeneralSecurityException());\n assertThat(\n assertThrows(BsaException.class, () -> fetcher.tryFetch(BlockListType.BLOCK))\n .isRetriable())\n .isFalse();\n }", + "completions": [ + "```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n // Arrange\n String blockListTypeName = \"TEST\";\n BlockListType blockListType = new BlockListType(blockListTypeName);\n URL dataUrl = new URL(\"http://example.com\");\n when(blockListUrls.get(blockListTypeName)).thenReturn(dataUrl.toString());\n when(urlConnectionService.createConnection(dataUrl)).thenThrow(new GeneralSecurityException());\n\n // Act and Assert\n assertThrows(BsaException.class, () -> {\n try {\n blockListService.tryFetch(blockListType);\n } catch (BsaException e) {\n assertFalse(e.isRetriable());\n throw e;\n }\n });\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 81, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssignOnEmptyTopicPartition()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssignOnEmptyTopicPartition", + "reference": " @Test\n public void testAssignOnEmptyTopicPartition() {\n consumer = newConsumer();\n completeUnsubscribeApplicationEventSuccessfully();\n\n consumer.assign(Collections.emptyList());\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().isEmpty());\n }", + "completions": [ + "```java\n @Test\n public void testAssignOnEmptyTopicPartition() {\n // Given\n Collection partitions = new ArrayList<>();\n ConsumerImpl consumer = new ConsumerImpl();\n\n // When\n consumer.assign(partitions);\n\n // Then\n // Add assertions to verify the expected behavior\n // For example, you can check if the consumer is unsubscribed and if the fetch buffer is empty\n // assertFalse(consumer.isSubscribed());\n // assertTrue(consumer.getFetchBuffer().isEmpty());\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 82, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ImmutableList applyLabelDiff(\n ImmutableList labels,\n IdnChecker idnChecker,\n DownloadSchedule schedule,\n DateTime now) {\n ImmutableList.Builder nonBlockedDomains = new ImmutableList.Builder<>();\n ImmutableMap> labelsByType =\n ImmutableMap.copyOf(\n labels.stream().collect(groupingBy(BlockLabel::labelType, toImmutableList())));\n\n tm().transact(\n () -> {\n for (Map.Entry> entry :\n labelsByType.entrySet()) {\n switch (entry.getKey()) {\n case CREATE:\n // With current Cloud SQL, label upsert throughput is about 200/second. If\n // better performance is needed, consider bulk insert in native SQL.\n tm().putAll(\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(\n label ->\n new BsaLabel(label.label(), schedule.jobCreationTime()))\n .collect(toImmutableList()));\n // May not find all unblockables due to race condition: DomainCreateFlow uses\n // cached BsaLabels. Eventually will be consistent.\n nonBlockedDomains.addAll(\n tallyUnblockableDomainsForNewLabels(entry.getValue(), idnChecker, now));\n break;\n case DELETE:\n ImmutableSet deletedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n // Delete labels in DB. Also cascade-delete BsaUnblockableDomain.\n int nDeleted = Queries.deleteBsaLabelByLabels(deletedLabels);\n if (nDeleted != deletedLabels.size()) {\n logger.atSevere().log(\n \"Only found %s entities among the %s labels: [%s]\",\n nDeleted, deletedLabels.size(), deletedLabels);\n }\n break;\n case NEW_ORDER_ASSOCIATION:\n ImmutableSet affectedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n ImmutableSet labelsInDb =\n Queries.queryBsaLabelByLabels(affectedLabels)\n .map(BsaLabel::getLabel)\n .collect(toImmutableSet());\n verify(\n labelsInDb.size() == affectedLabels.size(),\n \"Missing labels in DB: %s\",\n LazyArgs.lazy(() -> difference(affectedLabels, labelsInDb)));\n\n // Reuse registered and reserved names that are already computed.\n Queries.queryBsaUnblockableDomainByLabels(affectedLabels)\n .map(BsaUnblockableDomain::toUnblockableDomain)\n .forEach(nonBlockedDomains::add);\n\n for (BlockLabel label : entry.getValue()) {\n getInvalidTldsForLabel(label, idnChecker)\n .map(tld -> UnblockableDomain.of(label.label(), tld, Reason.INVALID))\n .forEach(nonBlockedDomains::add);\n }\n break;\n }\n }\n },\n TRANSACTION_REPEATABLE_READ);\n logger.atInfo().log(\"Processed %s of labels.\", labels.size());\n return nonBlockedDomains.build();\n }\n```\n###test function signature:\n```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "applyLabelDiffs_newAssociationOfLabelToOrder", + "reference": " @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n tm().transact(\n () -> {\n tm().insert(new BsaLabel(\"label\", fakeClock.nowUtc()));\n tm().insert(new BsaUnblockableDomain(\"label\", \"app\", Reason.REGISTERED));\n });\n when(idnChecker.getSupportingTlds(any())).thenReturn(ImmutableSet.of(app));\n when(idnChecker.getForbiddingTlds(any()))\n .thenReturn(Sets.difference(ImmutableSet.of(dev), ImmutableSet.of()).immutableCopy());\n\n ImmutableList unblockableDomains =\n applyLabelDiff(\n ImmutableList.of(\n BlockLabel.of(\"label\", LabelType.NEW_ORDER_ASSOCIATION, ImmutableSet.of())),\n idnChecker,\n schedule,\n fakeClock.nowUtc());\n assertThat(unblockableDomains)\n .containsExactly(\n UnblockableDomain.of(\"label.app\", UnblockableDomain.Reason.REGISTERED),\n UnblockableDomain.of(\"label.dev\", UnblockableDomain.Reason.INVALID));\n assertThat(tm().transact(() -> tm().loadByKeyIfPresent(BsaLabel.vKey(\"label\")))).isPresent();\n assertThat(\n tm().transact(() -> tm().loadByKeyIfPresent(BsaUnblockableDomain.vKey(\"label\", \"app\"))))\n .isPresent();\n }", + "completions": [ + "```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n // Given\n ImmutableList labels = ImmutableList.of(BlockLabel.of(\"label1\", LabelType.NEW_ORDER_ASSOCIATION));\n IdnChecker idnChecker = mock(IdnChecker.class);\n DownloadSchedule schedule = mock(DownloadSchedule.class);\n DateTime now = DateTime.now();\n\n // When\n ImmutableList result = applyLabelDiff(labels, idnChecker, schedule, now);\n\n // Then\n // Add assertions here to verify the result\n }\n```\n" + ], + "line": 85, + "token": 602, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 83, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture createTopic(CreateTopicRequest request) {\n CompletableFuture future = new CompletableFuture<>();\n for (; ; ) {\n if (metadataStore.isLeader()) {\n try (SqlSession session = metadataStore.openSession()) {\n if (!metadataStore.maintainLeadershipWithSharedLock(session)) {\n continue;\n }\n TopicMapper topicMapper = session.getMapper(TopicMapper.class);\n if (null != topicMapper.get(null, request.getTopic())) {\n ControllerException e = new ControllerException(Code.DUPLICATED_VALUE,\n String.format(\"Topic %s was taken\", request.getTopic()));\n future.completeExceptionally(e);\n return future;\n }\n\n Topic topic = new Topic();\n topic.setName(request.getTopic());\n topic.setQueueNum(request.getCount());\n topic.setStatus(TopicStatus.TOPIC_STATUS_ACTIVE);\n topic.setAcceptMessageTypes(JsonFormat.printer().print(request.getAcceptTypes()));\n topic.setRetentionHours(request.getRetentionHours());\n topicMapper.create(topic);\n long topicId = topic.getId();\n List assignments = createQueues(IntStream.range(0, request.getCount()),\n topicId, session);\n // Commit transaction\n session.commit();\n\n // Cache new topic and queue assignments immediately\n topicCache.apply(List.of(topic));\n assignmentCache.apply(assignments);\n future.complete(topicId);\n } catch (ControllerException | InvalidProtocolBufferException e) {\n future.completeExceptionally(e);\n }\n return future;\n } else {\n Optional leaderAddress = metadataStore.electionService().leaderAddress();\n if (leaderAddress.isEmpty()) {\n return CompletableFuture.failedFuture(new ControllerException(Code.NO_LEADER_VALUE, \"No leader is elected yet\"));\n }\n metadataStore.controllerClient().createTopic(leaderAddress.get(), request).whenComplete((res, e) -> {\n if (null != e) {\n future.completeExceptionally(e);\n } else {\n future.complete(res);\n }\n });\n }\n break;\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCreateTopic()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCreateTopic", + "reference": " @Test\n public void testCreateTopic() {\n MetadataStore metadataStore = Mockito.mock(MetadataStore.class);\n ControllerClient controllerClient = Mockito.mock(ControllerClient.class);\n Mockito.when(controllerClient.createTopic(ArgumentMatchers.anyString(), ArgumentMatchers.any())).thenReturn(\n CompletableFuture.failedFuture(new CompletionException(new RuntimeException()))\n );\n Mockito.when(metadataStore.controllerClient()).thenReturn(controllerClient);\n Mockito.when(metadataStore.isLeader()).thenReturn(false);\n\n ElectionService electionService = Mockito.mock(ElectionService.class);\n Mockito.when(metadataStore.electionService()).thenReturn(electionService);\n\n Mockito.when(electionService.leaderAddress()).thenReturn(Optional.of(\"localhost:1234\"));\n\n TopicManager topicManager = new TopicManager(metadataStore);\n topicManager.createTopic(CreateTopicRequest.newBuilder().build());\n }", + "completions": [ + "```java\n @Test\n public void testCreateTopic() {\n // Given\n CreateTopicRequest request = new CreateTopicRequest();\n request.setTopic(\"testTopic\");\n request.setCount(3);\n request.setRetentionHours(24);\n request.setAcceptTypes(List.of(MessageType.MESSAGE_TYPE_JSON));\n\n // When\n CompletableFuture future = controller.createTopic(request);\n\n // Then\n assertDoesNotThrow(() -> {\n Long topicId = future.get();\n assertNotNull(topicId);\n assertTrue(topicId > 0);\n });\n }\n```\n" + ], + "line": 63, + "token": 424, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 84, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull(){\n try(MockedStatic sessionMockedStatic = Mockito.mockStatic(Session.class)){\n Session mockSession = mock(Session.class);\n Transport transport = mock(Transport.class);\n sessionMockedStatic.when(()->Session.getDefaultInstance(any(Properties.class))).thenReturn(mockSession);\n when(blacklistService.inBlacklist(anyString())).thenReturn(false);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n when(settingsService.getByCategoryAndKey(anyString(),anyString())).thenReturn(\"test\").thenReturn(5678).thenReturn(\"test\").thenReturn(\"test\").thenReturn(\"123456\");\n when(mockSession.getTransport(anyString())).thenReturn(transport);\n doNothing().when(transport).connect(anyString(),anyInt(),anyString(),anyString());\n doAnswer(invocationOnMock -> {\n InternetAddress[] internetAddressList = invocationOnMock.getArgument(1);\n Assertions.assertEquals(addressList.get(0),internetAddressList[0].getAddress());\n return null;\n }).when(transport).sendMessage(any(MimeMessage.class),any());\n mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n } catch (MessagingException e) {\n throw new RuntimeException(e);\n }\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = Arrays.asList(\"test1@example.com\", \"test2@example.com\");\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://example.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 85, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead", + "reference": " @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 1,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n AstraConfigs.IndexerConfig headLocationAndNoRecoveryConfig =\n AstraConfigs.IndexerConfig.newBuilder()\n .setCreateRecoveryTasksOnStart(false)\n .setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST)\n .build();\n\n // When there is no data and ReadFromLocationOnStart is set to LATEST, return the current head\n assertThat(\n recoveryTaskCreator.determineStartingOffset(1000, 0, headLocationAndNoRecoveryConfig))\n .isEqualTo(1000);\n\n // Data exists for not for this partition.\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTime = 1;\n final long endTime = 100;\n final long maxOffset = 100;\n\n final SnapshotMetadata partition1 =\n new SnapshotMetadata(name, path, startTime, endTime, maxOffset, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition1);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition1));\n\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n\n final SnapshotMetadata partition11 =\n new SnapshotMetadata(\n name + \"1\", path, endTime + 1, endTime * 2, maxOffset * 2, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition11);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition11));\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(partition1, partition11);\n assertThat(recoveryTaskCreator.determineStartingOffset(0, 0, indexerConfig)).isNegative();\n\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n \"2\",\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(false);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n\n // When\n long result = determineStartingOffset(currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(currentEndOffsetForPartition, result);\n }\n```\n" + ], + "line": 128, + "token": 935, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 86, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime =\n LocalDateTime.of(2020, 10, 1, 10, 10, 0).atZone(ZoneOffset.UTC).toInstant();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n long startOffset = 1;\n long endOffset = msgsToProduce - 1;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry))\n .isEqualTo(endOffset - startOffset + 1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(0);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(0);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isEmpty();\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 0,\n 100,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 87, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"3g\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n String expirationTime = \"10x\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(expirationTime);\n assertEquals(CacheControlHeader.DEFAULT_BASE_VALUE + CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 88, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_empty_collection()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_empty_collection", + "reference": " @Test\n public void should_clone_empty_collection() {\n List original = new ArrayList<>();\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_empty_collection() {\n // Given\n Collection collection = new ArrayList<>();\n\n // When\n Collection clonedCollection = (Collection) clone(collection);\n\n // Then\n assertNotSame(collection, clonedCollection);\n assertEquals(collection, clonedCollection);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 89, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testTimeBoundSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testTimeBoundSearch", + "reference": " @Test\n public void testTimeBoundSearch() {\n Instant time = Instant.now();\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(1, time));\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(2, time.plusSeconds(100)));\n strictLogStore.logStore.commit();\n strictLogStore.logStore.refresh();\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(2);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(0);\n assertThat(getTimerCount(REFRESHES_TIMER, strictLogStore.metricsRegistry)).isEqualTo(1);\n\n // Start inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Extended range still only picking one element.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(90).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Both ranges are inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.toEpochMilli(),\n time.plusSeconds(100).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n // Extended range to pick up both events\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(1000).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n }", + "completions": [ + "```java\n @Test\n public void testTimeBoundSearch() {\n // Given\n String dataset = \"test_dataset\";\n String queryStr = \"test_query\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n for (LogMessage logMessage : result.getResults()) {\n assertTrue(logMessage.getTimestampMs() >= startTimeMsEpoch);\n assertTrue(logMessage.getTimestampMs() <= endTimeMsEpoch);\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 90, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "\t@ParameterizedTest\n\t@MethodSource", + "reference": "\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\t// setup\n\t\tString key = \"key\";\n\t\tTestParameters parameters = new TestParameters();\n\t\tObject testValue = theParams.getTestValue();\n\n\t\t// test\n\t\tif (theParams.isExpectedToWork()) {\n\t\t\tparameters.setUserData(key, testValue);\n\t\t\tassertFalse(parameters.getUserData().isEmpty());\n\t\t\tassertEquals(testValue, parameters.getUserData().get(key));\n\t\t} else {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, testValue);\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tString dataType = testValue.getClass().getName();\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid data type provided \" + dataType),\n\t\t\t\t\tex.getMessage());\n\t\t\t\tassertTrue(parameters.getUserData().isEmpty());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\tUser user = new User();\n\t\tuser.setUserData(theParams.getKey(), theParams.getValue());\n\t\tassertEquals(theParams.getValue(), user.getUserData().get(theParams.getKey()));\n\t}\n\n\tstatic Stream parameters() {\n\t\treturn Stream.of(\n\t\t\t\tnew TestParam(\"stringKey\", \"stringValue\"),\n\t\t\t\tnew TestParam(\"numberKey\", 123),\n\t\t\t\tnew TestParam(\"booleanKey\", true),\n\t\t\t\tnew TestParam(\"nullKey\", null)\n\t\t);\n\t}\n\n\tstatic class TestParam {\n\t\tprivate final String key;\n\t\tprivate final Object value;\n\n\t\tTestParam(String key, Object value) {\n\t\t\tthis.key = key;\n\t\t\tthis.value = value;\n\t\t}\n\n\t\tString getKey() {\n\t\t\treturn key;\n\t\t}\n\n\t\tObject getValue() {\n\t\t\treturn value;\n\t\t}\n\t}\n```\n" + ], + "line": 24, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\n\t\tString text = \"Hello, how are you?\";\n\n\t\tEmbedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));\n\n\t\tserver\n\t\t\t.expect(requestToUriTemplate(\"/models/{generative}:embedText?key={apiKey}\",\n\t\t\t\t\tVertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))\n\t\t\t.andExpect(method(HttpMethod.POST))\n\t\t\t.andExpect(content().json(objectMapper.writeValueAsString(Map.of(\"text\", text))))\n\t\t\t.andRespond(withSuccess(objectMapper.writeValueAsString(Map.of(\"embedding\", expectedEmbedding)),\n\t\t\t\t\tMediaType.APPLICATION_JSON));\n\n\t\tEmbedding embedding = client.embedText(text);\n\n\t\tassertThat(embedding).isEqualTo(expectedEmbedding);\n\n\t\tserver.verify();\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\tEmbeddingResponse response = new EmbeddingResponse(expectedEmbedding);\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(response);\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Mono syncAclWithAclCsv(KafkaCluster cluster, String csv) {\n return adminClientService.get(cluster)\n .flatMap(ac -> ac.listAcls(ResourcePatternFilter.ANY).flatMap(existingAclList -> {\n var existingSet = Set.copyOf(existingAclList);\n var newAcls = Set.copyOf(AclCsv.parseCsv(csv));\n var toDelete = Sets.difference(existingSet, newAcls);\n var toAdd = Sets.difference(newAcls, existingSet);\n logAclSyncPlan(cluster, toAdd, toDelete);\n if (toAdd.isEmpty() && toDelete.isEmpty()) {\n return Mono.empty();\n }\n log.info(\"Starting new ACLs creation\");\n return ac.createAcls(toAdd)\n .doOnSuccess(v -> {\n log.info(\"{} new ACLs created\", toAdd.size());\n log.info(\"Starting ACLs deletion\");\n })\n .then(ac.deleteAcls(toDelete)\n .doOnSuccess(v -> log.info(\"{} ACLs deleted\", toDelete.size())));\n }));\n }\n```\n###test function signature:\n```java\n @Test\n void testSyncAclWithAclCsv()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSyncAclWithAclCsv", + "reference": " @Test\n void testSyncAclWithAclCsv() {\n var existingBinding1 = new AclBinding(\n new ResourcePattern(ResourceType.TOPIC, \"*\", PatternType.LITERAL),\n new AccessControlEntry(\"User:test1\", \"*\", AclOperation.READ, AclPermissionType.ALLOW));\n\n var existingBinding2 = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"group1\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test2\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n var newBindingToBeAdded = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"groupNew\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test3\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n when(adminClientMock.listAcls(ResourcePatternFilter.ANY))\n .thenReturn(Mono.just(List.of(existingBinding1, existingBinding2)));\n\n ArgumentCaptor> createdCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.createAcls(createdCaptor.capture()))\n .thenReturn(Mono.empty());\n\n ArgumentCaptor> deletedCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.deleteAcls(deletedCaptor.capture()))\n .thenReturn(Mono.empty());\n\n aclsService.syncAclWithAclCsv(\n CLUSTER,\n \"Principal,ResourceType, PatternType, ResourceName,Operation,PermissionType,Host\\n\"\n + \"User:test1,TOPIC,LITERAL,*,READ,ALLOW,*\\n\"\n + \"User:test3,GROUP,PREFIXED,groupNew,DESCRIBE,DENY,localhost\"\n ).block();\n\n Collection createdBindings = createdCaptor.getValue();\n assertThat(createdBindings)\n .hasSize(1)\n .contains(newBindingToBeAdded);\n\n Collection deletedBindings = deletedCaptor.getValue();\n assertThat(deletedBindings)\n .hasSize(1)\n .contains(existingBinding2);\n }", + "completions": [ + "```java\n @Test\n void testSyncAclWithAclCsv() {\n // Given\n KafkaCluster cluster = new KafkaCluster(\"testCluster\", \"localhost:9092\");\n String csv = \"principal,host,operation,permissionType\\nUser:Alice,*,Read,Allow\\nUser:Bob,*,Write,Allow\";\n AdminClient adminClient = Mockito.mock(AdminClient.class);\n when(adminClientService.get(cluster)).thenReturn(Mono.just(adminClient));\n when(adminClient.listAcls(ResourcePatternFilter.ANY)).thenReturn(Flux.just(\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Alice\", \"*\", AclOperation.READ, AclPermissionType.ALLOW)),\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Bob\", \"*\", AclOperation.WRITE, AclPermissionType.ALLOW))\n ).collectList());\n\n // When\n Mono result = service.syncAclWithAclCsv(cluster, csv);\n\n // Then\n StepVerifier.create(result)\n .verifyComplete();\n\n verify(adminClient).createAcls(any());\n verify(adminClient).deleteAcls(any());\n }\n```\n" + ], + "line": 32, + "token": 259, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 93, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyNoResults", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final SearchResponse mockResponse = mock(SearchResponse.class);\n final SearchHits mockHits = mock(SearchHits.class);\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n when(mockHits.getHits()).thenReturn(new SearchHit[0]);\n when(tenantAwareClient.search(any(SearchRequest.class))).thenReturn(mockResponse);\n\n // When\n try {\n processInstanceService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 94, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyMultipleRecoveryBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(3);\n assertThat(recoveryTasks).contains(recoveryTask1, recoveryTask11);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo(recoveryStartOffset * 3 + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(100);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 95, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n private String saveAndChecksumBlockList(String jobName, LazyBlockList blockList) {\n BlobId blobId = getBlobId(jobName, getBlockListFileName(blockList.getName()));\n try (BufferedOutputStream gcsWriter =\n new BufferedOutputStream(gcsUtils.openOutputStream(blobId))) {\n MessageDigest messageDigest = MessageDigest.getInstance(checksumAlgorithm);\n blockList.consumeAll(\n (byteArray, length) -> {\n try {\n gcsWriter.write(byteArray, 0, length);\n } catch (IOException e) {\n throw new RuntimeException(e);\n }\n messageDigest.update(byteArray, 0, length);\n });\n return base16().lowerCase().encode(messageDigest.digest());\n } catch (IOException | NoSuchAlgorithmException e) {\n throw new RuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "saveAndChecksumBlockList_success", + "reference": " @Test\n void saveAndChecksumBlockList_success() throws Exception {\n String payload = \"somedata\\n\";\n String payloadChecksum = \"0737c8e591c68b93feccde50829aca86a80137547d8cfbe96bab6b20f8580c63\";\n\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream((\"bsa-checksum\\n\" + payload).getBytes(UTF_8)));\n lazyBlockList = new LazyBlockList(BlockListType.BLOCK, connection);\n\n ImmutableMap checksums =\n gcsClient.saveAndChecksumBlockList(\"some-name\", ImmutableList.of(lazyBlockList));\n assertThat(gcsUtils.existsAndNotEmpty(BlobId.of(\"my-bucket\", \"some-name/BLOCK.csv\"))).isTrue();\n assertThat(checksums).containsExactly(BlockListType.BLOCK, payloadChecksum);\n assertThat(gcsClient.readBlockList(\"some-name\", BlockListType.BLOCK))\n .containsExactly(\"somedata\");\n }", + "completions": [ + "```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception {\n // Given\n String jobName = \"testJob\";\n LazyBlockList blockList = mock(LazyBlockList.class);\n BlobId blobId = mock(BlobId.class);\n BufferedOutputStream gcsWriter = mock(BufferedOutputStream.class);\n MessageDigest messageDigest = mock(MessageDigest.class);\n when(gcsUtils.openOutputStream(blobId)).thenReturn(gcsWriter);\n when(checksumAlgorithm).thenReturn(\"SHA-256\");\n when(messageDigest.digest()).thenReturn(new byte[] {1, 2, 3, 4, 5});\n when(base16().lowerCase().encode(messageDigest.digest())).thenReturn(\"0102030405\");\n\n // When\n String checksum = saveAndChecksumBlockList(jobName, blockList);\n\n // Then\n assertEquals(\"0102030405\", checksum);\n verify(gcsWriter).write(any(), anyInt(), anyInt());\n verify(messageDigest).update(any(), anyInt(), anyInt());\n }\n```\n" + ], + "line": 30, + "token": 211, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 96, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ListenableFuture> findMissingBlobs(\n Iterable blobDigests, RequestMetadata requestMetadata) {\n // Some requests have been blocked, and we should tell the client we refuse to perform a lookup.\n try {\n if (inDenyList(requestMetadata)) {\n return immediateFailedFuture(\n Status.UNAVAILABLE\n .withDescription(\"The action associated with this request is forbidden\")\n .asException());\n }\n } catch (IOException e) {\n return immediateFailedFuture(Status.fromThrowable(e).asException());\n }\n\n // Empty blobs are an exceptional case. Filter them out.\n // If the user only requested empty blobs we can immediately tell them we already have it.\n Iterable nonEmptyDigests =\n Iterables.filter(blobDigests, (digest) -> digest.getSizeBytes() != 0);\n if (Iterables.isEmpty(nonEmptyDigests)) {\n return immediateFuture(ImmutableList.of());\n }\n\n if (configs.getServer().isFindMissingBlobsViaBackplane()) {\n return findMissingBlobsViaBackplane(nonEmptyDigests, requestMetadata);\n }\n\n return findMissingBlobsQueryingEachWorker(nonEmptyDigests, requestMetadata);\n }\n```\n###test function signature:\n```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "findMissingBlobsTest_ViaBackPlane", + "reference": " @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n Set activeWorkers = ImmutableSet.of(\"worker1\", \"worker2\", \"worker3\");\n Set expiredWorkers = ImmutableSet.of(\"workerX\", \"workerY\", \"workerZ\");\n Set imposterWorkers = ImmutableSet.of(\"imposter1\", \"imposter2\", \"imposter3\");\n\n Set availableDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFound1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build());\n\n Set missingDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"missing1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build());\n\n Set digestAvailableOnImposters =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFoundOnImposter1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter3\").setSizeBytes(1).build());\n\n Set emptyDigests =\n new HashSet<>(\n Arrays.asList(\n Digest.newBuilder().setHash(\"empty1\").build(),\n Digest.newBuilder().setHash(\"empty2\").build()));\n\n Iterable allDigests =\n Iterables.concat(\n availableDigests,\n missingDigests,\n emptyDigests,\n digestAvailableOnImposters,\n Arrays.asList(\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build()));\n\n Map> digestAndWorkersMap = new HashMap<>();\n\n for (Digest digest : availableDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(activeWorkers));\n }\n for (Digest digest : missingDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(expiredWorkers));\n }\n for (Digest digest : digestAvailableOnImposters) {\n digestAndWorkersMap.put(digest, getRandomSubset(imposterWorkers));\n }\n\n BuildfarmConfigs buildfarmConfigs = instance.getBuildFarmConfigs();\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(true);\n Set activeAndImposterWorkers =\n Sets.newHashSet(Iterables.concat(activeWorkers, imposterWorkers));\n\n when(mockBackplane.getStorageWorkers()).thenReturn(activeAndImposterWorkers);\n when(mockBackplane.getBlobDigestsWorkers(any(Iterable.class))).thenReturn(digestAndWorkersMap);\n when(mockInstanceLoader.load(anyString())).thenReturn(mockWorkerInstance);\n when(mockWorkerInstance.findMissingBlobs(anyIterable(), any(RequestMetadata.class)))\n .thenReturn(Futures.immediateFuture(new ArrayList<>()));\n\n long serverStartTime = 1686951033L; // june 15th, 2023\n Map workersStartTime = new HashMap<>();\n for (String worker : activeAndImposterWorkers) {\n workersStartTime.put(worker, serverStartTime);\n }\n when(mockBackplane.getWorkersStartTimeInEpochSecs(activeAndImposterWorkers))\n .thenReturn(workersStartTime);\n long oneDay = 86400L;\n for (Digest digest : availableDigests) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime + oneDay);\n }\n for (Digest digest : digestAvailableOnImposters) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime - oneDay);\n }\n\n Iterable actualMissingDigests =\n instance.findMissingBlobs(allDigests, RequestMetadata.getDefaultInstance()).get();\n Iterable expectedMissingDigests =\n Iterables.concat(missingDigests, digestAvailableOnImposters);\n\n assertThat(actualMissingDigests).containsExactlyElementsIn(expectedMissingDigests);\n verify(mockWorkerInstance, atMost(3))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n verify(mockWorkerInstance, atLeast(1))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n\n for (Digest digest : actualMissingDigests) {\n assertThat(digest).isNotIn(availableDigests);\n assertThat(digest).isNotIn(emptyDigests);\n assertThat(digest).isIn(expectedMissingDigests);\n }\n\n // reset BuildfarmConfigs\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(false);\n }", + "completions": [ + "```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n // Given\n Iterable blobDigests = Arrays.asList(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n RequestMetadata requestMetadata = RequestMetadata.newBuilder().setClientId(\"client1\").build();\n\n // When\n ListenableFuture> future =\n underTest.findMissingBlobs(blobDigests, requestMetadata);\n\n // Then\n Iterable missingBlobs = future.get();\n assertThat(missingBlobs).containsExactly(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n }\n```\n" + ], + "line": 37, + "token": 276, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 97, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent_notFound()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent_notFound", + "reference": " @Test\n public void testReadComponent_notFound() {\n Component component = mock(Component.class);\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(null);\n try {\n underTest.readComponent(\"someid\", \"testRepositoryName\");\n fail(\"Exception should have been thrown\");\n }\n catch (WebApplicationException e) {\n assertThat(e.getResponse(), is(notNullValue()));\n assertThat(e.getResponse().getStatus(), is(404));\n }\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent_notFound() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(() -> storageTx);\n when(storageTx.findComponent(componentId)).thenReturn(null);\n\n // When\n try {\n readComponent(repository, componentId);\n fail(\"Expected WebApplicationException\");\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Status.NOT_FOUND.getStatusCode(), e.getResponse().getStatus());\n }\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 98, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture auth(MQTT3AuthData authData) {\n try {\n HttpRequest request = HttpRequest.newBuilder()\n .uri(URI.create(webhookURI + \"/auth\"))\n .header(\"Content-Type\", \"application/json\")\n .POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))\n .timeout(Duration.ofSeconds(5))\n .build();\n return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())\n .thenApply(response -> {\n if (response.statusCode() == 200) {\n try {\n MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();\n JsonFormat.parser()\n .ignoringUnknownFields()\n .merge(response.body(), resultBuilder);\n return resultBuilder.build();\n } catch (Throwable e) {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build();\n }\n } else {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(\"Authenticate failed\")\n .build())\n .build();\n }\n })\n .exceptionally(e -> {\n System.out.println(\"Failed to call webhook: \" + e.getMessage());\n return null;\n });\n } catch (Throwable e) {\n return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAuth()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAuth", + "reference": " @Test\n public void testAuth() {\n MQTT3AuthData authedUser = MQTT3AuthData.newBuilder()\n .setUsername(\"authUser\")\n .build();\n\n MQTT3AuthData unauthUser = MQTT3AuthData.newBuilder()\n .setUsername(\"unauthUser\")\n .build();\n testServer.addAuthedUser(authedUser.getUsername());\n WebHookBasedAuthProvider provider = new WebHookBasedAuthProvider(testServer.getURI());\n MQTT3AuthResult authResult = provider.auth(authedUser).join();\n assertTrue(authResult.hasOk());\n\n authResult = provider.auth(unauthUser).join();\n assertFalse(authResult.hasOk());\n }", + "completions": [ + "```java\n @Test\n public void testAuth() {\n // Given\n String webhookURI = \"http://localhost:8080\";\n HttpClient httpClient = HttpClient.newHttpClient();\n MQTT3AuthData authData = MQTT3AuthData.newBuilder()\n .setClientId(\"client1\")\n .setUsername(\"user1\")\n .setPassword(\"password1\")\n .build();\n MQTT3AuthResult expectedResult = MQTT3AuthResult.newBuilder()\n .setSuccess(Success.newBuilder()\n .setIsSuperuser(false)\n .build())\n .build();\n\n // When\n CompletableFuture result = new MQTT3AuthService(webhookURI, httpClient).auth(authData);\n\n // Then\n assertNotNull(result);\n assertEquals(expectedResult, result.join());\n }\n```\n" + ], + "line": 59, + "token": 335, + "line_diff": 4, + "token_diff": 2 + }, + { + "id": 99, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_success", + "reference": " @Test\n public void testValidateRoleDuplicate_success() {\n // 调用,不会抛异常\n roleService.validateRoleDuplicate(randomString(), randomString(), null);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_success() {\n // Given\n String name = \"TestRole\";\n String code = \"TEST_ROLE\";\n Long id = 1L;\n\n // When\n // Call the method to be tested\n validateRoleDuplicate(name, code, id);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 1, + "token_diff": 1 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/6/QS/line_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Java/6/QS/line_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..1d1024071d37b4fac2451862991ddc5c288deb0a --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/6/QS/line_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(22~28),subset_1(29~37),subset_2(37~44),subset_3(44~57),subset_4(59~85),subset_5(85~128) +StarCoder2-15b,33.42,30.44,32.66,31.66,26.63,18.06 +CodeLlama-7b,35.78,32.60,33.45,30.50,27.62,21.68 +CodeLlama-13b,30.27,31.05,30.77,29.86,27.03,20.04 +CodeLlama-34b,32.97,29.78,32.43,31.26,25.44,22.48 +DeepSeek-Coder-1.3b,28.15,31.59,30.67,33.30,27.95,19.59 +DeepSeek-Coder-6.7b,18.13,22.43,23.05,17.26,24.35,22.05 +DeepSeek-Coder-33b,31.33,38.05,33.66,33.55,29.14,20.10 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/6/QS/token_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Java/6/QS/token_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..3cc1d704f3acf50933e671de067688f6cbb103a1 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/6/QS/token_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~205),subset_1(205~279),subset_2(279~354),subset_3(358~424),subset_4(486~606),subset_5(606~952) +StarCoder2-15b,31.09,30.49,29.57,38.64,21.92,21.08 +CodeLlama-7b,32.92,33.55,31.68,36.57,25.33,21.44 +CodeLlama-13b,29.05,30.27,29.62,33.49,25.47,21.11 +CodeLlama-34b,31.60,28.37,29.74,37.70,23.95,22.95 +DeepSeek-Coder-1.3b,26.90,30.55,29.58,37.45,25.08,21.63 +DeepSeek-Coder-6.7b,20.31,19.69,24.35,18.45,21.75,22.60 +DeepSeek-Coder-33b,31.32,35.27,30.36,39.57,26.67,22.65 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/6/QS/tongji.py b/dataset/Test Generation/ComplexCodeEval-Java/6/QS/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/6/QS/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/7/EI/EI.json b/dataset/Test Generation/ComplexCodeEval-Java/7/EI/EI.json new file mode 100644 index 0000000000000000000000000000000000000000..8763191369f2a9eee1f11be91f9d37a0992c3369 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/7/EI/EI.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTask", + "reference": " @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(TEST_S3_BUCKET);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 100L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockKafkaConsumer.consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(true);\n Mockito.when(mockKafkaConsumer.close()).thenReturn(true);\n Mockito.when(mockChunkManager.stopAsync()).thenReturn(true);\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenReturn(true);\n Mockito.when(AstraKafkaConsumer.makeKafkaConfig(Mockito.any(), Mockito.anyInt())).thenReturn(new Properties());\n Mockito.when(RecoveryChunkManager.fromConfig(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(mockChunkManager);\n Mockito.when(adminClient.listConsumerGroupOffsets(Mockito.anyString())).thenReturn(new ConsumerGroupOffsets());\n Mockito.when(AstraConfig.getRecoveryConfig()).thenReturn(new RecoveryConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig()).thenReturn(new KafkaConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic()).thenReturn(\"testTopic\");\n Mockito.when(AstraConfig.getIndexerConfig()).thenReturn(new IndexerConfig());\n Mockito.when(AstraConfig.getS3Config()).thenReturn(new S3Config());\n Mockito.when(meterRegistry.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class));\n Mockito.when(Timer.start(Mockito.any())).thenReturn(Mockito.mock(Timer.Sample.class));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n Assert.assertTrue(result);\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).prepConsumerForConsumption(Mockito.anyLong());\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong());\n Mockito.verify(mockChunkManager, Mockito.times(1)).waitForRollOvers();\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).close();\n Mockito.verify(mockChunkManager, Mockito.times(1)).stopAsync();\n Mockito.verify(mockChunkManager, Mockito.times(1)).awaitTerminated(Mockito.anyLong());\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 1, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskFailure", + "reference": " @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n String fakeS3Bucket = \"fakeBucket\";\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(fakeS3Bucket);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isFalse();\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 10L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(false);\n Mockito.when(mockKafkaConsumer.close()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.stopAsync()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertFalse(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 2, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public int migration(String srcTable, List uuids, String targetTable) {\n if (Objects.isNull(srcTable) || Objects.isNull(targetTable) || CollectionUtils.isEmpty(uuids)) {\n throw new SwValidationException(\n ValidSubject.EVALUATION, \"srcTable, targetTable and uuids must not be null\");\n }\n var filter = TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build();\n return store.migration(DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(filter)\n .build()\n );\n }\n```\n###test function signature:\n```java\n @Test\n public void testMigration()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMigration", + "reference": " @Test\n public void testMigration() {\n var srcTable = \"p/1/space/1/eval/summary\";\n var targetTable = \"p/1/eval/summary\";\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(null, List.of(\"uuid1\"), targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, null, targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, List.of(\"uuid1\"), null));\n\n evaluationRepo.migration(srcTable, List.of(\"uuid1\", \"uuid2\"), targetTable);\n\n verify(dataStore, times(1)).migration(\n DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(\n TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(List.of(\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid1\")))\n .build(),\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid2\")))\n .build()))\n .build())\n .build()\n );\n }", + "completions": [ + "```java\n @Test\n public void testMigration() {\n // Given\n String srcTable = \"srcTable\";\n List uuids = Arrays.asList(\"uuid1\", \"uuid2\", \"uuid3\");\n String targetTable = \"targetTable\";\n DataStoreMigrationRequest expectedRequest = DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build()\n )\n .build();\n when(store.migration(any(DataStoreMigrationRequest.class))).thenReturn(1);\n\n // When\n int result = migrationService.migration(srcTable, uuids, targetTable);\n\n // Then\n assertEquals(1, result);\n verify(store).migration(expectedRequest);\n }\n```\n" + ], + "line": 37, + "token": 221, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDatasMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDatasMapTest", + "reference": " @Test\n public void sendDatasMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"doris_test_db\";\n List> dataList = new ArrayList<>();\n for (int i = 0; i < 100; i++) {\n Map data = new HashMap<>();\n data.put(\"id\", i);\n data.put(\"name\", \"张三\");\n data.put(\"message\", \"股东风波的法规的规定翻跟斗广泛的给yryrtytr\");\n System.out.println(gson.toJson(data));\n dataList.add(data);\n\n }\n dorisStreamLoad.sendData(database, tableName, columnList, dataList);\n }", + "completions": [ + "```java\n @Test\n public void sendDataTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 4, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_serializable_key_and_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_serializable_key_and_value", + "reference": " @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map original = new HashMap<>();\n original.put(\"key\", new SerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map originalMap = new HashMap<>();\n originalMap.put(\"one\", 1);\n originalMap.put(\"two\", 2);\n originalMap.put(\"three\", 3);\n\n Map clonedMap = clone(originalMap);\n\n assertNotSame(originalMap, clonedMap);\n assertEquals(originalMap, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 5, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void encode(String schemaName, long nextTransactionOffset) {\n if (schemaName == null) {\n schemaName = UNKNOWN;\n }\n this.schemaName = schemaName;\n this.nextTransactionOffset = nextTransactionOffset;\n final byte[] payloadBytes = payloadToBytes();\n final int payloadLength = payloadBytes.length;\n\n // set logEventHeader\n int eventSize = eventHeaderLengthVersionGt1 + payloadLength;\n setLogEventHeader(new LogEventHeader(drc_filter_log_event.getType(), 0, eventSize, nextTransactionOffset)\n );\n\n // set payload\n final ByteBuf payloadByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(payloadLength);\n payloadByteBuf.writeBytes(payloadBytes);\n payloadByteBuf.skipBytes(payloadLength);\n setPayloadBuf(payloadByteBuf);\n }\n```\n###test function signature:\n```java\n @Test\n public void encode()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "encode", + "reference": " @Test\n public void encode() {\n FilterLogEvent filterLogEvent = new FilterLogEvent();\n filterLogEvent.encode(\"drc1\", 101);\n Assert.assertEquals(\"drc1\", filterLogEvent.getSchemaName());\n Assert.assertEquals(101, filterLogEvent.getNextTransactionOffset());\n\n ByteBuf headerBuf = filterLogEvent.getLogEventHeader().getHeaderBuf();\n headerBuf.readerIndex(0);\n ByteBuf payloadBuf = filterLogEvent.getPayloadBuf();\n payloadBuf.readerIndex(0);\n CompositeByteBuf compositeByteBuf = PooledByteBufAllocator.DEFAULT.compositeDirectBuffer();\n compositeByteBuf.addComponents(true, headerBuf, payloadBuf);\n\n FilterLogEvent newFilterLogEvent = new FilterLogEvent();\n newFilterLogEvent.read(compositeByteBuf);\n compositeByteBuf.release(compositeByteBuf.refCnt());\n Assert.assertEquals(\"drc1\", newFilterLogEvent.getSchemaName());\n Assert.assertEquals(101, newFilterLogEvent.getNextTransactionOffset());\n }", + "completions": [ + "```java\n @Test\n public void encode() {\n // Given\n String schemaName = \"testSchema\";\n long nextTransactionOffset = 1234567890L;\n drc_filter_log_event event = new drc_filter_log_event();\n\n // When\n event.encode(schemaName, nextTransactionOffset);\n\n // Then\n assertEquals(schemaName, event.getSchemaName());\n assertEquals(nextTransactionOffset, event.getNextTransactionOffset());\n assertEquals(drc_filter_log_event.getType(), event.getLogEventHeader().getType());\n assertEquals(0, event.getLogEventHeader().getServerId());\n assertEquals(eventHeaderLengthVersionGt1 + event.getPayloadBuf().readableBytes(), event.getLogEventHeader().getEventLength());\n assertEquals(nextTransactionOffset, event.getLogEventHeader().getNextTransactionOffset());\n assertEquals(event.getPayloadBuf().readableBytes(), event.getPayloadBuf().readableBytes());\n }\n```\n" + ], + "line": 29, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void put(String path, long position, ByteBuf data) {\n writeLock.lock();\n try {\n int dataLength = data.readableBytes();\n NavigableMap cache = path2cache.computeIfAbsent(path, k -> new TreeMap<>());\n Map.Entry pos2value = cache.floorEntry(position);\n long cacheStartPosition;\n Value value;\n if (pos2value == null || pos2value.getKey() + pos2value.getValue().dataLength < position) {\n cacheStartPosition = position;\n value = Value.EMPTY;\n } else {\n cacheStartPosition = pos2value.getKey();\n value = pos2value.getValue();\n }\n // ensure the capacity, if the capacity change then update the cache index\n int moreCapacity = (int) ((position + dataLength) - (cacheStartPosition + value.blocks.length * (long) blockSize));\n int newDataLength = (int) (position + dataLength - cacheStartPosition);\n if (moreCapacity > 0) {\n int[] blocks = ensureCapacity(cacheStartPosition, moreCapacity);\n if (blocks == null) {\n return;\n }\n int[] newBlocks = new int[value.blocks.length + blocks.length];\n System.arraycopy(value.blocks, 0, newBlocks, 0, value.blocks.length);\n System.arraycopy(blocks, 0, newBlocks, value.blocks.length, blocks.length);\n value = new Value(newBlocks, newDataLength);\n } else {\n value = new Value(value.blocks, newDataLength);\n }\n cache.put(cacheStartPosition, value);\n lru.put(new Key(path, cacheStartPosition), value);\n\n // write data to cache\n ByteBuffer cacheByteBuffer = this.cacheByteBuffer.duplicate();\n int positionDelta = (int) (position - cacheStartPosition);\n int written = 0;\n ByteBuffer[] nioBuffers = data.nioBuffers();\n int[] blocks = value.blocks;\n for (ByteBuffer nioBuffer : nioBuffers) {\n ByteBuf buf = Unpooled.wrappedBuffer(nioBuffer);\n while (buf.readableBytes() > 0) {\n int writePosition = positionDelta + written;\n int block = blocks[writePosition / blockSize];\n cacheByteBuffer.position(block * blockSize + writePosition % blockSize);\n int length = Math.min(buf.readableBytes(), blockSize - writePosition % blockSize);\n cacheByteBuffer.put(buf.slice(buf.readerIndex(), length).nioBuffer());\n buf.skipBytes(length);\n written += length;\n }\n }\n } finally {\n writeLock.unlock();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testMergePut() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMergePut", + "reference": " @Test\n public void testMergePut() throws IOException {\n FileCache fileCache = new FileCache(\"/tmp/file_cache_test\", 10 * 1024, 1024);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 1, 500));\n buf.addComponent(true, genBuf((byte) 2, 500));\n buf.addComponent(true, genBuf((byte) 3, 500));\n fileCache.put(\"test\", 3333L, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(1500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 4, 500));\n buf.addComponent(true, genBuf((byte) 5, 500));\n buf.addComponent(true, genBuf((byte) 6, 500));\n fileCache.put(\"test\", 3333L + 1000, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(2500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 7, 500));\n fileCache.put(\"test\", 3333L + 1000 + 1500, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(3000, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n\n assertTrue(verify(fileCache.get(\"test\", 3333L, 500).get(), (byte) 1));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 500, 500).get(), (byte) 2));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1000, 500).get(), (byte) 4));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1500, 500).get(), (byte) 5));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2000, 500).get(), (byte) 6));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2500, 500).get(), (byte) 7));\n }", + "completions": [ + "```java\n @Test\n public void testMergePut() throws IOException {\n // Given\n String path = \"/test/path\";\n long position = 100;\n ByteBuf data = Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5});\n\n // When\n cache.put(path, position, data);\n\n // Then\n NavigableMap cacheMap = cache.path2cache.get(path);\n assertNotNull(cacheMap);\n Map.Entry entry = cacheMap.floorEntry(position);\n assertNotNull(entry);\n assertEquals(position, entry.getKey().longValue());\n assertEquals(5, entry.getValue().dataLength);\n assertEquals(1, entry.getValue().blocks[0]);\n assertEquals(2, entry.getValue().blocks[1]);\n assertEquals(3, entry.getValue().blocks[2]);\n assertEquals(4, entry.getValue().blocks[3]);\n assertEquals(5, entry.getValue().blocks[4]);\n }\n```\n" + ], + "line": 65, + "token": 508, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 7, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchById()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchById", + "reference": " @Test\n public void testSearchById() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult index =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(index.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchById() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 8, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = new ArrayList<>();\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 9, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_object", + "reference": " @Test\n public void should_clone_serializable_object() {\n Object original = new SerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_object() {\n // Given\n Serializable object = new Serializable() {\n private static final long serialVersionUID = 1L;\n // Implement necessary methods\n };\n\n // When\n Serializable clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertEquals(object, clonedObject);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 10, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() {\n\n\t\tvar text = \"Hello, how are you?\";\n\n\t\tEmbedding response = vertexAiPaLm2Api.embedText(text);\n\n\t\tassertThat(response).isNotNull();\n\t\tassertThat(response.value()).hasSize(768);\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(new EmbeddingResponse(expectedEmbedding));\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Optional maybeFetchErrorIntervalMs(short errorCode, int intervalMs) {\n if (errorCode == Errors.NONE.code())\n return Optional.empty();\n\n int pushIntervalMs;\n String reason;\n\n Errors error = Errors.forCode(errorCode);\n switch (error) {\n case INVALID_REQUEST:\n case INVALID_RECORD:\n case UNSUPPORTED_VERSION:\n pushIntervalMs = Integer.MAX_VALUE;\n reason = \"The broker response indicates the client sent an request that cannot be resolved\"\n + \" by re-trying, hence disable telemetry\";\n break;\n case UNKNOWN_SUBSCRIPTION_ID:\n case UNSUPPORTED_COMPRESSION_TYPE:\n pushIntervalMs = 0;\n reason = error.message();\n break;\n case TELEMETRY_TOO_LARGE:\n case THROTTLING_QUOTA_EXCEEDED:\n reason = error.message();\n pushIntervalMs = (intervalMs != -1) ? intervalMs : ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS;\n break;\n default:\n reason = \"Unwrapped error code\";\n log.error(\"Error code: {}. Unmapped error for telemetry, disable telemetry.\", errorCode);\n pushIntervalMs = Integer.MAX_VALUE;\n }\n\n log.debug(\"Error code: {}, reason: {}. Push interval update to {} ms.\", errorCode, reason, pushIntervalMs);\n return Optional.of(pushIntervalMs);\n }\n```\n###test function signature:\n```java\n @Test\n public void testMaybeFetchErrorIntervalMs()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMaybeFetchErrorIntervalMs", + "reference": " @Test\n public void testMaybeFetchErrorIntervalMs() {\n assertEquals(Optional.empty(), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.NONE.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_REQUEST.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_RECORD.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SUBSCRIPTION_ID.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNSUPPORTED_COMPRESSION_TYPE.code(), -1));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), 20000));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), 20000));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SERVER_ERROR.code(), -1));\n }", + "completions": [ + "```java\n @Test\n public void testMaybeFetchErrorIntervalMs() {\n // Test case 1: errorCode is Errors.NONE.code()\n short errorCode = Errors.NONE.code();\n int intervalMs = 1000;\n Optional result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertFalse(result.isPresent());\n\n // Test case 2: errorCode is Errors.INVALID_REQUEST.code()\n errorCode = Errors.INVALID_REQUEST.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 3: errorCode is Errors.UNKNOWN_SUBSCRIPTION_ID.code()\n errorCode = Errors.UNKNOWN_SUBSCRIPTION_ID.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 4: errorCode is Errors.TELEMETRY_TOO_LARGE.code()\n errorCode = Errors.TELEMETRY_TOO_LARGE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 5: errorCode is Errors.THROTTLING_QUOTA_EXCEEDED.code()\n errorCode = Errors.THROTTLING_QUOTA_EXCEEDED.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 6: errorCode is Errors.UNSUPPORTED_COMPRESSION_TYPE.code()\n errorCode = Errors.UNSUPPORTED_COMPRESSION_TYPE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 7: errorCode is Errors.INVALID_RECORD.code()\n errorCode = Errors.INVALID_RECORD.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 8: errorCode is Errors.UNSUPPORTED_VERSION.code()\n errorCode = Errors.UNSUPPORTED_VERSION.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 9: errorCode is an unmapped error\n errorCode = (short) 9999;\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n }\n```\n" + ], + "line": 43, + "token": 290, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 12, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_non_serializable_object", + "reference": " @Test\n public void should_clone_non_serializable_object() {\n Object original = new NonSerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_non_serializable_object() {\n // Given\n Map nonSerializableMap = new HashMap<>();\n nonSerializableMap.put(\"key1\", new Object());\n nonSerializableMap.put(\"key2\", new Object());\n\n // When\n Map clonedMap = objectCloner.clone(nonSerializableMap);\n\n // Then\n assertNotSame(nonSerializableMap, clonedMap);\n assertNotSame(nonSerializableMap.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(nonSerializableMap.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 13, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueOperationCompletesOperationWithCachedActionResult", + "reference": " @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ActionKey actionKey = DigestUtil.asActionKey(Digest.newBuilder().setHash(\"test\").build());\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"operation-with-cached-action-result\")\n .setActionDigest(actionKey.getDigest())\n .build();\n\n ActionResult actionResult = ActionResult.getDefaultInstance();\n\n when(mockBackplane.getActionResult(eq(actionKey))).thenReturn(actionResult);\n\n Poller poller = mock(Poller.class);\n\n instance.queue(executeEntry, poller, DEFAULT_TIMEOUT).get();\n\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(CACHE_CHECK));\n verify(mockBackplane, never()).putOperation(any(Operation.class), eq(QUEUED));\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(COMPLETED));\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n ActionResult actionResult = ActionResult.newBuilder().setExitCode(0).build();\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(\"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n when(cache.get(any(ActionKey.class))).thenReturn(actionResult);\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n future.get();\n\n verify(poller).pause();\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 14, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueActionFailsQueueEligibility", + "reference": " @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n Directory inputRoot = Directory.newBuilder().build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(false);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_INVALID)\n .setSubject(INVALID_PLATFORM)\n .setDescription(\n \"properties are not valid for queue eligibility: []. If you think your\"\n + \" queue should still accept these poperties without them being\"\n + \" specified in queue configuration, consider configuring the queue\"\n + \" with `allow_unmatched: True`\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(ACTION_DIGEST));\n when(executeEntry.getStdoutStreamName()).thenReturn(STDOUT_STREAM_NAME);\n when(executeEntry.getStderrStreamName()).thenReturn(STDERR_STREAM_NAME);\n when(executeEntry.getOperationName()).thenReturn(OPERATION_NAME);\n when(executeEntry.getRequestMetadata()).thenReturn(REQUEST_METADATA);\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n when(poller.pause()).thenThrow(new RuntimeException(\"Queue eligibility failed\"));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n\n assertThrows(RuntimeException.class, future::get);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 15, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new SerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new SerializableObject(\"name2\"),\n new SerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key\", new SerializableObject());\n SerializableObject object = new SerializableObject(map);\n\n // When\n SerializableObject clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertNotSame(object.getMap(), clonedObject.getMap());\n assertNotSame(object.getMap().get(\"key\"), clonedObject.getMap().get(\"key\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 16, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForSumAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForSumAgg", + "reference": " @Test\n public void testFullIndexSearchForSumAgg() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new SumAggBuilder(\"test\", TEST_SOURCE_LONG_PROPERTY, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalSum internalSum =\n (InternalSum) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n // 1, 3, 4, 5\n assertThat(internalSum.getValue()).isEqualTo(13);\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForSumAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new SumAggBuilder(\"testField\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertEquals(SumAggBuilder.class, result.getAggregation().getClass());\n assertEquals(\"testField\", ((SumAggBuilder) result.getAggregation()).getField());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 17, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void process(String base64AuthenticatorData, String signature, String clientDataJson, Fido2RegistrationData registration,\n Fido2AuthenticationData authenticationEntity) {\n log.debug(\"Registration: {}\", registration);\n\n AuthData authData = authenticatorDataParser.parseAssertionData(base64AuthenticatorData);\n commonVerifiers.verifyRpIdHash(authData, registration.getDomain());\n\n log.debug(\"User verification option: {}\", authenticationEntity.getUserVerificationOption());\n userVerificationVerifier.verifyUserVerificationOption(authenticationEntity.getUserVerificationOption(), authData);\n\n byte[] clientDataHash = DigestUtils.getSha256Digest().digest(base64Service.urlDecode(clientDataJson));\n\n try {\n int counter = authenticatorDataParser.parseCounter(authData.getCounters());\n commonVerifiers.verifyCounter(registration.getCounter(), counter);\n registration.setCounter(counter);\n\n JsonNode uncompressedECPointNode = dataMapperService.cborReadTree(base64Service.urlDecode(registration.getUncompressedECPoint()));\n PublicKey publicKey = coseService.createUncompressedPointFromCOSEPublicKey(uncompressedECPointNode);\n\n log.debug(\"Uncompressed ECpoint node: {}\", uncompressedECPointNode);\n log.debug(\"EC Public key hex: {}\", Hex.encodeHexString(publicKey.getEncoded()));\n log.debug(\"Registration algorithm: {}, default use: -7\", registration.getSignatureAlgorithm());\n authenticatorDataVerifier.verifyAssertionSignature(authData, clientDataHash, signature, publicKey, -7);\n\n } catch (Fido2CompromisedDevice ex) {\n log.error(\"Error compromised device: {}\", ex.getMessage());\n throw ex;\n } catch (Exception ex) {\n log.error(\"Error to check none assertion: {}\", ex.getMessage());\n throw new Fido2RuntimeException(\"Failed to check none assertion: {}\", ex.getMessage(), ex);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "process_ifCborReadTreeThrowException_fido2RuntimeException", + "reference": " @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n String base64AuthenticatorData = \"base64AuthenticatorData_test\";\n String signature = \"signature_test\";\n String clientDataJson = \"clientDataJson_test\";\n Fido2RegistrationData registration = mock(Fido2RegistrationData.class);\n Fido2AuthenticationData authenticationEntity = mock(Fido2AuthenticationData.class);\n\n when(authenticationEntity.getUserVerificationOption()).thenReturn(UserVerification.preferred);\n when(registration.getDomain()).thenReturn(\"domain_test\");\n when(registration.getCounter()).thenReturn(100);\n when(registration.getUncompressedECPoint()).thenReturn(\"uncompressedECPoint_test\");\n\n when(authenticatorDataParser.parseAssertionData(any())).thenReturn(mock(AuthData.class));\n when(base64Service.urlDecode(any(String.class))).thenReturn(\"decode_test\".getBytes());\n when(dataMapperService.cborReadTree(any(byte[].class))).thenThrow(new IOException(\"IOException_test\"));\n\n Fido2RuntimeException ex = assertThrows(Fido2RuntimeException.class, () -> noneAssertionFormatProcessor.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity));\n assertNotNull(ex);\n assertEquals(ex.getMessage(), \"IOException_test\");\n\n verify(log).debug(eq(\"Registration: {}\"), any(Fido2RegistrationData.class));\n verify(log).debug(eq(\"User verification option: {}\"), any(UserVerification.class));\n verify(commonVerifiers).verifyRpIdHash(any(AuthData.class), any(String.class));\n verify(authenticatorDataParser).parseCounter(any());\n verify(log).error(eq(\"Error to check none assertion: {}\"), any(String.class));\n\n verify(log, never()).error(eq(\"Error compromised device: {}\"), any(String.class));\n verifyNoInteractions(coseService, authenticatorDataVerifier);\n verifyNoMoreInteractions(log);\n }", + "completions": [ + "```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n // Arrange\n String base64AuthenticatorData = \"base64AuthenticatorData\";\n String signature = \"signature\";\n String clientDataJson = \"clientDataJson\";\n Fido2RegistrationData registration = new Fido2RegistrationData();\n Fido2AuthenticationData authenticationEntity = new Fido2AuthenticationData();\n\n when(authenticatorDataParser.parseAssertionData(base64AuthenticatorData)).thenReturn(new AuthData());\n when(authenticatorDataParser.parseCounter(any())).thenReturn(1);\n when(dataMapperService.cborReadTree(any())).thenThrow(new IOException());\n\n // Act & Assert\n Assertions.assertThrows(Fido2RuntimeException.class, () -> {\n fido2AuthenticationService.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity);\n });\n }\n```\n" + ], + "line": 38, + "token": 365, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 18, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void setUserData_invalidKey_throws()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "setUserData_invalidKey_throws", + "reference": "\t@Test\n\tpublic void setUserData_invalidKey_throws() {\n\t\t// setup\n\t\tTestParameters parameters = new TestParameters();\n\n\t\t// test\n\t\tfor (String key : new String[] { null, \"\" }) {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, \"test\");\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid key; key must be non-empty, non-null\"),\n\t\t\t\t\tex.getMessage());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@Test(expected = IllegalArgumentException.class)\n\tpublic void setUserData_invalidKey_throws() {\n\t\tTestClass testClass = new TestClass();\n\t\ttestClass.setUserData(\"\", \"value\");\n\t}\n```\n\nIn this test, we are creating an instance of the TestClass and calling the setUserData method with an empty key and a valid value. We expect this to throw an IllegalArgumentException, which is what we specify in the @Test annotation.\n" + ], + "line": 23, + "token": 170, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueDirectoryMissingErrorsOperation", + "reference": " @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ByteString foo = ByteString.copyFromUtf8(\"foo\");\n Digest subdirDigest = DIGEST_UTIL.compute(foo);\n Directory inputRoot =\n Directory.newBuilder()\n .addDirectories(\n DirectoryNode.newBuilder().setName(\"missing-subdir\").setDigest(subdirDigest))\n .build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(true);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_MISSING)\n .setSubject(\"blobs/\" + DigestUtil.toString(subdirDigest))\n .setDescription(\"The directory `/missing-subdir` was not found in the CAS.\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(10);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(DigestUtil.DEFAULT_INSTANCE_NAME + \"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n ActionKey actionKey = DigestUtil.asActionKey(executeEntry.getActionDigest());\n when(checkCacheFuture(actionKey, any(), any())).thenReturn(immediateFuture(true));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n assertThat(future.get()).isEqualTo(null);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 20, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime = Instant.now();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Start recovery with an offset range that is partially unavailable\n long startOffset = 50;\n long endOffset = 150;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry)).isEqualTo(50);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(51);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 100L,\n 200L,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 21, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public List selectHeaderAccept(String[] accepts) {\n if (accepts.length == 0) {\n return null;\n }\n for (String accept : accepts) {\n MediaType mediaType = MediaType.parseMediaType(accept);\n if (isJsonMime(mediaType) && !isProblemJsonMime(accept)) {\n return Collections.singletonList(mediaType);\n }\n }\n return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts));\n }\n```\n###test function signature:\n```java\n @Test\n public void testSelectHeaderAccept()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSelectHeaderAccept", + "reference": " @Test\n public void testSelectHeaderAccept() {\n String[] accepts = {\"application/json\", \"application/xml\"};\n assertEquals(\"application/json\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"APPLICATION/XML\", \"APPLICATION/JSON\"};\n assertEquals(\"APPLICATION/JSON\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"application/xml\", \"application/json; charset=UTF8\"};\n assertEquals(\"application/json; charset=UTF8\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"text/plain\", \"application/xml\"};\n assertEquals(\"text/plain,application/xml\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{};\n assertNull(apiClient.selectHeaderAccept(accepts));\n }", + "completions": [ + "```java\n @Test\n public void testSelectHeaderAccept() {\n // Given\n String[] accepts = {\"application/json\", \"application/problem+json\"};\n // When\n List result = selectHeaderAccept(accepts);\n // Then\n assertEquals(1, result.size());\n assertEquals(\"application/json\", result.get(0).toString());\n }\n```\n" + ], + "line": 23, + "token": 154, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollEventSentOnConsumerPoll", + "reference": " @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n SubscriptionState subscriptions = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);\n consumer = newConsumer(\n mock(FetchBuffer.class),\n new ConsumerInterceptors<>(Collections.emptyList()),\n mock(ConsumerRebalanceListenerInvoker.class),\n subscriptions,\n singletonList(new RoundRobinAssignor()),\n \"group-id\",\n \"client-id\");\n final TopicPartition tp = new TopicPartition(\"topic\", 0);\n final List> records = singletonList(\n new ConsumerRecord<>(\"topic\", 0, 2, \"key1\", \"value1\"));\n doAnswer(invocation -> Fetch.forPartition(tp, records, true))\n .when(fetchCollector)\n .collectFetch(Mockito.any(FetchBuffer.class));\n\n consumer.subscribe(singletonList(\"topic1\"));\n consumer.poll(Duration.ofMillis(100));\n verify(applicationEventHandler).add(any(PollEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n // Arrange\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n consumer.poll(Duration.ofMillis(100));\n\n // Act\n ConsumerRecords records = consumer.poll(Duration.ofMillis(100));\n\n // Assert\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 23, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetMultiplePartitionRecoveriesBehind", + "reference": " @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n final RecoveryTaskMetadata recoveryTask2 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"2\",\n \"2\",\n recoveryStartOffset * 3 + 1,\n recoveryStartOffset * 4,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask2);\n final RecoveryTaskMetadata recoveryTask21 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"21\", \"2\", recoveryStartOffset * 4 + 1, 50000, createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask21);\n await()\n .until(\n () ->\n recoveryTaskStore\n .listSync()\n .containsAll(\n List.of(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(5);\n assertThat(recoveryTasks)\n .contains(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n // Given\n String partitionId = \"partition1\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 24, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture commitStreamObject(apache.rocketmq.controller.v1.S3StreamObject streamObject,\n List compactedObjects) {\n LOGGER.info(\"commitStreamObject with streamObject: {}, compactedObjects: {}\", TextFormat.shortDebugString(streamObject),\n compactedObjects);\n\n CompletableFuture future = new CompletableFuture<>();\n try (SqlSession session = sessionFactory.openSession()) {\n if (streamObject.getObjectId() == S3Constants.NOOP_OBJECT_ID) {\n LOGGER.error(\"S3StreamObject[object-id={}] is null or objectId is unavailable\", streamObject.getObjectId());\n String msg = String.format(\"S3StreamObject[object-id=%d] is null or objectId is unavailable\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.NOT_FOUND_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n\n long committedTs = System.currentTimeMillis();\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n\n // commit object\n if (!commitObject(streamObject.getObjectId(), streamObject.getStreamId(), streamObject.getObjectSize(), session)) {\n String msg = String.format(\"S3StreamObject[object-id=%d] is not ready for commit\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.ILLEGAL_STATE_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n long dataTs = committedTs;\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n dataTs = compactedObjects.stream()\n .map(id -> {\n // mark destroy compacted object\n S3Object object = s3ObjectMapper.getById(id);\n object.setState(S3ObjectState.BOS_WILL_DELETE);\n object.setMarkedForDeletionTimestamp(new Date());\n s3ObjectMapper.markToDelete(object.getId(), new Date());\n\n // update dataTs to the min compacted object's dataTs\n com.automq.rocketmq.metadata.dao.S3StreamObject s3StreamObject =\n s3StreamObjectMapper.getByObjectId(id);\n return s3StreamObject.getBaseDataTimestamp().getTime();\n })\n .min(Long::compareTo).get();\n }\n\n List toCache = new ArrayList<>();\n\n // create a new S3StreamObject to replace committed ones\n if (streamObject.getObjectId() != S3Constants.NOOP_OBJECT_ID) {\n com.automq.rocketmq.metadata.dao.S3StreamObject newS3StreamObj =\n new com.automq.rocketmq.metadata.dao.S3StreamObject();\n newS3StreamObj.setStreamId(streamObject.getStreamId());\n newS3StreamObj.setObjectId(streamObject.getObjectId());\n newS3StreamObj.setObjectSize(streamObject.getObjectSize());\n newS3StreamObj.setStartOffset(streamObject.getStartOffset());\n newS3StreamObj.setEndOffset(streamObject.getEndOffset());\n newS3StreamObj.setBaseDataTimestamp(new Date(dataTs));\n newS3StreamObj.setCommittedTimestamp(new Date(committedTs));\n s3StreamObjectMapper.create(newS3StreamObj);\n toCache.add(newS3StreamObj);\n }\n\n // delete the compactedObjects of S3Stream\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n compactedObjects.forEach(id -> s3StreamObjectMapper.delete(null, null, id));\n }\n session.commit();\n\n // Update Cache\n s3StreamObjectCache.cache(streamObject.getStreamId(), toCache);\n s3StreamObjectCache.onCompact(streamObject.getStreamId(), compactedObjects);\n\n LOGGER.info(\"S3StreamObject[object-id={}] commit success, compacted objects: {}\",\n streamObject.getObjectId(), compactedObjects);\n future.complete(null);\n } catch (Exception e) {\n LOGGER.error(\"CommitStream failed\", e);\n ControllerException ex = new ControllerException(Code.INTERNAL_VALUE, \"CommitStream failed\" + e.getMessage());\n future.completeExceptionally(ex);\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCommitStreamObject() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCommitStreamObject", + "reference": " @Test\n public void testCommitStreamObject() throws IOException {\n long objectId, streamId = 1;\n\n try (S3MetadataService metadataService = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n objectId = metadataService.prepareS3Objects(3, 5).get();\n } catch (Exception e) {\n throw new RuntimeException(e);\n }\n\n S3StreamObject news3StreamObject = S3StreamObject.newBuilder()\n .setObjectId(objectId + 2)\n .setStreamId(streamId)\n .setObjectSize(111L)\n .build();\n\n try (SqlSession session = this.getSessionFactory().openSession()) {\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n buildS3StreamObjs(objectId, 2, 3, 100L).forEach(s3StreamObjectMapper::create);\n session.commit();\n }\n\n try (S3MetadataService service = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n List compactedObjects = new ArrayList<>();\n compactedObjects.add(objectId);\n compactedObjects.add(objectId + 1);\n service.commitStreamObject(news3StreamObject, compactedObjects);\n }\n\n try (SqlSession session = getSessionFactory().openSession()) {\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3Object s3Object = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object.getState());\n\n S3Object s3Object1 = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object1.getState());\n\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n for (long index = objectId; index < objectId + 2; index++) {\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(index);\n Assertions.assertNull(object);\n }\n\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(objectId + 2);\n Assertions.assertEquals(111L, object.getObjectSize());\n Assertions.assertEquals(streamId, object.getStreamId());\n Assertions.assertTrue(object.getBaseDataTimestamp().getTime() > 0);\n Assertions.assertTrue(object.getCommittedTimestamp().getTime() > 0);\n }\n }", + "completions": [ + "```java\n @Test\n public void testCommitStreamObject() throws IOException {\n // Given\n apache.rocketmq.controller.v1.S3StreamObject streamObject = new apache.rocketmq.controller.v1.S3StreamObject();\n streamObject.setObjectId(1L);\n streamObject.setStreamId(2L);\n streamObject.setObjectSize(3L);\n streamObject.setStartOffset(4L);\n streamObject.setEndOffset(5L);\n\n List compactedObjects = Arrays.asList(6L, 7L, 8L);\n\n // When\n CompletableFuture future = controller.commitStreamObject(streamObject, compactedObjects);\n\n // Then\n assertDoesNotThrow(() -> future.get());\n }\n```\n" + ], + "line": 85, + "token": 664, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 25, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyTooManyResults", + "reference": " @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(totalHits.value).thenReturn(2L);\n Mockito.when(hits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(response.getHits()).thenReturn(hits);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processService.getProcessByKey(processDefinitionKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Expected exception message to match\",\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n e.getMessage());\n }\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 26, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String doGet(String url) {\n try (CloseableHttpClient client = HttpClients.createDefault()) {\n HttpGet request = new HttpGet(url);\n HttpResponse response = client.execute(request);\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n return EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n }\n } catch (IOException e) {\n log.error(\"DoGet exception.\", e);\n\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n void testDoGet()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDoGet", + "reference": " @Test\n void testDoGet() {\n String result = HttpUtil.doGet(\"http://baidu.com\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testDoGet() {\n String url = \"http://example.com\";\n String expectedResponse = \"Hello, world!\";\n\n // Mock the HttpClient and HttpResponse\n CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);\n HttpResponse mockResponse = Mockito.mock(HttpResponse.class);\n StatusLine mockStatusLine = Mockito.mock(StatusLine.class);\n HttpEntity mockEntity = Mockito.mock(HttpEntity.class);\n\n // Set up the mock objects to return the expected values\n Mockito.when(mockClient.execute(Mockito.any(HttpGet.class))).thenReturn(mockResponse);\n Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);\n Mockito.when(mockStatusLine.getStatusCode()).thenReturn(HttpStatus.SC_OK);\n Mockito.when(mockResponse.getEntity()).thenReturn(mockEntity);\n Mockito.when(EntityUtils.toString(mockEntity, StandardCharsets.UTF_8)).thenReturn(expectedResponse);\n\n // Call the method under test\n String actualResponse = YourClassName.doGet(url);\n\n // Verify the result\n Assertions.assertEquals(expectedResponse, actualResponse);\n }\n```\nPlease replace `YourClassName` with the actual name of the class where the `doGet` method is located.\n" + ], + "line": 23, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n dpopService.validateDpopThumprint(null, \"any\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n // Given\n String existingThumprint = \"\";\n String requestThumprint = \"requestThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n try {\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), e.getResponse().getStatus());\n assertEquals(\"Thumprint does not match\", e.getMessage());\n }\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDataMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDataMapTest", + "reference": " @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"hera_log_doris_table_2_11\";\n for (int i = 0; i < 1; i++) {\n String jsonString = \"{\\\"linenumber\\\":1256,\\\"tailId\\\":90219,\\\"mqtag\\\":\\\"tags_4_13_90219\\\",\\\"filename\\\":\\\"/home/work/log/log-agent/server.log\\\",\\\"tail\\\":\\\"demo-client-agent\\\",\\\"mqtopic\\\":\\\"90219_hera-demo-client\\\",\\\"message\\\":\\\"2024-01-10 19:34:40,360|INFO ||NettyClientPublicExecutor_1|c.x.mone.log.agent.rpc.task.PingTask|83|ping res: log-agent-server:2022-12-05:0.0.2->2024-01-10 19:34:40 358->10.53.129.250:9899\\\",\\\"logstore\\\":\\\"测试doris日志\\\",\\\"logip\\\":\\\"10.53.129.176\\\",\\\"timestamp\\\":1704886481241}\";\n\n Map map = gson.fromJson(jsonString, Map.class);\n dorisStreamLoad.sendData(database, tableName, columnListNew, map);\n }\n }", + "completions": [ + "```java\n @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 29, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyNoResults", + "reference": " @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final NotFoundException exception =\n Assertions.assertThrows(\n NotFoundException.class, () -> processService.getProcessByKey(processDefinitionKey));\n\n // Then\n Assertions.assertEquals(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 30, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String getDiagramByKey(Long processDefinitionKey) {\n final IdsQueryBuilder q = idsQuery().addIds(processDefinitionKey.toString());\n\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n\n if (response.getHits().getTotalHits().value == 1) {\n final Map result = response.getHits().getHits()[0].getSourceAsMap();\n return (String) result.get(BPMN_XML);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with id '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with id '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\n \"Exception occurred, while obtaining the process diagram: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetDiagramByKeyTooManyResults", + "reference": " @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getDiagramByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 1L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHits.Relation.EQUAL_TO));\n\n // When\n try {\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n processDiagramService.getDiagramByKey(processDefinitionKey);\n } catch (NotFoundException e) {\n // Then\n Assert.assertEquals(\n \"Could not find unique process with id '\" + processDefinitionKey + \"'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 37, + "token": 309, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 31, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n void shutdown() {\n LOG.info(\"Running shutdown hook.\");\n try {\n serviceManager.stopAsync().awaitStopped(30, TimeUnit.SECONDS);\n } catch (Exception e) {\n // stopping timed out\n LOG.error(\"ServiceManager shutdown timed out\", e);\n }\n try {\n curatorFramework.unwrap().close();\n } catch (Exception e) {\n LOG.error(\"Error while closing curatorFramework \", e);\n }\n LOG.info(\"Shutting down LogManager\");\n LogManager.shutdown();\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexerShutdownTwice() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexerShutdownTwice", + "reference": " @Test\n public void testIndexerShutdownTwice() throws Exception {\n startKafkaServer();\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // Create a live partition for this partiton\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTimeMs = 1;\n final long endTimeMs = 100;\n final long maxOffset = 30;\n SnapshotMetadata livePartition0 =\n new SnapshotMetadata(\n name + \"live0\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"0\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition0);\n\n SnapshotMetadata livePartition1 =\n new SnapshotMetadata(\n name + \"live1\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"1\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition1);\n\n final SnapshotMetadata partition0 =\n new SnapshotMetadata(name, path, startTimeMs, endTimeMs, maxOffset, \"0\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition0);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .containsOnly(livePartition1, livePartition0, partition0);\n\n // Empty consumer offset since there is no prior consumer.\n astraIndexer =\n new AstraIndexer(\n chunkManagerUtil.chunkManager,\n curatorFramework,\n makeIndexerConfig(50),\n getKafkaConfig(),\n metricsRegistry);\n astraIndexer.startAsync();\n astraIndexer.awaitRunning(DEFAULT_START_STOP_DURATION);\n await().until(() -> kafkaServer.getConnectedConsumerGroups() == 1);\n\n // Produce more messages since the recovery task is created for head.\n produceMessagesToKafka(kafkaServer.getBroker(), startTime);\n\n consumeMessagesAndSearchMessagesTest(100, 1);\n\n // Live snapshot is deleted, recovery task is created.\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(livePartition1, partition0);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isEqualTo(4);\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).size()).isEqualTo(1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(searchMetadataStore).size()).isEqualTo(1);\n RecoveryTaskMetadata recoveryTask1 =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).get(0);\n assertThat(recoveryTask1.startOffset).isEqualTo(31);\n assertThat(recoveryTask1.endOffset).isEqualTo(99);\n assertThat(recoveryTask1.partitionId).isEqualTo(\"0\");\n\n // Shutting down is idempotent. So, doing it twice shouldn't throw an error.\n astraIndexer.shutDown();\n astraIndexer.shutDown();\n astraIndexer = null;\n }", + "completions": [ + "```java\n @Test\n public void testIndexerShutdownTwice() throws Exception {\n // Given\n Indexer indexer = new Indexer();\n indexer.serviceManager = Mockito.mock(ServiceManager.class);\n indexer.curatorFramework = Mockito.mock(CuratorFramework.class);\n Mockito.when(indexer.serviceManager.stopAsync()).thenReturn(Mockito.mock(ServiceManager.class));\n Mockito.when(indexer.serviceManager.awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class))).thenThrow(new TimeoutException());\n Mockito.when(indexer.curatorFramework.unwrap()).thenReturn(Mockito.mock(Closeable.class));\n Mockito.doThrow(new IOException()).when(indexer.curatorFramework.unwrap()).close();\n\n // When\n indexer.shutdown();\n\n // Then\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).stopAsync();\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class));\n Mockito.verify(indexer.curatorFramework.unwrap(), Mockito.times(1)).close();\n }\n```\n" + ], + "line": 27, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(addressList,\"test\",\"test\",\"test\", SystemEnum.AGENT, MsgTypeEnum.ALARM);\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = null;\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 33, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyWithException", + "reference": " @Test\n public void testGetProcessByKeyWithException() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(OperateRuntimeException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(totalHits.value).thenReturn(2L);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final Executable executable = () -> processService.getProcessByKey(processDefinitionKey);\n\n // Then\n final NotFoundException exception =\n Assertions.assertThrows(NotFoundException.class, executable);\n Assertions.assertEquals(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 34, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testConcurrentSearches() throws InterruptedException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testConcurrentSearches", + "reference": " @Test\n public void testConcurrentSearches() throws InterruptedException {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n AtomicInteger searchFailures = new AtomicInteger(0);\n AtomicInteger statsFailures = new AtomicInteger(0);\n AtomicInteger searchExceptions = new AtomicInteger(0);\n AtomicInteger successfulRuns = new AtomicInteger(0);\n\n Runnable searchRun =\n () -> {\n for (int i = 0; i < 100; i++) {\n try {\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n 0L,\n MAX_TIME,\n 100,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n if (babies.hits.size() != 2) {\n searchFailures.addAndGet(1);\n } else {\n successfulRuns.addAndGet(1);\n }\n } catch (Exception e) {\n searchExceptions.addAndGet(1);\n }\n }\n };\n\n Thread t1 = new Thread(searchRun);\n Thread t2 = new Thread(searchRun);\n t1.start();\n t2.start();\n t1.join();\n t2.join();\n assertThat(searchExceptions.get()).isEqualTo(0);\n assertThat(statsFailures.get()).isEqualTo(0);\n assertThat(searchFailures.get()).isEqualTo(0);\n assertThat(successfulRuns.get()).isEqualTo(200);\n }", + "completions": [ + "```java\n @Test\n public void testConcurrentSearches() throws InterruptedException {\n // Create a fixed thread pool with 10 threads\n ExecutorService executorService = Executors.newFixedThreadPool(10);\n\n // Create a list to hold the futures\n List>> futures = new ArrayList<>();\n\n // Submit 100 search tasks to the executor service\n for (int i = 0; i < 100; i++) {\n // Create a search task\n Callable> searchTask = () -> {\n // Generate a random dataset\n String dataset = UUID.randomUUID().toString();\n\n // Generate a random query string\n String queryStr = UUID.randomUUID().toString();\n\n // Generate a random start time\n long startTimeMsEpoch = System.currentTimeMillis() - ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random end time\n long endTimeMsEpoch = startTimeMsEpoch + ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random number of hits requested\n int howMany = ThreadLocalRandom.current().nextInt(100);\n\n // Generate a random aggregation builder\n AggBuilder aggBuilder = new AggBuilder();\n\n // Perform the search\n return logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n };\n\n // Submit the search task to the executor service and add the future to the list\n futures.add(executorService.submit(searchTask));\n }\n\n // Wait for all search tasks to complete\n for (Future> future : futures) {\n future.get();\n }\n\n // Shutdown the executor service\n executorService.shutdown();\n }\n```\n" + ], + "line": 85, + "token": 608, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 35, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String sendGetData(String path, Map headMap) {\n log.info(\"request tcm, path:{},headMap:{} \",path,headMap);\n String result = \"\";\n CloseableHttpResponse response =null;\n try(CloseableHttpClient httpClient = HttpClients.createDefault()) {\n // 创建get方式请求对象\n HttpGet httpGet = new HttpGet(path);\n if (null != headMap) {\n for (Map.Entry entry : headMap.entrySet()) {\n httpGet.addHeader(entry.getKey(), entry.getValue());\n }\n }\n // 通过请求对象获取响应对象\n response = httpClient.execute(httpGet);\n\n // 获取结果实体\n // 判断网络连接状态码是否正常(0--200都数正常)\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n result = EntityUtils.toString(response.getEntity(), UTF_8);\n } else {\n log.error(\"get请求传输 异常.url:{}, headMap:{}\", path, headMap);\n }\n } catch (ClientProtocolException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n } catch (IOException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n }\n finally {\n try {\n if (null!=response){\n response.close();\n }\n } catch (IOException e) {\n log.error(\"关闭response 异常 \", e);\n }\n }\n log.debug(result);\n return result;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendGetData()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendGetData", + "reference": " @Test\n void testSendGetData() {\n final Map headMap = new HashMap<>();\n final String result = HttpUtils.sendGetData(\"path\", headMap);\n assertThat(result).isEqualTo(\"\");\n }", + "completions": [ + "```java\n @Test\n void testSendGetData() {\n String path = \"http://localhost:8080/api/test\";\n Map headMap = new HashMap<>();\n headMap.put(\"Content-Type\", \"application/json\");\n headMap.put(\"Authorization\", \"Bearer your_token\");\n\n String result = sendGetData(path, headMap);\n\n // Add assertions to verify the result\n // For example, if the result is expected to be a JSON string, you can use a JSON library to parse the result and assert the values\n // Here is an example using the JSON library Jackson\n ObjectMapper objectMapper = new ObjectMapper();\n JsonNode jsonNode = objectMapper.readTree(result);\n assertEquals(\"expectedValue\", jsonNode.get(\"key\").asText());\n }\n```\nPlease replace \"your_token\" with the actual token you want to use for the test. Also, replace \"expectedValue\" and \"key\" with the actual expected value and key from the JSON response.\n" + ], + "line": 51, + "token": 367, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 36, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ByteString retainKeyPrefix(ByteString tenantNS, List topicFilterLevels) {\n ByteString prefix = ByteString.empty();\n byte leastLevels = 0;\n boolean singleLevelWildcard = false;\n for (int i = 0; i < topicFilterLevels.size(); i++) {\n String tfl = topicFilterLevels.get(i);\n if (\"+\".equals(tfl)) {\n leastLevels++;\n singleLevelWildcard = true;\n continue;\n }\n if (\"#\".equals(tfl)) {\n break;\n }\n leastLevels++;\n if (!singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(tfl));\n }\n if (i + 1 < topicFilterLevels.size()) {\n if (!topicFilterLevels.get(i + 1).equals(\"#\") && !singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(NUL));\n }\n }\n }\n return tenantNS.concat(unsafeWrap(new byte[] {leastLevels})).concat(prefix);\n }\n```\n###test function signature:\n```java\n @Test\n public void testRetainKeyPrefix()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRetainKeyPrefix", + "reference": " @Test\n public void testRetainKeyPrefix() {\n String tenantId = \"tenantA\";\n ByteString tenantNS = tenantNS(tenantId);\n assertEquals(retainKeyPrefix(tenantNS, parse(\"#\", false)), tenantNS.concat(levelByte(0)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/#\", false)),\n tenantNS.concat(levelByte(1).concat(copyFromUtf8(\"a\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/+\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(\"a\").concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/b\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)).concat(copyFromUtf8(\"b\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/#\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+\", false)),\n tenantNS.concat(\n levelByte(3).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+/+\", false)),\n tenantNS.concat(\n levelByte(4).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/\", false)),\n tenantNS.concat(levelByte(4).concat(copyFromUtf8(NUL))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/+/\", false)),\n tenantNS.concat(levelByte(5).concat(copyFromUtf8(NUL))));\n }", + "completions": [ + "```java\n @Test\n public void testRetainKeyPrefix() {\n ByteString tenantNS = ByteString.copyFromUtf8(\"tenant1\");\n List topicFilterLevels = Arrays.asList(\"level1\", \"level2\", \"+\", \"#\");\n ByteString expectedPrefix = ByteString.copyFromUtf8(\"tenant1\\u0001level1\\u0000level2\");\n ByteString actualPrefix = retainKeyPrefix(tenantNS, topicFilterLevels);\n assertEquals(expectedPrefix, actualPrefix);\n }\n```\n" + ], + "line": 37, + "token": 255, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 37, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n @Transactional(rollbackFor = Exception.class)\n public void updateFileConfigMaster(Long id) {\n // 校验存在\n validateFileConfigExists(id);\n // 更新其它为非 master\n fileConfigMapper.updateBatch(new FileConfigDO().setMaster(false));\n // 更新\n fileConfigMapper.updateById(new FileConfigDO().setId(id).setMaster(true));\n // 发送刷新配置的消息\n TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {\n\n @Override\n public void afterCommit() {\n fileConfigProducer.sendFileConfigRefreshMessage();\n }\n\n });\n }\n```\n###test function signature:\n```java\n @Test\n public void testUpdateFileConfigMaster_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testUpdateFileConfigMaster_success", + "reference": " @Test\n public void testUpdateFileConfigMaster_success() {\n // mock 数据\n FileConfigDO dbFileConfig = randomFileConfigDO().setMaster(false);\n fileConfigMapper.insert(dbFileConfig);// @Sql: 先插入出一条存在的数据\n FileConfigDO masterFileConfig = randomFileConfigDO().setMaster(true);\n fileConfigMapper.insert(masterFileConfig);// @Sql: 先插入出一条存在的数据\n\n // 调用\n fileConfigService.updateFileConfigMaster(dbFileConfig.getId());\n // 断言数据\n assertTrue(fileConfigMapper.selectById(dbFileConfig.getId()).getMaster());\n assertFalse(fileConfigMapper.selectById(masterFileConfig.getId()).getMaster());\n // verify 调用\n verify(fileConfigProducer).sendFileConfigRefreshMessage();\n }", + "completions": [ + "```java\n @Test\n public void testUpdateFileConfigMaster_success() {\n // 准备数据\n FileConfigDO fileConfigDO = new FileConfigDO();\n fileConfigDO.setId(1L);\n fileConfigDO.setMaster(false);\n fileConfigMapper.insert(fileConfigDO);\n\n // 执行方法\n fileConfigService.updateFileConfigMaster(1L);\n\n // 验证结果\n FileConfigDO updatedFileConfigDO = fileConfigMapper.selectById(1L);\n Assert.assertTrue(updatedFileConfigDO.getMaster());\n\n // 清理数据\n fileConfigMapper.deleteById(1L);\n }\n```\n" + ], + "line": 28, + "token": 164, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\t@Override\n\tpublic void add(List documents) {\n\n\t\tAssert.notNull(documents, \"The document list should not be null.\");\n\t\tif (CollectionUtils.isEmpty(documents)) {\n\t\t\treturn; // nothing to do;\n\t\t}\n\n\t\tfinal var searchDocuments = documents.stream().map(document -> {\n\t\t\tfinal var embeddings = this.embeddingClient.embed(document);\n\t\t\tSearchDocument searchDocument = new SearchDocument();\n\t\t\tsearchDocument.put(ID_FIELD_NAME, document.getId());\n\t\t\tsearchDocument.put(EMBEDDING_FIELD_NAME, embeddings);\n\t\t\tsearchDocument.put(CONTENT_FIELD_NAME, document.getContent());\n\t\t\tsearchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());\n\n\t\t\t// Add the filterable metadata fields as top level fields, allowing filler\n\t\t\t// expressions on them.\n\t\t\tfor (MetadataField mf : this.filterMetadataFields) {\n\t\t\t\tif (document.getMetadata().containsKey(mf.name())) {\n\t\t\t\t\tsearchDocument.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name()));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn searchDocument;\n\t\t}).toList();\n\n\t\tIndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);\n\n\t\tfor (IndexingResult indexingResult : result.getResults()) {\n\t\t\tAssert.isTrue(indexingResult.isSucceeded(),\n\t\t\t\t\tString.format(\"Document with key %s upload is not successfully\", indexingResult.getKey()));\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void addAndSearchTest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "addAndSearchTest", + "reference": "\t@Test\n\tpublic void addAndSearchTest() {\n\n\t\tcontextRunner.run(context -> {\n\n\t\t\tVectorStore vectorStore = context.getBean(VectorStore.class);\n\n\t\t\tvectorStore.add(documents);\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\t\t\t}, hasSize(1));\n\n\t\t\tList results = vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\n\t\t\tassertThat(results).hasSize(1);\n\t\t\tDocument resultDoc = results.get(0);\n\t\t\tassertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());\n\t\t\tassertThat(resultDoc.getContent()).contains(\"The Great Depression (1929–1939) was an economic shock\");\n\t\t\tassertThat(resultDoc.getMetadata()).hasSize(2);\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"meta2\");\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"distance\");\n\n\t\t\t// Remove all documents from the store\n\t\t\tvectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Hello\").withTopK(1));\n\t\t\t}, hasSize(0));\n\t\t});\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void addAndSearchTest() {\n\t\t// Given\n\t\tList documents = new ArrayList<>();\n\t\tDocument document1 = new Document();\n\t\tdocument1.setId(\"1\");\n\t\tdocument1.setContent(\"This is a test document\");\n\t\tdocument1.setMetadata(new HashMap<>());\n\t\tdocument1.getMetadata().put(\"author\", \"John Doe\");\n\t\tdocument1.getMetadata().put(\"date\", \"2022-01-01\");\n\t\tdocuments.add(document1);\n\n\t\tDocument document2 = new Document();\n\t\tdocument2.setId(\"2\");\n\t\tdocument2.setContent(\"This is another test document\");\n\t\tdocument2.setMetadata(new HashMap<>());\n\t\tdocument2.getMetadata().put(\"author\", \"Jane Doe\");\n\t\tdocument2.getMetadata().put(\"date\", \"2022-01-02\");\n\t\tdocuments.add(document2);\n\n\t\t// When\n\t\tsearchService.add(documents);\n\n\t\t// Then\n\t\tSearchRequest searchRequest = new SearchRequest();\n\t\tsearchRequest.setQuery(\"test\");\n\t\tsearchRequest.setFilter(new HashMap<>());\n\t\tsearchRequest.getFilter().put(\"author\", \"John Doe\");\n\t\tSearchResponse searchResponse = searchService.search(searchRequest);\n\n\t\tAssert.assertEquals(1, searchResponse.getDocuments().size());\n\t\tAssert.assertEquals(\"1\", searchResponse.getDocuments().get(0).getId());\n\t}\n```\n" + ], + "line": 39, + "token": 312, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 39, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {\n long startTime = System.currentTimeMillis();\n\n // 校验参数\n try {\n validateParams(metrics);\n } catch (Exception e) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(e.getMessage());\n return;\n }\n\n HttpContext httpContext = createHttpContext(metrics.getNginx());\n HttpUriRequest request = createHttpRequest(metrics.getNginx());\n try (CloseableHttpResponse response = CommonHttpClient.getHttpClient().execute(request, httpContext)){\n // 发起http请求,获取响应数据\n int statusCode = response.getStatusLine().getStatusCode();\n if (statusCode != SUCCESS_CODE) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(\"StatusCode \" + statusCode);\n return;\n }\n String resp = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n\n Long responseTime = System.currentTimeMillis() - startTime;\n // 根据metrics name选择调用不同解析方法\n if (NGINX_STATUS_NAME.equals(metrics.getName()) || AVAILABLE.equals(metrics.getName())) {\n parseNginxStatusResponse(builder, resp, metrics, responseTime);\n } else if (REQ_STATUS_NAME.equals(metrics.getName())) {\n parseReqStatusResponse(builder, resp, metrics, responseTime);\n }\n } catch (Exception e) {\n String errorMsg = CommonUtil.getMessageFromThrowable(e);\n log.info(errorMsg);\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(errorMsg);\n } finally {\n if (request != null) {\n request.abort();\n }\n }\n\n }\n```\n###test function signature:\n```java\n @Test\n public void testNginxCollect()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNginxCollect", + "reference": " @Test\n public void testNginxCollect() {\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n long monitorId = 999;\n String app = \"testNginx\";\n\n Metrics metrics = new Metrics();\n metrics.setNginx(nginxProtocol);\n nginxCollect.collect(builder, monitorId, app, metrics);\n }", + "completions": [ + "```java\n @Test\n public void testNginxCollect() {\n // 创建一个Metrics对象\n Metrics metrics = new Metrics();\n metrics.setName(\"nginx_status\");\n metrics.setNginx(new Nginx(\"http://localhost:8080/nginx_status\"));\n\n // 创建一个CollectRep.MetricsData.Builder对象\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n\n // 创建一个NginxCollect对象\n NginxCollect nginxCollect = new NginxCollect();\n\n // 调用collect方法\n nginxCollect.collect(builder, 1, \"test\", metrics);\n\n // 验证结果\n assertEquals(CollectRep.Code.SUCCESS, builder.getCode());\n assertNotNull(builder.getData());\n assertTrue(builder.getData().getMetrics().containsKey(\"active_connections\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"accepts\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"handled\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"requests\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"reading\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"writing\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"waiting\"));\n }\n```\n" + ], + "line": 51, + "token": 335, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 40, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Transactional(rollbackOn = Exception.class)\n public void submit(@NonNull PartitionPlanConfig partitionPlanConfig)\n throws SchedulerException, ClassNotFoundException {\n Long databaseId = partitionPlanConfig.getDatabaseId();\n Validate.notNull(databaseId, \"DatabaseId can not be null\");\n // disable all related partition plan task\n Database database = this.databaseService.detail(databaseId);\n disablePartitionPlan(database.getId());\n PartitionPlanEntity partitionPlanEntity = modelToEntity(partitionPlanConfig);\n partitionPlanEntity = this.partitionPlanRepository.save(partitionPlanEntity);\n if (!partitionPlanConfig.isEnabled()\n || CollectionUtils.isEmpty(partitionPlanConfig.getPartitionTableConfigs())) {\n log.info(\"Partition plan is disabled or table config is empty, do nothing and return\");\n return;\n }\n Validate.isTrue(partitionPlanConfig.getCreationTrigger() != null, \"Creation trigger can not be null\");\n if (partitionPlanConfig.getDroppingTrigger() == null) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(partitionPlanConfig.getPartitionTableConfigs(),\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n return;\n }\n Map> strategy2TblCfgs =\n partitionPlanConfig.getPartitionTableConfigs().stream().flatMap(tableConfig -> {\n Map> strategy2Cfgs =\n tableConfig.getPartitionKeyConfigs().stream()\n .collect(Collectors.groupingBy(PartitionPlanKeyConfig::getStrategy));\n return strategy2Cfgs.values().stream().map(cfgs -> {\n PartitionPlanTableConfig cfg = new PartitionPlanTableConfig();\n cfg.setPartitionKeyConfigs(cfgs);\n cfg.setTableName(tableConfig.getTableName());\n cfg.setEnabled(tableConfig.isEnabled());\n cfg.setPartitionNameInvoker(tableConfig.getPartitionNameInvoker());\n cfg.setPartitionNameInvokerParameters(tableConfig.getPartitionNameInvokerParameters());\n return cfg;\n });\n }).collect(Collectors.groupingBy(cfg -> cfg.getPartitionKeyConfigs().get(0).getStrategy()));\n List createConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.CREATE);\n List dropConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.DROP);\n if (CollectionUtils.isNotEmpty(createConfigs)) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(createConfigs,\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n if (CollectionUtils.isNotEmpty(dropConfigs)) {\n ScheduleEntity dropScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getDroppingTrigger());\n createPartitionPlanTables(dropConfigs,\n partitionPlanEntity.getId(), dropScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "submit_bothCreateAndDropTrigger_submitSucceed", + "reference": " @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n PartitionPlanTableConfig tableConfig = new PartitionPlanTableConfig();\n tableConfig.setTableName(MYSQL_REAL_RANGE_TABLE_NAME);\n tableConfig.setPartitionNameInvoker(\"CUSTOM_PARTITION_NAME_GENERATOR\");\n SqlExprBasedGeneratorConfig config = new SqlExprBasedGeneratorConfig();\n config.setGenerateExpr(\"concat('p', date_format(from_unixtime(unix_timestamp(\"\n + \"STR_TO_DATE(20240125, '%Y%m%d')) + \"\n + PartitionPlanVariableKey.INTERVAL.getVariable() + \"), '%Y%m%d'))\");\n config.setIntervalGenerateExpr(\"86400\");\n tableConfig.setPartitionNameInvokerParameters(getSqlExprBasedNameGeneratorParameters(config));\n PartitionPlanKeyConfig c3Create = getMysqlc3CreateConfig();\n PartitionPlanKeyConfig datekeyCreate = getMysqldatekeyCreateConfig();\n PartitionPlanKeyConfig dropConfig = getDropConfig();\n tableConfig.setPartitionKeyConfigs(Arrays.asList(c3Create, datekeyCreate, dropConfig));\n\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setPartitionTableConfigs(Collections.singletonList(tableConfig));\n partitionPlanConfig.setFlowInstanceId(1L);\n partitionPlanConfig.setTimeoutMillis(180000L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setDatabaseId(1L);\n\n long t1 = System.currentTimeMillis();\n long t2 = System.currentTimeMillis() + 1000;\n TriggerConfig createTrigger = TestRandom.nextObject(TriggerConfig.class);\n createTrigger.setStartAt(new Date(t1));\n partitionPlanConfig.setCreationTrigger(createTrigger);\n\n TriggerConfig dropTrigger = TestRandom.nextObject(TriggerConfig.class);\n dropTrigger.setStartAt(new Date(t2));\n partitionPlanConfig.setDroppingTrigger(dropTrigger);\n Database database = TestRandom.nextObject(Database.class);\n database.setId(1L);\n Mockito.when(this.databaseService.detail(1L)).thenReturn(database);\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t1;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t2;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.doNothing().when(this.scheduleService).enable(Mockito.isA(ScheduleEntity.class));\n this.partitionPlanScheduleService.submit(partitionPlanConfig);\n List actuals = this.partitionPlanTableRepository.findAll();\n Set scheduleIds = actuals.stream()\n .map(PartitionPlanTableEntity::getScheduleId).collect(Collectors.toSet());\n Assert.assertEquals(scheduleIds.size(), 2);\n }", + "completions": [ + "```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n // Given\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setDatabaseId(1L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanConfig.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanConfig.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanConfig.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanConfig.setTaskId(\"taskId\");\n partitionPlanConfig.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanConfig.setTimeoutMillis(1000L);\n\n Database database = new Database();\n database.setId(1L);\n database.setName(\"database\");\n database.setType(DatabaseType.MYSQL);\n database.setHost(\"localhost\");\n database.setPort(3306);\n database.setUsername(\"root\");\n database.setPassword(\"password\");\n database.setDatabaseName(\"test\");\n database.setProperties(\"{}\");\n database.setStatus(DatabaseStatus.ONLINE);\n database.setCreateTime(new Date());\n database.setUpdateTime(new Date());\n when(databaseService.detail(1L)).thenReturn(database);\n\n PartitionPlanEntity partitionPlanEntity = new PartitionPlanEntity();\n partitionPlanEntity.setId(1L);\n partitionPlanEntity.setDatabaseId(1L);\n partitionPlanEntity.setEnabled(true);\n partitionPlanEntity.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanEntity.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanEntity.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanEntity.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanEntity.setTaskId(\"taskId\");\n partitionPlanEntity.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanEntity.setTimeoutMillis(1000L);\n when(partitionPlanRepository.save(any(PartitionPlanEntity.class))).thenReturn(partitionPlanEntity);\n\n ScheduleEntity createScheduleEntity = new ScheduleEntity();\n createScheduleEntity.setId(1L);\n createScheduleEntity.setDatabaseId(1L);\n createScheduleEntity.setTriggerConfig(new TriggerConfig(TriggerType" + ], + "line": 70, + "token": 530, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 41, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAllQueryWithFullTextSearchDisabled", + "reference": " @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n Instant time = Instant.now();\n Trace.KeyValue customField =\n Trace.KeyValue.newBuilder()\n .setVStr(\"value\")\n .setKey(\"customField\")\n .setFieldType(Schema.SchemaFieldType.KEYWORD)\n .build();\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time, List.of(customField)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n SearchResult termQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"customField:value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(termQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermStrQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermStrQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermNumericQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermNumericQuery.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"*\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNull(result.getInternalAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 42, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollExecutedCommitAsyncCallbacks", + "reference": " @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks() {\n consumer = newConsumer();\n MockCommitCallback callback = new MockCommitCallback();\n completeCommitAsyncApplicationEventSuccessfully();\n doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n completeFetchedCommittedOffsetApplicationEventSuccessfully(mkMap());\n\n consumer.assign(Collections.singleton(new TopicPartition(\"foo\", 0)));\n assertDoesNotThrow(() -> consumer.commitAsync(new HashMap<>(), callback));\n assertMockCommitCallbackInvoked(() -> consumer.poll(Duration.ZERO),\n callback,\n null);\n }", + "completions": [ + "```java\n @Test\n public void testPoll() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertNotNull(records);\n assertEquals(0, records.count());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 43, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static URI resolve(URI base, URI ref) {\n String path = ref.getRawPath();\n if (path != null && !path.isEmpty()) {\n path = base.getRawPath();\n if (path == null || !path.endsWith(\"/\")) {\n try {\n base = new URI(base.getScheme(), base.getAuthority(), base.getPath() + '/', null, null);\n } catch (URISyntaxException e) {\n throw new IllegalStateException(e);\n }\n }\n }\n return URIUtils.resolve(base, ref);\n }\n```\n###test function signature:\n```java\n @Test\n void testResolve_BasePathTrailingSlash()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testResolve_BasePathTrailingSlash", + "reference": " @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://host/sub/dir/\");\n assertEquals(\"http://host/sub/dir/file.jar\", resolve(base, \"file.jar\"));\n assertEquals(\"http://host/sub/dir/dir/file.jar\", resolve(base, \"dir/file.jar\"));\n assertEquals(\"http://host/sub/dir/?arg=val\", resolve(base, \"?arg=val\"));\n assertEquals(\"http://host/sub/dir/file?arg=val\", resolve(base, \"file?arg=val\"));\n assertEquals(\"http://host/sub/dir/dir/file?arg=val\", resolve(base, \"dir/file?arg=val\"));\n }", + "completions": [ + "```java\n @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://example.com/path/\");\n URI ref = URI.create(\"http://example.com/path/file.txt\");\n URI expected = URI.create(\"http://example.com/path/file.txt\");\n URI actual = resolve(base, ref);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 25, + "token": 187, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_collection_of_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_collection_of_non_serializable_object", + "reference": " @Test\n public void should_clone_collection_of_non_serializable_object() {\n List original = new ArrayList<>();\n original.add(new NonSerializableObject(\"value\"));\n List cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_collection_of_non_serializable_object() {\n // Given\n List nonSerializableList = new ArrayList<>();\n nonSerializableList.add(new NonSerializableObject(\"test\"));\n nonSerializableList.add(new NonSerializableObject(\"test2\"));\n\n // When\n List clonedList = objectCloner.clone(nonSerializableList);\n\n // Then\n assertNotSame(nonSerializableList, clonedList);\n assertEquals(nonSerializableList.size(), clonedList.size());\n for (int i = 0; i < nonSerializableList.size(); i++) {\n assertNotSame(nonSerializableList.get(i), clonedList.get(i));\n assertEquals(nonSerializableList.get(i), clonedList.get(i));\n }\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 45, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssign()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssign", + "reference": " @Test\n public void testAssign() {\n consumer = newConsumer();\n final TopicPartition tp = new TopicPartition(\"foo\", 3);\n consumer.assign(singleton(tp));\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().contains(tp));\n verify(applicationEventHandler).add(any(AssignmentChangeEvent.class));\n verify(applicationEventHandler).add(any(NewTopicsMetadataUpdateRequestEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testAssign() {\n // Create a mock Consumer instance\n Consumer mockConsumer = mock(Consumer.class);\n\n // Create a collection of TopicPartition instances\n Collection partitions = new ArrayList<>();\n partitions.add(new TopicPartition(\"testTopic\", 0));\n partitions.add(new TopicPartition(\"testTopic\", 1));\n\n // Call the assign method\n mockConsumer.assign(partitions);\n\n // Verify that the assign method was called with the correct parameters\n verify(mockConsumer).assign(partitions);\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 46, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String relativePath(Path absolutePath, Path basePath) {\n Preconditions.checkArgument(absolutePath.isAbsolute(), \"absolutePath must be an absolute path\");\n Preconditions.checkArgument(basePath.isAbsolute(), \"basePath must be an absolute path\");\n\n List absolutePathComponents = Lists.newArrayList(toPathComponents(absolutePath)); // make a copy\n List basePathComponents = toPathComponents(basePath);\n boolean hasCommonPrefix = basePathComponents.isEmpty(); // when basePath is \"/\", we always have a common prefix\n\n for (String base : basePathComponents) {\n if (absolutePathComponents.get(0).equals(base)) {\n absolutePathComponents.remove(0);\n hasCommonPrefix = true;\n } else {\n break;\n }\n }\n if (hasCommonPrefix) {\n return PATH_JOINER.join(absolutePathComponents);\n } else {\n return absolutePath.toString();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testRelativePath() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRelativePath", + "reference": " @Test\n public void testRelativePath() throws Exception {\n assertEquals(\"a\", PathUtils.relativePath(Path.of(\"/a\"), Path.of(\"/\")));\n assertEquals(\"b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/a\")));\n assertEquals(\"b/c.json\", PathUtils.relativePath(Path.of(\"/a/b/c.json\"), Path.of(\"/a\")));\n assertEquals(\"c/d/e\", PathUtils.relativePath(Path.of(\"/a/b/c/d/e\"), Path.of(\"/a/b\")));\n assertEquals(\"/a/b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/c/d\"))); // no common prefix\n }", + "completions": [ + "```java\n @Test\n public void testRelativePath() throws Exception {\n Path absolutePath = Paths.get(\"/a/b/c/d\");\n Path basePath = Paths.get(\"/a/b\");\n String expected = \"c/d\";\n String actual = relativePath(absolutePath, basePath);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 31, + "token": 234, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Boolean isJson(String value) {\n JsonElement jsonElement;\n try {\n jsonElement = JsonParser.parseString(value);\n } catch (Exception e) {\n return false;\n }\n if (jsonElement == null || jsonElement.isJsonNull()) {\n return false;\n }\n if (!jsonElement.isJsonObject() && !jsonElement.isJsonArray()) {\n return false;\n }\n return true;\n }\n```\n###test function signature:\n```java\n @Test\n void testIsJson()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJson", + "reference": " @Test\n void testIsJson() {\n assertThat(JsonUtil.isJson(\"value\")).isFalse();\n }", + "completions": [ + "```java\n @Test\n void testIsJson() {\n // Test case 1: Valid JSON object\n String jsonObject = \"{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}\";\n assertTrue(isJson(jsonObject));\n\n // Test case 2: Valid JSON array\n String jsonArray = \"[{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}, {\\\"name\\\":\\\"Jane\\\", \\\"age\\\":25, \\\"city\\\":\\\"Chicago\\\"}]\";\n assertTrue(isJson(jsonArray));\n\n // Test case 3: Invalid JSON (not an object or array)\n String invalidJson = \"\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"\";\n assertFalse(isJson(invalidJson));\n\n // Test case 4: Empty string\n String emptyString = \"\";\n assertFalse(isJson(emptyString));\n\n // Test case 5: Null value\n String nullValue = null;\n assertFalse(isJson(nullValue));\n }\n```\n" + ], + "line": 26, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchOnlyHistogram()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchOnlyHistogram", + "reference": " @Test\n public void testSearchOnlyHistogram() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 0,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(0);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(2);\n\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(1).getDocCount()).isEqualTo(1);\n\n assertThat(\n Long.parseLong(histogram.getBuckets().get(0).getKeyAsString()) >= time.toEpochMilli())\n .isTrue();\n assertThat(\n Long.parseLong(histogram.getBuckets().get(1).getKeyAsString())\n <= time.plusSeconds(10).toEpochMilli())\n .isTrue();\n }", + "completions": [ + "```java\n @Test\n public void testSearchOnlyHistogram() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 0;\n AggBuilder aggBuilder = new AggBuilder();\n aggBuilder.addHistogramAgg(\"testField\", \"testInterval\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation().getHistogramAggs().size() > 0);\n assertEquals(\"testField\", result.getAggregation().getHistogramAggs().get(0).getField());\n assertEquals(\"testInterval\", result.getAggregation().getHistogramAggs().get(0).getInterval());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 49, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent", + "reference": " @Test\n public void testReadComponent() {\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n EntityMetadata entityMetadata = mock(EntityMetadata.class);\n VariableSource variableSource = mock(VariableSource.class);\n when(contentPermissionChecker.isPermitted(anyString(), any(), eq(BreadActions.BROWSE), any())).thenReturn(true);\n when(component.getEntityMetadata()).thenReturn(entityMetadata);\n when(entityMetadata.getId()).thenReturn(new DetachedEntityId(\"someid\"));\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(Arrays.asList(asset));\n when(assetVariableResolver.fromAsset(asset)).thenReturn(variableSource);\n ComponentXO componentXO = underTest.readComponent(\"someid\", \"testRepositoryName\");\n\n assertThat(componentXO, is(notNullValue()));\n assertThat(componentXO.getId(), is(\"someid\"));\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n Supplier txSupplier = () -> storageTx;\n\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(txSupplier);\n when(storageTx.findComponent(componentId)).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(ImmutableList.of(asset));\n\n // When\n ComponentXO result = readComponent(repository, componentId);\n\n // Then\n verify(repository).facet(StorageFacet.class);\n verify(storageFacet).txSupplier();\n verify(storageTx).begin();\n verify(storageTx).findComponent(componentId);\n verify(storageTx).browseAssets(component);\n verify(storageTx).close();\n // Add assertions for the result\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 50, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ListResult listEntities(String tableName, String searchIndexName, List matchFilters,\n List queryFilters, List multiMatchFilter, String nextToken, List sorters, Class clazz) {\n SearchQuery searchQuery = createSearchQuery(matchFilters, queryFilters, multiMatchFilter);\n SearchRequest searchRequest = new SearchRequest(tableName, searchIndexName, searchQuery);\n if (!StringUtils.isEmpty(nextToken)) {\n byte[] tokenBytes = EncryptionUtil.decode(nextToken);\n searchRequest.getSearchQuery().setToken(tokenBytes);\n } else {\n if (sorters != null &&!sorters.isEmpty()) {\n searchQuery.setSort(new Sort(sorters));\n }\n }\n SearchRequest.ColumnsToGet columnsToGet = new SearchRequest.ColumnsToGet();\n columnsToGet.setColumns(ReflectionUtil.getPropertyNames(clazz));\n searchRequest.setColumnsToGet(columnsToGet);\n log.info(\"searchRequest:{}\", JsonUtil.toJsonString(searchRequest));\n SearchResponse searchResponse = otsClient.search(searchRequest);\n if (searchResponse == null || searchResponse.getRows() == null) {\n return ListResult.genSuccessListResult(null, 0);\n }\n byte[] nextTokenBytes = searchResponse.getNextToken();\n nextToken = nextTokenBytes == null || nextTokenBytes.length == 0 ? null : EncryptionUtil.encode(nextTokenBytes);\n List result = searchResponse.getRows().stream()\n .map(row -> OtsUtil.convertRowToDTO(row, clazz))\n .collect(Collectors.toList());\n return ListResult.genSuccessListResult(result, searchResponse.getTotalCount(), nextToken);\n }\n```\n###test function signature:\n```java\n @Test\n void testListEntities()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListEntities", + "reference": " @Test\n void testListEntities() {\n final List matchFilters = createMatchFilters();\n final List queryFilters = createQueryFilters();\n final ListResult expectedResult = new ListResult<>();\n FieldSort fieldSort = new FieldSort(OrderOtsConstant.GMT_CREATE_LONG);\n fieldSort.setOrder(SortOrder.DESC);\n expectedResult.setData(Arrays.asList());\n expectedResult.setCount(0L);\n String nextToken = \"CAESFQoTChEKDWdtdENyZWF0ZUxvbmcQARgBIlQKCQBI8UqGigEAAApHA0IAAAAxUzM1MzQzMTM0NjQzMjYzMzAzMzYyMzE2MTMzMzkzOTM1MzEzNjM2MzM2NDM2MzAzMDMwNjYzNTM1MzA2NjY0MzM=\";\n expectedResult.setNextToken(nextToken);\n final SearchResponse searchResponse = new SearchResponse(new Response(\"requestId\"));\n when(mockOtsClient.search(any(SearchRequest.class))).thenReturn(searchResponse);\n ListResult result = baseOtsHelper.listEntities(\"order\", \"order_index\", matchFilters, queryFilters, null, nextToken, Arrays.asList(fieldSort), OrderDTO.class);\n assertThat(result.getCount()).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n void testListEntities() {\n // Given\n String tableName = \"testTable\";\n String searchIndexName = \"testIndex\";\n List matchFilters = Arrays.asList(new OtsFilter(\"field1\", OtsFilter.CompareOperator.EQUAL, \"value1\"),\n new OtsFilter(\"field2\", OtsFilter.CompareOperator.EQUAL, \"value2\"));\n List queryFilters = Arrays.asList(new OtsFilter(\"field3\", OtsFilter.CompareOperator.EQUAL, \"value3\"),\n new OtsFilter(\"field4\", OtsFilter.CompareOperator.EQUAL, \"value4\"));\n List multiMatchFilter = Arrays.asList(new OtsFilter(\"field5\", OtsFilter.CompareOperator.EQUAL, \"value5\"),\n new OtsFilter(\"field6\", OtsFilter.CompareOperator.EQUAL, \"value6\"));\n String nextToken = \"testNextToken\";\n List sorters = Arrays.asList(new Sort.Sorter(\"field7\", Sort.SortOrder.ASC),\n new Sort.Sorter(\"field8\", Sort.SortOrder.DESC));\n Class clazz = TestDTO.class;\n\n // When\n ListResult result = listEntities(tableName, searchIndexName, matchFilters, queryFilters, multiMatchFilter, nextToken, sorters, clazz);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalCount());\n assertNull(result.getNextToken());\n assertTrue(result.getData().isEmpty());\n }\n```\n" + ], + "line": 38, + "token": 346, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 51, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String[] listFiles(URI fileUri, boolean recursive) throws IOException {\n try {\n ImmutableList.Builder builder = ImmutableList.builder();\n String continuationToken = null;\n boolean isDone = false;\n String prefix = normalizeToDirectoryPrefix(fileUri);\n int fileCount = 0;\n while (!isDone) {\n ListObjectsV2Request.Builder listObjectsV2RequestBuilder =\n ListObjectsV2Request.builder().bucket(fileUri.getHost());\n if (!prefix.equals(DELIMITER)) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.prefix(prefix);\n }\n if (!recursive) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.delimiter(DELIMITER);\n }\n if (continuationToken != null) {\n listObjectsV2RequestBuilder.continuationToken(continuationToken);\n }\n ListObjectsV2Request listObjectsV2Request = listObjectsV2RequestBuilder.build();\n LOG.debug(\"Trying to send ListObjectsV2Request {}\", listObjectsV2Request);\n ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);\n LOG.debug(\"Getting ListObjectsV2Response: {}\", listObjectsV2Response);\n List filesReturned = listObjectsV2Response.contents();\n fileCount += filesReturned.size();\n filesReturned.stream()\n .forEach(\n object -> {\n // Only add files and not directories\n if (!object.key().equals(fileUri.getPath())\n && !object.key().endsWith(DELIMITER)) {\n String fileKey = object.key();\n if (fileKey.startsWith(DELIMITER)) {\n fileKey = fileKey.substring(1);\n }\n builder.add(S3_SCHEME + fileUri.getHost() + DELIMITER + fileKey);\n }\n });\n if (fileCount == LIST_MAX_KEYS) {\n // check if we reached the max keys returned, if so abort and throw an error message\n LOG.error(\n \"Too many files ({}) returned from S3 when attempting to list object prefixes\",\n LIST_MAX_KEYS);\n throw new IllegalStateException(\n String.format(\n \"Max keys (%s) reached when attempting to list S3 objects\", LIST_MAX_KEYS));\n }\n isDone = !listObjectsV2Response.isTruncated();\n continuationToken = listObjectsV2Response.nextContinuationToken();\n }\n String[] listedFiles = builder.build().toArray(new String[0]);\n LOG.debug(\n \"Listed {} files from URI: {}, is recursive: {}\", listedFiles.length, fileUri, recursive);\n return listedFiles;\n } catch (Throwable t) {\n throw new IOException(t);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListFilesInFolderRecursive", + "reference": " @Test\n public void testListFilesInFolderRecursive() throws Exception {\n String folder = \"list-files-rec\";\n String[] nestedFolders = new String[] {\"list-files-child-1\", \"list-files-child-2\"};\n String[] originalFiles = new String[] {\"a-list-3.txt\", \"b-list-3.txt\", \"c-list-3.txt\"};\n\n List expectedResultList = new ArrayList<>();\n for (String childFolder : nestedFolders) {\n String folderName = folder + DELIMITER + childFolder;\n for (String fileName : originalFiles) {\n createEmptyFile(folderName, fileName);\n expectedResultList.add(\n String.format(FILE_FORMAT, SCHEME, bucket, folderName + DELIMITER + fileName));\n }\n }\n String[] actualFiles =\n s3BlobFs.listFiles(URI.create(String.format(FILE_FORMAT, SCHEME, bucket, folder)), true);\n\n actualFiles =\n Arrays.stream(actualFiles).filter(x -> x.contains(\"list-3\")).toArray(String[]::new);\n assertEquals(actualFiles.length, expectedResultList.size());\n assertTrue(Arrays.equals(expectedResultList.toArray(), actualFiles));\n }", + "completions": [ + "```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception {\n // Given\n URI fileUri = new URI(\"s3://bucket/folder/\");\n boolean recursive = true;\n String[] expectedFiles = {\n \"s3://bucket/folder/file1.txt\",\n \"s3://bucket/folder/file2.txt\",\n \"s3://bucket/folder/subfolder/file3.txt\",\n \"s3://bucket/folder/subfolder/file4.txt\"\n };\n\n // When\n String[] actualFiles = s3FileSystem.listFiles(fileUri, recursive);\n\n // Then\n assertArrayEquals(expectedFiles, actualFiles);\n }\n```\n" + ], + "line": 70, + "token": 490, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 52, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_nameDuplicate", + "reference": " @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // mock 数据\n RoleDO roleDO = randomPojo(RoleDO.class, o -> o.setName(\"role_name\"));\n roleMapper.insert(roleDO);\n // 准备参数\n String name = \"role_name\";\n\n // 调用,并断言异常\n assertServiceException(() -> roleService.validateRoleDuplicate(name, randomString(), null),\n ROLE_NAME_DUPLICATE, name);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // Given\n String name = \"duplicateName\";\n String code = \"uniqueCode\";\n Long id = 1L;\n RoleDO existingRole = new RoleDO();\n existingRole.setId(2L);\n existingRole.setName(name);\n existingRole.setCode(code);\n when(roleMapper.selectByName(name)).thenReturn(existingRole);\n\n // When\n try {\n roleService.validateRoleDuplicate(name, code, id);\n fail(\"Expected exception not thrown\");\n } catch (ServiceException e) {\n // Then\n assertEquals(ROLE_NAME_DUPLICATE, e.getCode());\n assertEquals(name, e.getArgs()[0]);\n }\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public String sign(Map map) {\n try {\n Date date = new Date(System.currentTimeMillis() + jwtProperties.getExpireTimeMills());\n Algorithm algorithm = Algorithm.HMAC256(jwtProperties.getTokenSecret());\n Map header = new HashMap<>(2);\n header.put(\"typ\", \"jwt\");\n JWTCreator.Builder builder = JWT.create()\n .withHeader(header)\n .withIssuedAt(new Date())\n .withExpiresAt(date);\n map.entrySet().forEach(entry -> {\n if (entry.getValue() instanceof Integer) {\n builder.withClaim(entry.getKey(), (Integer) entry.getValue());\n } else if (entry.getValue() instanceof Long) {\n builder.withClaim(entry.getKey(), (Long) entry.getValue());\n } else if (entry.getValue() instanceof Boolean) {\n builder.withClaim(entry.getKey(), (Boolean) entry.getValue());\n } else if (entry.getValue() instanceof String) {\n builder.withClaim(entry.getKey(), String.valueOf(entry.getValue()));\n } else if (entry.getValue() instanceof Double) {\n builder.withClaim(entry.getKey(), (Double) entry.getValue());\n } else if (entry.getValue() instanceof Date) {\n builder.withClaim(entry.getKey(), (Date) entry.getValue());\n }\n });\n return builder.sign(algorithm);\n } catch (Exception e) {\n return null;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sign_generateJwtToken_succeed()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sign_generateJwtToken_succeed", + "reference": " @Test\n public void sign_generateJwtToken_succeed() {\n Map map = new HashMap<>();\n map.put(\"userId\", 123);\n map.put(\"username\", \"testUser\");\n String token = jwtService.sign(map);\n Assert.assertNotNull(token);\n }", + "completions": [ + "```java\n @Test\n public void sign_generateJwtToken_succeed() {\n // Given\n Map map = new HashMap<>();\n map.put(\"userId\", 1);\n map.put(\"username\", \"testUser\");\n map.put(\"isAdmin\", true);\n map.put(\"expireTime\", System.currentTimeMillis() + 3600000);\n\n // When\n String jwtToken = jwtService.sign(map);\n\n // Then\n assertNotNull(jwtToken);\n assertTrue(jwtToken.length() > 0);\n }\n```\n" + ], + "line": 41, + "token": 349, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 54, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_non_serializable_value_with_null_value", + "reference": " @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map original = new LinkedHashMap<>();\n original.put(\"null\", null);\n original.put(\"key\", new NonSerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map map = new HashMap<>();\n map.put(\"key\", null);\n\n Map clonedMap = (Map) clone(map);\n\n assertNotSame(map, clonedMap);\n assertEquals(map, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 55, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void init(final DiscoveryConfig config) {\n if (this.namingService != null) {\n LOGGER.info(\"Nacos naming service already registered\");\n return;\n }\n Properties properties = config.getProps();\n Properties nacosProperties = new Properties();\n this.groupName = properties.getProperty(\"groupName\", \"SHENYU_GROUP\");\n String serverAddr = config.getServerList();\n nacosProperties.put(PropertyKeyConst.SERVER_ADDR, serverAddr);\n nacosProperties.put(PropertyKeyConst.NAMESPACE, properties.getProperty(NAMESPACE, \"\"));\n nacosProperties.put(PropertyKeyConst.USERNAME, properties.getProperty(PropertyKeyConst.USERNAME, \"\"));\n nacosProperties.put(PropertyKeyConst.PASSWORD, properties.getProperty(PropertyKeyConst.PASSWORD, \"\"));\n nacosProperties.put(PropertyKeyConst.ACCESS_KEY, properties.getProperty(PropertyKeyConst.ACCESS_KEY, \"\"));\n nacosProperties.put(PropertyKeyConst.SECRET_KEY, properties.getProperty(PropertyKeyConst.SECRET_KEY, \"\"));\n try {\n this.namingService = NamingFactory.createNamingService(nacosProperties);\n LOGGER.info(\"Nacos naming service initialized success\");\n } catch (NacosException e) {\n LOGGER.error(\"Error initializing Nacos naming service\", e);\n throw new ShenyuException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testInit", + "reference": " @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n // Set the discovery config\n setField(nacosDiscoveryServiceUnderTest.getClass(), \"namingService\", null);\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n config.setServerList(\"127.0.0.1:8848\");\n properties.setProperty(\"groupName\", \"SHENYU_GROUP\");\n config.setProps(properties);\n\n try (MockedStatic mockedNamingFactory = mockStatic(NamingFactory.class)) {\n // Mock the successful creation of NamingService\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenReturn(namingService);\n nacosDiscoveryServiceUnderTest.init(config);\n mockedNamingFactory.verify(() -> NamingFactory.createNamingService(any(Properties.class)));\n assertEquals(namingService, getField(nacosDiscoveryServiceUnderTest, \"namingService\"));\n // Mock the situation where NamingService fails to be created and throws an exception\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenThrow(new NacosException());\n assertDoesNotThrow(() -> nacosDiscoveryServiceUnderTest.init(config));\n }\n }", + "completions": [ + "```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n properties.put(\"groupName\", \"SHENYU_GROUP\");\n properties.put(\"serverAddr\", \"localhost:8848\");\n properties.put(\"namespace\", \"\");\n properties.put(\"username\", \"\");\n properties.put(\"password\", \"\");\n properties.put(\"accessKey\", \"\");\n properties.put(\"secretKey\", \"\");\n config.setProps(properties);\n config.setServerList(\"localhost:8848\");\n\n NacosDiscoveryService nacosDiscoveryService = new NacosDiscoveryService();\n nacosDiscoveryService.init(config);\n\n Field namingServiceField = NacosDiscoveryService.class.getDeclaredField(\"namingService\");\n namingServiceField.setAccessible(true);\n NamingService namingService = (NamingService) namingServiceField.get(nacosDiscoveryService);\n\n assertNotNull(namingService);\n }\n```\n" + ], + "line": 35, + "token": 273, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 56, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyRecoveryNotBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(850, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 2) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(750, 0, indexerConfig));\n\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n assertThat(recoveryTaskCreator.determineStartingOffset(1201, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTaskCreator.determineStartingOffset(1200, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(1150, 0, indexerConfig));\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore))\n .contains(recoveryTask1, recoveryTask11);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 57, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public boolean isJsonMime(String mediaType) {\n // \"* / *\" is default to JSON\n if (\"*/*\".equals(mediaType)) {\n return true;\n }\n\n try {\n return isJsonMime(MediaType.parseMediaType(mediaType));\n } catch (InvalidMediaTypeException e) {\n }\n return false;\n }\n```\n###test function signature:\n```java\n @Test\n public void testIsJsonMime()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJsonMime", + "reference": " @Test\n public void testIsJsonMime() {\n assertFalse(apiClient.isJsonMime(null));\n assertFalse(apiClient.isJsonMime(\"\"));\n assertFalse(apiClient.isJsonMime(\"text/plain\"));\n assertFalse(apiClient.isJsonMime(\"application/xml\"));\n assertFalse(apiClient.isJsonMime(\"application/jsonp\"));\n assertFalse(apiClient.isJsonMime(\"example/json\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+jsonx\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+xjson\"));\n\n assertTrue(apiClient.isJsonMime(\"application/json\"));\n assertTrue(apiClient.isJsonMime(\"application/json; charset=UTF8\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/JSON\"));\n\n assertTrue(apiClient.isJsonMime(\"application/problem+json\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/PROBLEM+JSON\"));\n assertTrue(apiClient.isJsonMime(\"application/json\\t\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+bar+json\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json;x;y\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json\\t;\"));\n assertTrue(apiClient.isJsonMime(\"Example/fOO+JSON\"));\n }", + "completions": [ + "```java\n @Test\n public void testIsJsonMime() {\n // Test case 1: \"* / *\"\n assertTrue(isJsonMime(\"*/*\"));\n\n // Test case 2: \"application/json\"\n assertTrue(isJsonMime(\"application/json\"));\n\n // Test case 3: \"application/json; charset=UTF-8\"\n assertTrue(isJsonMime(\"application/json; charset=UTF-8\"));\n\n // Test case 4: \"application/xml\"\n assertFalse(isJsonMime(\"application/xml\"));\n\n // Test case 5: \"text/plain\"\n assertFalse(isJsonMime(\"text/plain\"));\n\n // Test case 6: \"invalid/media/type\"\n assertFalse(isJsonMime(\"invalid/media/type\"));\n\n // Test case 7: null\n assertFalse(isJsonMime(null));\n }\n```\n" + ], + "line": 22, + "token": 139, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n dpopService.validateDpopThumprint(\"test\", \"test\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n // Given\n String existingThumprint = \"existingThumprint\";\n String requestThumprint = \"existingThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n // Call the method to test\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String getExpandUrl(String url, Object paramObj) {\n UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url);\n if (paramObj instanceof Map) {\n for (Map.Entry param : ((Map) paramObj).entrySet()) {\n if (null != param.getValue()) {\n builder.queryParam(param.getKey(), param.getValue());\n }\n }\n } else {\n for (Field param : paramObj.getClass().getDeclaredFields()) {\n param.setAccessible(true);\n\n try {\n Object value = param.get(paramObj);\n if (null != value) {\n if (value instanceof List) {\n builder.queryParam(param.getName(), (List) value);\n } else {\n builder.queryParam(param.getName(), value);\n }\n }\n } catch (IllegalAccessException e) {\n log.warn(\"Field: {} value retrieval failed, exception: {}\", param.getName(), e);\n }\n }\n }\n\n return builder.build().encode().toUri().toString();\n }\n```\n###test function signature:\n```java\n @Test\n void testGetExpandUrl()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetExpandUrl", + "reference": " @Test\n void testGetExpandUrl() {\n String result = HttpUtil.getExpandUrl(\"http://baidu.com\", \"paramObj\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testGetExpandUrl() {\n // Test with a Map\n Map map = new HashMap<>();\n map.put(\"param1\", \"value1\");\n map.put(\"param2\", \"value2\");\n String url = \"http://example.com\";\n String expected = \"http://example.com?param1=value1¶m2=value2\";\n String actual = getExpandUrl(url, map);\n assertEquals(expected, actual);\n\n // Test with a custom object\n TestObject obj = new TestObject(\"value1\", \"value2\");\n expected = \"http://example.com?param1=value1¶m2=value2\";\n actual = getExpandUrl(url, obj);\n assertEquals(expected, actual);\n }\n\n public static class TestObject {\n private String param1;\n private String param2;\n\n public TestObject(String param1, String param2) {\n this.param1 = param1;\n this.param2 = param2;\n }\n\n public String getParam1() {\n return param1;\n }\n\n public String getParam2() {\n return param2;\n }\n }\n```\n" + ], + "line": 38, + "token": 277, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 60, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"asdf\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n String invalidExpirationTime = \"invalid\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(invalidExpirationTime);\n assertEquals(CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 61, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Get\n @Path(\"/:indexName/_mapping\")\n public HttpResponse mapping(\n @Param(\"indexName\") Optional indexName,\n @Param(\"startTimeEpochMs\") Optional startTimeEpochMs,\n @Param(\"endTimeEpochMs\") Optional endTimeEpochMs)\n throws IOException {\n // Use a tree map so the results are naturally sorted\n Map> propertiesMap = new TreeMap<>();\n\n // we default the schema search to the last hour if params are not provided\n AstraSearch.SchemaResult schemaResult =\n searcher.getSchema(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(\n startTimeEpochMs.orElse(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build());\n\n Map schema = SearchResultUtils.fromSchemaResultProto(schemaResult);\n schema.forEach((key, value) -> propertiesMap.put(key, Map.of(\"type\", value.getName())));\n\n // todo - remove this after we add support for a \"date\" type\n // override the timestamp as a date field for proper autocomplete\n propertiesMap.put(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, Map.of(\"type\", \"date\"));\n\n return HttpResponse.of(\n HttpStatus.OK,\n MediaType.JSON,\n JsonUtil.writeAsString(\n ImmutableMap.of(\n indexName.orElseThrow(),\n ImmutableMap.of(\"mappings\", ImmutableMap.of(\"properties\", propertiesMap)))));\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexMapping() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexMapping", + "reference": " @Test\n public void testIndexMapping() throws IOException {\n AstraQueryServiceBase searcher = mock(AstraQueryServiceBase.class);\n ElasticsearchApiService serviceUnderTest = new ElasticsearchApiService(searcher);\n\n Instant start = Instant.now();\n Instant end = start.minusSeconds(60);\n\n when(searcher.getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build())))\n .thenReturn(AstraSearch.SchemaResult.newBuilder().build());\n\n HttpResponse response =\n serviceUnderTest.mapping(\n Optional.of(\"foo\"), Optional.of(start.toEpochMilli()), Optional.of(end.toEpochMilli()));\n verify(searcher)\n .getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build()));\n\n // handle response\n AggregatedHttpResponse aggregatedRes = response.aggregate().join();\n String body = aggregatedRes.content(StandardCharsets.UTF_8);\n JsonNode jsonNode = new ObjectMapper().readTree(body);\n\n assertThat(aggregatedRes.status().code()).isEqualTo(200);\n\n assertThat(jsonNode.findValue(\"foo\")).isNotNull();\n assertThat(\n jsonNode.findValue(\"foo\").findValue(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName))\n .isNotNull();\n\n when(searcher.getSchema(any()))\n .thenAnswer(\n invocationOnMock -> {\n AstraSearch.SchemaRequest request =\n ((AstraSearch.SchemaRequest) invocationOnMock.getArguments()[0]);\n assertThat(request.getDataset()).isEqualTo(\"bar\");\n assertThat(request.getStartTimeEpochMs())\n .isCloseTo(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli(),\n Offset.offset(1000L));\n assertThat(request.getEndTimeEpochMs())\n .isCloseTo(Instant.now().toEpochMilli(), Offset.offset(1000L));\n return AstraSearch.SchemaResult.newBuilder().build();\n });\n serviceUnderTest.mapping(Optional.of(\"bar\"), Optional.empty(), Optional.empty());\n }", + "completions": [ + "```java\n @Test\n public void testIndexMapping() throws IOException {\n Optional indexName = Optional.of(\"testIndex\");\n Optional startTimeEpochMs = Optional.of(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli());\n Optional endTimeEpochMs = Optional.of(Instant.now().toEpochMilli());\n\n AstraSearch.SchemaResult schemaResult =\n AstraSearch.SchemaResult.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(startTimeEpochMs.orElse(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build();\n\n when(searcher.getSchema(any(AstraSearch.SchemaRequest.class))).thenReturn(schemaResult);\n\n HttpResponse response = mapping(indexName, startTimeEpochMs, endTimeEpochMs);\n\n assertEquals(HttpStatus.OK, response.getStatus());\n assertEquals(MediaType.JSON, response.getContentType());\n assertNotNull(response.getContent());\n }\n```\n" + ], + "line": 43, + "token": 354, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 62, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "@SuppressWarnings", + "reference": " @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n consumer = newConsumer();\n Exception e = assertThrows(UnsupportedOperationException.class, () -> consumer.poll(0L));\n assertEquals(\"Consumer.poll(long) is not supported when \\\"group.protocol\\\" is \\\"consumer\\\". \" +\n \"This method is deprecated and will be removed in the next major release.\", e.getMessage());\n }", + "completions": [ + "```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 46, + "token": 373, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 63, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testNullSearchString()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNullSearchString", + "reference": " @Test\n public void testNullSearchString() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n assertThatExceptionOfType(IllegalArgumentException.class)\n .isThrownBy(\n () ->\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME + \"miss\",\n null,\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\")));\n }", + "completions": [ + "```java\n @Test\n public void testNullSearchString() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = null;\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n try {\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n fail(\"Expected an IllegalArgumentException to be thrown\");\n } catch (IllegalArgumentException e) {\n // Expected\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 64, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture sendMessageBack(ProxyContext ctx, ReceiptHandle handle, String messageId,\n ConsumerSendMsgBackRequestHeader requestHeader, long timeoutMillis) {\n // Build the response.\n final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null, null);\n\n Integer delayLevel = requestHeader.getDelayLevel();\n if (Objects.isNull(delayLevel)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument delay level is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n Long offset = requestHeader.getOffset();\n if (Objects.isNull(offset)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument offset is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n VirtualQueue virtualQueue = new VirtualQueue(requestHeader.getBname());\n\n CompletableFuture topicFuture = topicOf(requestHeader.getOriginTopic());\n CompletableFuture groupFuture = consumerGroupOf(requestHeader.getGroup());\n return topicFuture.thenCombine(groupFuture, (topic, group) -> {\n if (topic.getTopicId() != virtualQueue.topicId()) {\n LOGGER.error(\"Topic id in request header {} does not match topic id in message queue {}, maybe the topic is recreated.\",\n topic.getTopicId(), virtualQueue.topicId());\n throw new ProxyException(apache.rocketmq.v2.Code.TOPIC_NOT_FOUND, \"Topic resource does not exist.\");\n }\n return Pair.of(topic, group);\n }).thenCompose(pair -> {\n Topic topic = pair.getLeft();\n ConsumerGroup group = pair.getRight();\n\n return store.pull(StoreContext.EMPTY, group.getGroupId(), topic.getTopicId(), virtualQueue.physicalQueueId(),\n Filter.DEFAULT_FILTER, requestHeader.getOffset(), 1, false)\n .thenApply(pullResult -> {\n if (pullResult.status() == com.automq.rocketmq.store.model.message.PullResult.Status.FOUND) {\n return pullResult.messageList().get(0);\n }\n throw new ProxyException(apache.rocketmq.v2.Code.MESSAGE_NOT_FOUND, \"Message not found from server.\");\n }).thenCompose(messageExt -> {\n if (messageExt.deliveryAttempts() > group.getMaxDeliveryAttempt()) {\n return deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n }\n\n // Message consume retry strategy\n // <0: no retry,put into DLQ directly\n // =0: broker control retry frequency\n // >0: client control retry frequency\n return switch (Integer.compare(delayLevel, 0)) {\n case -1 ->\n deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n case 0 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n // Keep the same logic as apache RocketMQ.\n int serverDelayLevel = messageExt.deliveryAttempts() + 1;\n messageExt.setDeliveryAttempts(serverDelayLevel);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(serverDelayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put messageExt to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n case 1 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n messageExt.setDeliveryAttempts(messageExt.deliveryAttempts() + 1);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(delayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put message to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n default -> throw new IllegalStateException(\"Never reach here\");\n };\n });\n }).whenComplete((nil, throwable) -> {\n if (throwable != null) {\n response.setCode(ResponseCode.SYSTEM_ERROR);\n response.setRemark(throwable.getMessage());\n }\n }).thenApply(nil -> response);\n }\n```\n###test function signature:\n```java\n @Test\n void sendMessageBack()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendMessageBack", + "reference": " @Test\n void sendMessageBack() {\n String topicName = \"topic\";\n VirtualQueue virtualQueue = new VirtualQueue(2, 0);\n ConsumerSendMsgBackRequestHeader header = new ConsumerSendMsgBackRequestHeader();\n header.setBname(virtualQueue.brokerName());\n header.setGroup(\"group\");\n header.setOriginTopic(topicName);\n header.setOffset(0L);\n header.setDelayLevel(0);\n\n CompletionException exception = assertThrowsExactly(CompletionException.class, () -> messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join());\n assertInstanceOf(ProxyException.class, ExceptionUtils.getRealException(exception));\n ProxyException realException = (ProxyException) ExceptionUtils.getRealException(exception);\n assertEquals(realException.getErrorCode(), Code.MESSAGE_NOT_FOUND);\n assertEquals(realException.getMessage(), \"Message not found from server.\");\n\n long topicId = metadataService.topicOf(topicName).join().getTopicId();\n messageStore.put(StoreContext.EMPTY, FlatMessageUtil.convertTo(ProxyContextExt.create(), topicId, 0, \"\", new Message(topicName, \"\", new byte[] {})));\n\n // Broker controlled delay level.\n header.setDelayLevel(0);\n RemotingCommand response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Client controlled delay level.\n header.setDelayLevel(16);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Forward message into dead letter topic.\n header.setDelayLevel(-1);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.times(1)).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n }", + "completions": [ + "```java\n @Test\n void sendMessageBack() {\n // Given\n ProxyContext ctx = mock(ProxyContext.class);\n ReceiptHandle handle = mock(ReceiptHandle.class);\n String messageId = \"testMessageId\";\n ConsumerSendMsgBackRequestHeader requestHeader = mock(ConsumerSendMsgBackRequestHeader.class);\n long timeoutMillis = 1000L;\n\n when(requestHeader.getDelayLevel()).thenReturn(1);\n when(requestHeader.getOffset()).thenReturn(1L);\n when(requestHeader.getOriginTopic()).thenReturn(\"testOriginTopic\");\n when(requestHeader.getGroup()).thenReturn(\"testGroup\");\n when(requestHeader.getBname()).thenReturn(\"testBname\");\n\n Topic topic = mock(Topic.class);\n when(topic.getTopicId()).thenReturn(1);\n\n ConsumerGroup group = mock(ConsumerGroup.class);\n when(group.getGroupId()).thenReturn(\"testGroupId\");\n when(group.getMaxDeliveryAttempt()).thenReturn(3);\n\n FlatMessage message = mock(FlatMessage.class);\n when(message.deliveryAttempts()).thenReturn(1);\n when(message.originalOffset()).thenReturn(1L);\n\n MessageExt messageExt = mock(MessageExt.class);\n when(messageExt.message()).thenReturn(message);\n when(messageExt.deliveryAttempts()).thenReturn(1);\n\n PullResult pullResult = mock(PullResult.class);\n when(pullResult.status()).thenReturn(PullResult.Status.FOUND);\n when(pullResult.messageList()).thenReturn(List.of(messageExt));\n\n Store store = mock(Store.class);\n when(store.pull(any(), any(), any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(pullResult));\n when(store.put(any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n DeadLetterService deadLetterService = mock(DeadLetterService.class);\n when(deadLetterService.send(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n ProxyContextExt ctxExt = mock(ProxyContextExt.class);\n\n TopicService topicService = mock(TopicService.class);\n when(topicService.topicOf(any())).thenReturn(CompletableFuture.completedFuture(topic));\n when(topicService.consumerGroupOf(any())).thenReturn(CompletableFuture.completedFuture(group));\n\n ConsumerService consumerService = new ConsumerService(store, deadLetterService, topicService);\n\n // When\n CompletableFuture future = consumerService.sendMessageBack(ctx, handle, messageId, requestHeader, timeoutMillis);\n\n // Then\n RemotingCommand response = future.join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n }\n```\n" + ], + "line": 99, + "token": 793, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 65, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInRequestThreads()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInRequestThreads", + "reference": " @Test\n public void testGetTracerInRequestThreads() {\n ApolloAuditTracer mockTracer = new ApolloAuditTracer(Mockito.mock(ApolloAuditScopeManager.class), supplier);\n RequestAttributes mockRequestAttributes = Mockito.mock(RequestAttributes.class);\n RequestContextHolder.setRequestAttributes(mockRequestAttributes);\n Mockito.when(mockRequestAttributes.getAttribute(Mockito.eq(ApolloAuditConstants.TRACER), Mockito.eq(RequestAttributes.SCOPE_REQUEST)))\n .thenReturn(mockTracer);\n ApolloAuditTracer get = traceContext.tracer();\n assertNotNull(get);\n Mockito.verify(traceContext, Mockito.times(0))\n .setTracer(Mockito.any(ApolloAuditTracer.class));\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInRequestThreads() {\n // Given\n RequestAttributes requestAttributes = Mockito.mock(RequestAttributes.class);\n ApolloAuditTracer tracer = Mockito.mock(ApolloAuditTracer.class);\n Mockito.when(requestAttributes.getAttribute(ApolloAuditConstants.TRACER, RequestAttributes.SCOPE_REQUEST)).thenReturn(tracer);\n Mockito.when(RequestContextHolder.getRequestAttributes()).thenReturn(requestAttributes);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n Assert.assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Map generate(String modelName) {\n Map kv = new HashMap<>();\n if (modelName != null) {\n final Schema schema = this.globalContext.getSchemaMap().get(modelName);\n if (schema != null) {\n String example = Json.pretty(this.resolveModelToExample(modelName, schema));\n\n if (example != null) {\n kv.put(EXAMPLE, example);\n return Map.copyOf(kv);\n }\n }\n }\n return Collections.emptyMap();\n }\n```\n###test function signature:\n```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "shouldGenerateOneOfWhenOneOfInRoot", + "reference": " @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n OpenAPIModelGenerator generator = setupPayloadGenerator();\n\n Map example = generator.generate(\"PetType\");\n String exampleJson = example.get(\"example\");\n\n Assertions.assertThat(exampleJson)\n .contains(\"PetTypeONE_OF#/components/schemas/Husky\")\n .contains(\"PetTypeONE_OF#/components/schemas/Labrador\");\n }", + "completions": [ + "```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n // Given\n String modelName = \"ModelName\";\n Schema schema = new Schema();\n schema.setOneOf(Arrays.asList(new Schema(), new Schema()));\n this.globalContext.getSchemaMap().put(modelName, schema);\n\n // When\n Map result = this.generate(modelName);\n\n // Then\n assertNotNull(result);\n assertFalse(result.isEmpty());\n assertTrue(result.containsKey(EXAMPLE));\n assertNotNull(result.get(EXAMPLE));\n }\n```\n" + ], + "line": 25, + "token": 181, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInAnotherThreadButSameRequest", + "reference": " @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n ApolloAuditTracer mockTracer = Mockito.mock(ApolloAuditTracer.class);\n {\n Mockito.when(traceContext.tracer()).thenReturn(mockTracer);\n }\n CountDownLatch latch = new CountDownLatch(1);\n Executors.newSingleThreadExecutor().submit(() -> {\n ApolloAuditTracer tracer = traceContext.tracer();\n\n assertEquals(mockTracer, tracer);\n\n latch.countDown();\n });\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n // Given\n RequestAttributes requestAttributes = new ServletRequestAttributes(new MockHttpServletRequest());\n RequestContextHolder.setRequestAttributes(requestAttributes);\n ApolloAuditTracer tracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n requestAttributes.setAttribute(ApolloAuditConstants.TRACER, tracer, RequestAttributes.SCOPE_REQUEST);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchMultipleItemsAndIndices()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchMultipleItemsAndIndices", + "reference": " @Test\n public void testSearchMultipleItemsAndIndices() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(1);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchMultipleItemsAndIndices() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(10, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 69, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n static Map>> getCandidateJobs(String tableNameWithType,\n Map> allJobMetadata)\n throws Exception {\n long nowMs = System.currentTimeMillis();\n Map>> candidates = new HashMap<>();\n // If the job started most recently has already completed, then skip retry for the table.\n Pair latestStartedJob = null;\n Pair latestCompletedJob = null;\n // The processing order of job metadata from the given Map is not deterministic. Track the completed original\n // jobs so that we can simply skip the retry jobs belonging to the completed original jobs.\n Map completedOriginalJobs = new HashMap<>();\n Set cancelledOriginalJobs = new HashSet<>();\n for (Map.Entry> entry : allJobMetadata.entrySet()) {\n String jobId = entry.getKey();\n Map jobMetadata = entry.getValue();\n long statsUpdatedAt = Long.parseLong(jobMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));\n String jobStatsInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);\n if (StringUtils.isEmpty(jobStatsInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job progress stats\", jobId);\n continue;\n }\n String jobCtxInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n if (StringUtils.isEmpty(jobCtxInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job context\", jobId);\n continue;\n }\n TableRebalanceProgressStats jobStats = JsonUtils.stringToObject(jobStatsInStr, TableRebalanceProgressStats.class);\n TableRebalanceContext jobCtx = JsonUtils.stringToObject(jobCtxInStr, TableRebalanceContext.class);\n long jobStartTimeMs = jobStats.getStartTimeMs();\n if (latestStartedJob == null || latestStartedJob.getRight() < jobStartTimeMs) {\n latestStartedJob = Pair.of(jobId, jobStartTimeMs);\n }\n String originalJobId = jobCtx.getOriginalJobId();\n RebalanceResult.Status jobStatus = jobStats.getStatus();\n if (jobStatus == RebalanceResult.Status.DONE || jobStatus == RebalanceResult.Status.NO_OP) {\n LOGGER.info(\"Skip rebalance job: {} as it has completed with status: {}\", jobId, jobStatus);\n completedOriginalJobs.put(originalJobId, jobId);\n if (latestCompletedJob == null || latestCompletedJob.getRight() < jobStartTimeMs) {\n latestCompletedJob = Pair.of(jobId, jobStartTimeMs);\n }\n continue;\n }\n if (jobStatus == RebalanceResult.Status.FAILED || jobStatus == RebalanceResult.Status.ABORTED) {\n LOGGER.info(\"Found rebalance job: {} for original job: {} has been stopped with status: {}\", jobId,\n originalJobId, jobStatus);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n continue;\n }\n if (jobStatus == RebalanceResult.Status.CANCELLED) {\n LOGGER.info(\"Found cancelled rebalance job: {} for original job: {}\", jobId, originalJobId);\n cancelledOriginalJobs.add(originalJobId);\n continue;\n }\n // Check if an IN_PROGRESS job is still actively running.\n long heartbeatTimeoutMs = jobCtx.getConfig().getHeartbeatTimeoutInMs();\n if (nowMs - statsUpdatedAt < heartbeatTimeoutMs) {\n LOGGER.info(\"Rebalance job: {} is actively running with status updated at: {} within timeout: {}. Skip \"\n + \"retry for table: {}\", jobId, statsUpdatedAt, heartbeatTimeoutMs, tableNameWithType);\n return Collections.emptyMap();\n }\n // The job is considered failed, but it's possible it is still running, then we might end up with more than one\n // rebalance jobs running in parallel for a table. The rebalance algorithm is idempotent, so this should be fine\n // for the correctness.\n LOGGER.info(\"Found stuck rebalance job: {} for original job: {}\", jobId, originalJobId);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n }\n if (latestCompletedJob != null && latestCompletedJob.getLeft().equals(latestStartedJob.getLeft())) {\n LOGGER.info(\"Rebalance job: {} started most recently has already done. Skip retry for table: {}\",\n latestCompletedJob.getLeft(), tableNameWithType);\n return Collections.emptyMap();\n }\n for (String jobId : cancelledOriginalJobs) {\n LOGGER.info(\"Skip original job: {} as it's cancelled\", jobId);\n candidates.remove(jobId);\n }\n for (Map.Entry entry : completedOriginalJobs.entrySet()) {\n LOGGER.info(\"Skip original job: {} as it's completed by attempt: {}\", entry.getKey(), entry.getValue());\n candidates.remove(entry.getKey());\n }\n return candidates;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetCandidateJobs()\n throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetCandidateJobs", + "reference": " @Test\n public void testGetCandidateJobs()\n throws Exception {\n String tableName = \"table01\";\n Map> allJobMetadata = new HashMap<>();\n\n // Original job run as job1, and all its retry jobs failed too.\n RebalanceConfig jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n TableRebalanceProgressStats stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(1000);\n TableRebalanceContext jobCtx = TableRebalanceContext.forInitialAttempt(\"job1\", jobCfg);\n Map jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job1\", stats, jobCtx);\n allJobMetadata.put(\"job1\", jobMetadata);\n // 3 failed retry runs for job1\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 2, 1100, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 3, 1200, RebalanceResult.Status.ABORTED);\n allJobMetadata.put(\"job1_3\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 4, 1300, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_4\", jobMetadata);\n\n // Original job run as job2, and its retry job job2_1 completed.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(2000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job2\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job2\", stats, jobCtx);\n allJobMetadata.put(\"job2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job2\", 2, 2100, RebalanceResult.Status.DONE);\n allJobMetadata.put(\"job2_2\", jobMetadata);\n\n // Original job run as job3, and failed to send out heartbeat in time.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.IN_PROGRESS);\n stats.setStartTimeMs(3000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job3\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job3\", stats, jobCtx);\n jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"3000\");\n allJobMetadata.put(\"job3\", jobMetadata);\n\n // Original job run as job4, which didn't have retryJobCfg as from old version of the code.\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(4000);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job4\", stats, null);\n jobMetadata.remove(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n allJobMetadata.put(\"job4\", jobMetadata);\n\n // Only need to retry job1 and job3, as job2 is completed and job4 is from old version of code.\n Map>> jobs =\n RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 2);\n assertTrue(jobs.containsKey(\"job1\"));\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job1\").size(), 4); // four runs including job1,job1_1,job1_2,job1_3\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Abort job1 and cancel its retries, then only job3 is retry candidate.\n jobMetadata = allJobMetadata.get(\"job1_4\");\n cancelRebalanceJob(jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 1);\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Add latest job5 that's already done, thus no need to retry for table.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.DONE);\n stats.setStartTimeMs(5000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job5\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job5\", stats, jobCtx);\n allJobMetadata.put(\"job5\", jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 0);\n }", + "completions": [ + "```java\n @Test\n public void testGetCandidateJobs()\n throws Exception {\n // Given\n String tableNameWithType = \"table1_type1\";\n Map> allJobMetadata = new HashMap<>();\n Map jobMetadata1 = new HashMap<>();\n jobMetadata1.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"FAILED\\\"}\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job1\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job1\", jobMetadata1);\n Map jobMetadata2 = new HashMap<>();\n jobMetadata2.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"CANCELLED\\\"}\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job2\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job2\", jobMetadata2);\n\n // When\n Map>> result = getCandidateJobs(tableNameWithType, allJobMetadata);\n\n // Then\n assertEquals(1, result.size());\n assertTrue(result.containsKey(\"job2\"));\n Set> jobSet = result.get(\"job2\");\n assertEquals(1, jobSet.size());\n Pair jobPair = jobSet.iterator().next();\n assertEquals(\"job2\", jobPair.getLeft().getOriginalJobId());\n assertEquals(1640995200000L, jobPair.getRight().longValue());\n }\n```\n" + ], + "line": 94, + "token": 952, + "line_diff": 4, + "token_diff": 6 + }, + { + "id": 70, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchError()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchError", + "reference": " @Test\n public void distClientMatchError() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.ERROR);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.ERROR);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.NO_SUB));\n for (IMQTTTransientSession session : sessions) {\n verify(session, never()).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchError() {\n // Create a mock for the ITenantMeter interface\n ITenantMeter mockTenantMeter = Mockito.mock(ITenantMeter.class);\n // Create a mock for the ILocalDistService interface\n ILocalDistService mockLocalDistService = Mockito.mock(ILocalDistService.class);\n // Create a mock for the sessionMap\n Map mockSessionMap = new HashMap<>();\n // Create a mock for the routeMap\n Map> mockRouteMap = new HashMap<>();\n // Create a mock for the resourceThrottler\n IResourceThrottler mockResourceThrottler = Mockito.mock(IResourceThrottler.class);\n\n // Set up the mock objects\n Mockito.when(mockTenantMeter.get(Mockito.anyString())).thenReturn(mockTenantMeter);\n Mockito.when(mockLocalDistService.isGlobal(Mockito.anyString())).thenReturn(false);\n Mockito.when(mockSessionMap.get(Mockito.anyString())).thenReturn(null);\n Mockito.when(mockRouteMap.get(Mockito.any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(null));\n Mockito.when(mockResourceThrottler.hasResource(Mockito.anyString(), Mockito.anyLong())).thenReturn(false);\n\n // Create a new instance of the class under test\n DistService distService = new DistService(mockTenantMeter, mockLocalDistService, mockSessionMap, mockRouteMap, mockResourceThrottler);\n\n // Create a test request\n DeliveryRequest request = DeliveryRequest.newBuilder().build();\n\n // Call the method under test\n CompletableFuture replyFuture = distService.dist(request);\n\n // Verify the results\n DeliveryReply reply = replyFuture.join();\n Assert.assertEquals(0, reply.getResultMap().size());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 71, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchOK()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchOK", + "reference": " @Test\n public void distClientMatchOK() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n when(session.publish(any(), anyList())).thenReturn(true);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.OK);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.OK);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.OK));\n for (IMQTTTransientSession session : sessions) {\n verify(session, times(1)).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchOK() {\n // Given\n DeliveryRequest request = DeliveryRequest.newBuilder()\n .putPackage(\"tenant1\", DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .setMessagePack(TopicMessagePack.newBuilder()\n .setTopic(\"topic1\")\n .setPayload(\"payload1\")\n .build())\n .addMatchInfo(MatchInfo.newBuilder()\n .setReceiverId(\"receiver1\")\n .setTopicFilter(\"topic1\")\n .build())\n .build())\n .build())\n .build();\n sessionMap.put(\"receiver1\", mock(IMQTTTransientSession.class));\n when(sessionMap.get(\"receiver1\").publish(any(MatchInfo.class), anyList())).thenReturn(true);\n when(routeMap.get(any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(LocalRoutes.newBuilder()\n .setLocalizedReceiverId(\"receiver1\")\n .putRoute(\"receiver1\", mock(IMQTTTransientSession.class))\n .build()));\n when(ITenantMeter.get(\"tenant1\")).thenReturn(mock(ITenantMeter.class));\n when(resourceThrottler.hasResource(\"tenant1\", TotalTransientFanOutBytesPerSeconds)).thenReturn(true);\n\n // When\n CompletableFuture replyFuture = dist(request);\n\n // Then\n DeliveryReply reply = replyFuture.join();\n DeliveryResults results = reply.getResultMap().get(\"tenant1\");\n assertEquals(1, results.getResultCount());\n DeliveryResult result = results.getResult(0);\n assertEquals(DeliveryResult.Code.OK, result.getCode());\n assertEquals(\"receiver1\", result.getMatchInfo().getReceiverId());\n assertEquals(\"topic1\", result.getMatchInfo().getTopicFilter());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 72, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Range calcInterquartileRangeApproximation(List sortedValues) {\n Range iqr = null;\n\n if (sortedValues.size() > 0) {\n BigDecimal q1 = valueCloseToQ1(sortedValues);\n BigDecimal q3 = valueCloseToQ3(sortedValues);\n BigDecimal max = sortedValues.get(sortedValues.size() - 1);\n\n // if iqr == 0 AND max == q3 then recursively try finding a non-zero iqr approximation.\n if (q1.compareTo(q3) == 0 && max.compareTo(q3) == 0) {\n // filter out max and try again\n iqr = this.calcInterquartileRangeApproximation(\n sortedValues.stream().filter(d -> d.compareTo(max) < 0).collect(Collectors.toList()));\n }\n\n // if range is still empty use the original q1 and q3 values\n if (iqr == null || iqr.isEmpty()) {\n iqr = Range.closedOpen(q1, q3);\n }\n }\n\n return iqr;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCalcInterquartileRangeApproximation()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCalcInterquartileRangeApproximation", + "reference": " @Test\n public void testCalcInterquartileRangeApproximation() {\n List values = decList(1, 2, 3, 4, 5, 6, 7, 8);\n assertEquals(new BigDecimal(\"3\"), dataBinHelper.calcInterquartileRangeApproximation(values).lowerEndpoint());\n assertEquals(new BigDecimal(\"7\"), dataBinHelper.calcInterquartileRangeApproximation(values).upperEndpoint());\n }", + "completions": [ + "```java\n @Test\n public void testCalcInterquartileRangeApproximation() {\n List sortedValues = Arrays.asList(\n new BigDecimal(\"1.0\"),\n new BigDecimal(\"2.0\"),\n new BigDecimal(\"3.0\"),\n new BigDecimal(\"4.0\"),\n new BigDecimal(\"5.0\"),\n new BigDecimal(\"6.0\"),\n new BigDecimal(\"7.0\"),\n new BigDecimal(\"8.0\"),\n new BigDecimal(\"9.0\"),\n new BigDecimal(\"10.0\")\n );\n\n Range expectedIqr = Range.closedOpen(new BigDecimal(\"2.5\"), new BigDecimal(\"7.5\"));\n Range actualIqr = calcInterquartileRangeApproximation(sortedValues);\n\n assertEquals(expectedIqr, actualIqr);\n }\n```\n" + ], + "line": 30, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 73, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_non_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new NonSerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new NonSerializableObject(\"name2\"),\n new NonSerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObjectWithNonSerializableNestedObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serialable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key1\", \"value1\");\n map.put(\"key2\", new NonSerializableObject(\"nonSerializableValue\"));\n\n // When\n Map clonedMap = objectCloner.clone(map);\n\n // Then\n assertNotSame(map, clonedMap);\n assertEquals(map.size(), clonedMap.size());\n assertEquals(map.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(map.get(\"key2\"), clonedMap.get(\"key2\"));\n assertEquals(map.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 74, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SignedJWT token2Jwt(String token) {\n if (token == null) {\n log.error(\"Token is null.\");\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n try {\n return SignedJWT.parse(token);\n } catch (ParseException e) {\n log.error(\"Token to JWT Error,message = {}\", e.getMessage());\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testToken2Jwt()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testToken2Jwt", + "reference": " @Test\n public void testToken2Jwt() {\n Assertions.assertThrows(BizException.class, () -> tokenParseHelper.token2Jwt(null));\n }", + "completions": [ + "```java\n @Test\n public void testToken2Jwt() {\n // Given\n String token = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c\";\n\n // When\n SignedJWT jwt = token2Jwt(token);\n\n // Then\n assertNotNull(jwt);\n assertEquals(token, jwt.serialize());\n }\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testLongPollWaitIsLimited()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testLongPollWaitIsLimited", + "reference": " @Test\n public void testLongPollWaitIsLimited() {\n consumer = newConsumer();\n String topicName = \"topic1\";\n consumer.subscribe(singletonList(topicName));\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertTrue(consumer.assignment().isEmpty());\n\n final int partition = 3;\n final TopicPartition tp = new TopicPartition(topicName, partition);\n final List> records = asList(\n new ConsumerRecord<>(topicName, partition, 2, \"key1\", \"value1\"),\n new ConsumerRecord<>(topicName, partition, 3, \"key2\", \"value2\")\n );\n\n // On the first iteration, return no data; on the second, return two records\n doAnswer(invocation -> {\n // Mock the subscription being assigned as the first fetch is collected\n consumer.subscriptions().assignFromSubscribed(Collections.singleton(tp));\n return Fetch.empty();\n }).doAnswer(invocation -> {\n return Fetch.forPartition(tp, records, true);\n }).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n\n // And then poll for up to 10000ms, which should return 2 records without timing out\n ConsumerRecords returnedRecords = consumer.poll(Duration.ofMillis(10000));\n assertEquals(2, returnedRecords.count());\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertEquals(singleton(tp), consumer.assignment());\n }", + "completions": [ + "```java\n @Test\n public void testLongPollWaitIsLimited() {\n // Arrange\n Duration timeout = Duration.ofMillis(1000);\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // Act\n long startTime = System.currentTimeMillis();\n ConsumerRecords records = consumer.poll(timeout);\n long endTime = System.currentTimeMillis();\n\n // Assert\n long pollDuration = endTime - startTime;\n assertTrue(pollDuration >= timeout.toMillis());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 76, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForMinAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForMinAgg", + "reference": " @Test\n public void testFullIndexSearchForMinAgg() {\n Instant time = Instant.now();\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new MinAggBuilder(\n \"test\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalMin internalMin =\n (InternalMin) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n assertThat(Double.valueOf(internalMin.getValue()).longValue()).isEqualTo(time.toEpochMilli());\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForMinAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new MinAggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation() instanceof MinAgg);\n assertTrue(result.getElapsedTimeMicros() > 0);\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 77, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_bsaChecksumFetched", + "reference": " @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n setupMocks();\n when(connection.getResponseCode()).thenReturn(SC_OK);\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream(\"bsa-checksum\\ndata\".getBytes(UTF_8)));\n LazyBlockList download = fetcher.tryFetch(BlockListType.BLOCK);\n assertThat(download.getName()).isEqualTo(BlockListType.BLOCK);\n assertThat(download.checksum()).isEqualTo(\"bsa-checksum\");\n verify(connection, times(1)).setRequestMethod(\"GET\");\n verify(connection, times(1)).setRequestProperty(\"Authorization\", \"Bearer authToken\");\n }", + "completions": [ + "```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n // Given\n URL dataUrl = new URL(\"https://example.com/bsaChecksum\");\n HttpsURLConnection connection = mock(HttpsURLConnection.class);\n when(connection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);\n when(connection.getErrorStream()).thenReturn(new ByteArrayInputStream(\"\".getBytes()));\n when(urlConnectionService.createConnection(dataUrl)).thenReturn(connection);\n\n // When\n LazyBlockList lazyBlockList = bsaService.tryFetch(BlockListType.BSA_CHECKSUM);\n\n // Then\n assertEquals(BlockListType.BSA_CHECKSUM, lazyBlockList.getBlockListType());\n assertEquals(connection, lazyBlockList.getConnection());\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 78, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDisabledFullTextSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDisabledFullTextSearch", + "reference": " @Test\n public void testDisabledFullTextSearch() {\n Instant time = Instant.ofEpochSecond(1593365471);\n Trace.KeyValue field1Tag =\n Trace.KeyValue.newBuilder()\n .setVInt32(1234)\n .setKey(\"field1\")\n .setFieldType(Schema.SchemaFieldType.INTEGER)\n .build();\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(2, \"apple baby\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(3, \"baby car 1234\", time.plusSeconds(4)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Without the _all field as default.\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(3);\n\n // empty string\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"app*\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Returns baby or car, 2 messages.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby car\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Test numbers\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"apple 1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"123\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n }", + "completions": [ + "```java\n @Test\n public void testDisabledFullTextSearch() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getResults().size());\n assertNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 79, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyWithException", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(\n OperateRuntimeException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHitsRelation.EQUAL_TO));\n Mockito.when(searchHits.getHits()).thenReturn(new SearchHit[0]);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processInstanceListViewService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException to be thrown\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find unique process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 80, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_SecurityException_notRetriable", + "reference": " @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n when(connectionService.createConnection(any(URL.class)))\n .thenThrow(new GeneralSecurityException());\n assertThat(\n assertThrows(BsaException.class, () -> fetcher.tryFetch(BlockListType.BLOCK))\n .isRetriable())\n .isFalse();\n }", + "completions": [ + "```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n // Arrange\n String blockListTypeName = \"TEST\";\n BlockListType blockListType = new BlockListType(blockListTypeName);\n URL dataUrl = new URL(\"http://example.com\");\n when(blockListUrls.get(blockListTypeName)).thenReturn(dataUrl.toString());\n when(urlConnectionService.createConnection(dataUrl)).thenThrow(new GeneralSecurityException());\n\n // Act and Assert\n assertThrows(BsaException.class, () -> {\n try {\n blockListService.tryFetch(blockListType);\n } catch (BsaException e) {\n assertFalse(e.isRetriable());\n throw e;\n }\n });\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 81, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssignOnEmptyTopicPartition()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssignOnEmptyTopicPartition", + "reference": " @Test\n public void testAssignOnEmptyTopicPartition() {\n consumer = newConsumer();\n completeUnsubscribeApplicationEventSuccessfully();\n\n consumer.assign(Collections.emptyList());\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().isEmpty());\n }", + "completions": [ + "```java\n @Test\n public void testAssignOnEmptyTopicPartition() {\n // Given\n Collection partitions = new ArrayList<>();\n ConsumerImpl consumer = new ConsumerImpl();\n\n // When\n consumer.assign(partitions);\n\n // Then\n // Add assertions to verify the expected behavior\n // For example, you can check if the consumer is unsubscribed and if the fetch buffer is empty\n // assertFalse(consumer.isSubscribed());\n // assertTrue(consumer.getFetchBuffer().isEmpty());\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 82, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ImmutableList applyLabelDiff(\n ImmutableList labels,\n IdnChecker idnChecker,\n DownloadSchedule schedule,\n DateTime now) {\n ImmutableList.Builder nonBlockedDomains = new ImmutableList.Builder<>();\n ImmutableMap> labelsByType =\n ImmutableMap.copyOf(\n labels.stream().collect(groupingBy(BlockLabel::labelType, toImmutableList())));\n\n tm().transact(\n () -> {\n for (Map.Entry> entry :\n labelsByType.entrySet()) {\n switch (entry.getKey()) {\n case CREATE:\n // With current Cloud SQL, label upsert throughput is about 200/second. If\n // better performance is needed, consider bulk insert in native SQL.\n tm().putAll(\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(\n label ->\n new BsaLabel(label.label(), schedule.jobCreationTime()))\n .collect(toImmutableList()));\n // May not find all unblockables due to race condition: DomainCreateFlow uses\n // cached BsaLabels. Eventually will be consistent.\n nonBlockedDomains.addAll(\n tallyUnblockableDomainsForNewLabels(entry.getValue(), idnChecker, now));\n break;\n case DELETE:\n ImmutableSet deletedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n // Delete labels in DB. Also cascade-delete BsaUnblockableDomain.\n int nDeleted = Queries.deleteBsaLabelByLabels(deletedLabels);\n if (nDeleted != deletedLabels.size()) {\n logger.atSevere().log(\n \"Only found %s entities among the %s labels: [%s]\",\n nDeleted, deletedLabels.size(), deletedLabels);\n }\n break;\n case NEW_ORDER_ASSOCIATION:\n ImmutableSet affectedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n ImmutableSet labelsInDb =\n Queries.queryBsaLabelByLabels(affectedLabels)\n .map(BsaLabel::getLabel)\n .collect(toImmutableSet());\n verify(\n labelsInDb.size() == affectedLabels.size(),\n \"Missing labels in DB: %s\",\n LazyArgs.lazy(() -> difference(affectedLabels, labelsInDb)));\n\n // Reuse registered and reserved names that are already computed.\n Queries.queryBsaUnblockableDomainByLabels(affectedLabels)\n .map(BsaUnblockableDomain::toUnblockableDomain)\n .forEach(nonBlockedDomains::add);\n\n for (BlockLabel label : entry.getValue()) {\n getInvalidTldsForLabel(label, idnChecker)\n .map(tld -> UnblockableDomain.of(label.label(), tld, Reason.INVALID))\n .forEach(nonBlockedDomains::add);\n }\n break;\n }\n }\n },\n TRANSACTION_REPEATABLE_READ);\n logger.atInfo().log(\"Processed %s of labels.\", labels.size());\n return nonBlockedDomains.build();\n }\n```\n###test function signature:\n```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "applyLabelDiffs_newAssociationOfLabelToOrder", + "reference": " @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n tm().transact(\n () -> {\n tm().insert(new BsaLabel(\"label\", fakeClock.nowUtc()));\n tm().insert(new BsaUnblockableDomain(\"label\", \"app\", Reason.REGISTERED));\n });\n when(idnChecker.getSupportingTlds(any())).thenReturn(ImmutableSet.of(app));\n when(idnChecker.getForbiddingTlds(any()))\n .thenReturn(Sets.difference(ImmutableSet.of(dev), ImmutableSet.of()).immutableCopy());\n\n ImmutableList unblockableDomains =\n applyLabelDiff(\n ImmutableList.of(\n BlockLabel.of(\"label\", LabelType.NEW_ORDER_ASSOCIATION, ImmutableSet.of())),\n idnChecker,\n schedule,\n fakeClock.nowUtc());\n assertThat(unblockableDomains)\n .containsExactly(\n UnblockableDomain.of(\"label.app\", UnblockableDomain.Reason.REGISTERED),\n UnblockableDomain.of(\"label.dev\", UnblockableDomain.Reason.INVALID));\n assertThat(tm().transact(() -> tm().loadByKeyIfPresent(BsaLabel.vKey(\"label\")))).isPresent();\n assertThat(\n tm().transact(() -> tm().loadByKeyIfPresent(BsaUnblockableDomain.vKey(\"label\", \"app\"))))\n .isPresent();\n }", + "completions": [ + "```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n // Given\n ImmutableList labels = ImmutableList.of(BlockLabel.of(\"label1\", LabelType.NEW_ORDER_ASSOCIATION));\n IdnChecker idnChecker = mock(IdnChecker.class);\n DownloadSchedule schedule = mock(DownloadSchedule.class);\n DateTime now = DateTime.now();\n\n // When\n ImmutableList result = applyLabelDiff(labels, idnChecker, schedule, now);\n\n // Then\n // Add assertions here to verify the result\n }\n```\n" + ], + "line": 85, + "token": 602, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 83, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture createTopic(CreateTopicRequest request) {\n CompletableFuture future = new CompletableFuture<>();\n for (; ; ) {\n if (metadataStore.isLeader()) {\n try (SqlSession session = metadataStore.openSession()) {\n if (!metadataStore.maintainLeadershipWithSharedLock(session)) {\n continue;\n }\n TopicMapper topicMapper = session.getMapper(TopicMapper.class);\n if (null != topicMapper.get(null, request.getTopic())) {\n ControllerException e = new ControllerException(Code.DUPLICATED_VALUE,\n String.format(\"Topic %s was taken\", request.getTopic()));\n future.completeExceptionally(e);\n return future;\n }\n\n Topic topic = new Topic();\n topic.setName(request.getTopic());\n topic.setQueueNum(request.getCount());\n topic.setStatus(TopicStatus.TOPIC_STATUS_ACTIVE);\n topic.setAcceptMessageTypes(JsonFormat.printer().print(request.getAcceptTypes()));\n topic.setRetentionHours(request.getRetentionHours());\n topicMapper.create(topic);\n long topicId = topic.getId();\n List assignments = createQueues(IntStream.range(0, request.getCount()),\n topicId, session);\n // Commit transaction\n session.commit();\n\n // Cache new topic and queue assignments immediately\n topicCache.apply(List.of(topic));\n assignmentCache.apply(assignments);\n future.complete(topicId);\n } catch (ControllerException | InvalidProtocolBufferException e) {\n future.completeExceptionally(e);\n }\n return future;\n } else {\n Optional leaderAddress = metadataStore.electionService().leaderAddress();\n if (leaderAddress.isEmpty()) {\n return CompletableFuture.failedFuture(new ControllerException(Code.NO_LEADER_VALUE, \"No leader is elected yet\"));\n }\n metadataStore.controllerClient().createTopic(leaderAddress.get(), request).whenComplete((res, e) -> {\n if (null != e) {\n future.completeExceptionally(e);\n } else {\n future.complete(res);\n }\n });\n }\n break;\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCreateTopic()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCreateTopic", + "reference": " @Test\n public void testCreateTopic() {\n MetadataStore metadataStore = Mockito.mock(MetadataStore.class);\n ControllerClient controllerClient = Mockito.mock(ControllerClient.class);\n Mockito.when(controllerClient.createTopic(ArgumentMatchers.anyString(), ArgumentMatchers.any())).thenReturn(\n CompletableFuture.failedFuture(new CompletionException(new RuntimeException()))\n );\n Mockito.when(metadataStore.controllerClient()).thenReturn(controllerClient);\n Mockito.when(metadataStore.isLeader()).thenReturn(false);\n\n ElectionService electionService = Mockito.mock(ElectionService.class);\n Mockito.when(metadataStore.electionService()).thenReturn(electionService);\n\n Mockito.when(electionService.leaderAddress()).thenReturn(Optional.of(\"localhost:1234\"));\n\n TopicManager topicManager = new TopicManager(metadataStore);\n topicManager.createTopic(CreateTopicRequest.newBuilder().build());\n }", + "completions": [ + "```java\n @Test\n public void testCreateTopic() {\n // Given\n CreateTopicRequest request = new CreateTopicRequest();\n request.setTopic(\"testTopic\");\n request.setCount(3);\n request.setRetentionHours(24);\n request.setAcceptTypes(List.of(MessageType.MESSAGE_TYPE_JSON));\n\n // When\n CompletableFuture future = controller.createTopic(request);\n\n // Then\n assertDoesNotThrow(() -> {\n Long topicId = future.get();\n assertNotNull(topicId);\n assertTrue(topicId > 0);\n });\n }\n```\n" + ], + "line": 63, + "token": 424, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 84, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull(){\n try(MockedStatic sessionMockedStatic = Mockito.mockStatic(Session.class)){\n Session mockSession = mock(Session.class);\n Transport transport = mock(Transport.class);\n sessionMockedStatic.when(()->Session.getDefaultInstance(any(Properties.class))).thenReturn(mockSession);\n when(blacklistService.inBlacklist(anyString())).thenReturn(false);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n when(settingsService.getByCategoryAndKey(anyString(),anyString())).thenReturn(\"test\").thenReturn(5678).thenReturn(\"test\").thenReturn(\"test\").thenReturn(\"123456\");\n when(mockSession.getTransport(anyString())).thenReturn(transport);\n doNothing().when(transport).connect(anyString(),anyInt(),anyString(),anyString());\n doAnswer(invocationOnMock -> {\n InternetAddress[] internetAddressList = invocationOnMock.getArgument(1);\n Assertions.assertEquals(addressList.get(0),internetAddressList[0].getAddress());\n return null;\n }).when(transport).sendMessage(any(MimeMessage.class),any());\n mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n } catch (MessagingException e) {\n throw new RuntimeException(e);\n }\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = Arrays.asList(\"test1@example.com\", \"test2@example.com\");\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://example.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 85, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead", + "reference": " @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 1,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n AstraConfigs.IndexerConfig headLocationAndNoRecoveryConfig =\n AstraConfigs.IndexerConfig.newBuilder()\n .setCreateRecoveryTasksOnStart(false)\n .setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST)\n .build();\n\n // When there is no data and ReadFromLocationOnStart is set to LATEST, return the current head\n assertThat(\n recoveryTaskCreator.determineStartingOffset(1000, 0, headLocationAndNoRecoveryConfig))\n .isEqualTo(1000);\n\n // Data exists for not for this partition.\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTime = 1;\n final long endTime = 100;\n final long maxOffset = 100;\n\n final SnapshotMetadata partition1 =\n new SnapshotMetadata(name, path, startTime, endTime, maxOffset, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition1);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition1));\n\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n\n final SnapshotMetadata partition11 =\n new SnapshotMetadata(\n name + \"1\", path, endTime + 1, endTime * 2, maxOffset * 2, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition11);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition11));\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(partition1, partition11);\n assertThat(recoveryTaskCreator.determineStartingOffset(0, 0, indexerConfig)).isNegative();\n\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n \"2\",\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(false);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n\n // When\n long result = determineStartingOffset(currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(currentEndOffsetForPartition, result);\n }\n```\n" + ], + "line": 128, + "token": 935, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 86, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime =\n LocalDateTime.of(2020, 10, 1, 10, 10, 0).atZone(ZoneOffset.UTC).toInstant();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n long startOffset = 1;\n long endOffset = msgsToProduce - 1;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry))\n .isEqualTo(endOffset - startOffset + 1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(0);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(0);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isEmpty();\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 0,\n 100,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 87, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"3g\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n String expirationTime = \"10x\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(expirationTime);\n assertEquals(CacheControlHeader.DEFAULT_BASE_VALUE + CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_empty_collection()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_empty_collection", + "reference": " @Test\n public void should_clone_empty_collection() {\n List original = new ArrayList<>();\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_empty_collection() {\n // Given\n Collection collection = new ArrayList<>();\n\n // When\n Collection clonedCollection = (Collection) clone(collection);\n\n // Then\n assertNotSame(collection, clonedCollection);\n assertEquals(collection, clonedCollection);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 89, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testTimeBoundSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testTimeBoundSearch", + "reference": " @Test\n public void testTimeBoundSearch() {\n Instant time = Instant.now();\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(1, time));\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(2, time.plusSeconds(100)));\n strictLogStore.logStore.commit();\n strictLogStore.logStore.refresh();\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(2);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(0);\n assertThat(getTimerCount(REFRESHES_TIMER, strictLogStore.metricsRegistry)).isEqualTo(1);\n\n // Start inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Extended range still only picking one element.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(90).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Both ranges are inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.toEpochMilli(),\n time.plusSeconds(100).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n // Extended range to pick up both events\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(1000).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n }", + "completions": [ + "```java\n @Test\n public void testTimeBoundSearch() {\n // Given\n String dataset = \"test_dataset\";\n String queryStr = \"test_query\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n for (LogMessage logMessage : result.getResults()) {\n assertTrue(logMessage.getTimestampMs() >= startTimeMsEpoch);\n assertTrue(logMessage.getTimestampMs() <= endTimeMsEpoch);\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 90, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "\t@ParameterizedTest\n\t@MethodSource", + "reference": "\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\t// setup\n\t\tString key = \"key\";\n\t\tTestParameters parameters = new TestParameters();\n\t\tObject testValue = theParams.getTestValue();\n\n\t\t// test\n\t\tif (theParams.isExpectedToWork()) {\n\t\t\tparameters.setUserData(key, testValue);\n\t\t\tassertFalse(parameters.getUserData().isEmpty());\n\t\t\tassertEquals(testValue, parameters.getUserData().get(key));\n\t\t} else {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, testValue);\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tString dataType = testValue.getClass().getName();\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid data type provided \" + dataType),\n\t\t\t\t\tex.getMessage());\n\t\t\t\tassertTrue(parameters.getUserData().isEmpty());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\tUser user = new User();\n\t\tuser.setUserData(theParams.getKey(), theParams.getValue());\n\t\tassertEquals(theParams.getValue(), user.getUserData().get(theParams.getKey()));\n\t}\n\n\tstatic Stream parameters() {\n\t\treturn Stream.of(\n\t\t\t\tnew TestParam(\"stringKey\", \"stringValue\"),\n\t\t\t\tnew TestParam(\"numberKey\", 123),\n\t\t\t\tnew TestParam(\"booleanKey\", true),\n\t\t\t\tnew TestParam(\"nullKey\", null)\n\t\t);\n\t}\n\n\tstatic class TestParam {\n\t\tprivate final String key;\n\t\tprivate final Object value;\n\n\t\tTestParam(String key, Object value) {\n\t\t\tthis.key = key;\n\t\t\tthis.value = value;\n\t\t}\n\n\t\tString getKey() {\n\t\t\treturn key;\n\t\t}\n\n\t\tObject getValue() {\n\t\t\treturn value;\n\t\t}\n\t}\n```\n" + ], + "line": 24, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\n\t\tString text = \"Hello, how are you?\";\n\n\t\tEmbedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));\n\n\t\tserver\n\t\t\t.expect(requestToUriTemplate(\"/models/{generative}:embedText?key={apiKey}\",\n\t\t\t\t\tVertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))\n\t\t\t.andExpect(method(HttpMethod.POST))\n\t\t\t.andExpect(content().json(objectMapper.writeValueAsString(Map.of(\"text\", text))))\n\t\t\t.andRespond(withSuccess(objectMapper.writeValueAsString(Map.of(\"embedding\", expectedEmbedding)),\n\t\t\t\t\tMediaType.APPLICATION_JSON));\n\n\t\tEmbedding embedding = client.embedText(text);\n\n\t\tassertThat(embedding).isEqualTo(expectedEmbedding);\n\n\t\tserver.verify();\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\tEmbeddingResponse response = new EmbeddingResponse(expectedEmbedding);\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(response);\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Mono syncAclWithAclCsv(KafkaCluster cluster, String csv) {\n return adminClientService.get(cluster)\n .flatMap(ac -> ac.listAcls(ResourcePatternFilter.ANY).flatMap(existingAclList -> {\n var existingSet = Set.copyOf(existingAclList);\n var newAcls = Set.copyOf(AclCsv.parseCsv(csv));\n var toDelete = Sets.difference(existingSet, newAcls);\n var toAdd = Sets.difference(newAcls, existingSet);\n logAclSyncPlan(cluster, toAdd, toDelete);\n if (toAdd.isEmpty() && toDelete.isEmpty()) {\n return Mono.empty();\n }\n log.info(\"Starting new ACLs creation\");\n return ac.createAcls(toAdd)\n .doOnSuccess(v -> {\n log.info(\"{} new ACLs created\", toAdd.size());\n log.info(\"Starting ACLs deletion\");\n })\n .then(ac.deleteAcls(toDelete)\n .doOnSuccess(v -> log.info(\"{} ACLs deleted\", toDelete.size())));\n }));\n }\n```\n###test function signature:\n```java\n @Test\n void testSyncAclWithAclCsv()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSyncAclWithAclCsv", + "reference": " @Test\n void testSyncAclWithAclCsv() {\n var existingBinding1 = new AclBinding(\n new ResourcePattern(ResourceType.TOPIC, \"*\", PatternType.LITERAL),\n new AccessControlEntry(\"User:test1\", \"*\", AclOperation.READ, AclPermissionType.ALLOW));\n\n var existingBinding2 = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"group1\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test2\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n var newBindingToBeAdded = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"groupNew\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test3\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n when(adminClientMock.listAcls(ResourcePatternFilter.ANY))\n .thenReturn(Mono.just(List.of(existingBinding1, existingBinding2)));\n\n ArgumentCaptor> createdCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.createAcls(createdCaptor.capture()))\n .thenReturn(Mono.empty());\n\n ArgumentCaptor> deletedCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.deleteAcls(deletedCaptor.capture()))\n .thenReturn(Mono.empty());\n\n aclsService.syncAclWithAclCsv(\n CLUSTER,\n \"Principal,ResourceType, PatternType, ResourceName,Operation,PermissionType,Host\\n\"\n + \"User:test1,TOPIC,LITERAL,*,READ,ALLOW,*\\n\"\n + \"User:test3,GROUP,PREFIXED,groupNew,DESCRIBE,DENY,localhost\"\n ).block();\n\n Collection createdBindings = createdCaptor.getValue();\n assertThat(createdBindings)\n .hasSize(1)\n .contains(newBindingToBeAdded);\n\n Collection deletedBindings = deletedCaptor.getValue();\n assertThat(deletedBindings)\n .hasSize(1)\n .contains(existingBinding2);\n }", + "completions": [ + "```java\n @Test\n void testSyncAclWithAclCsv() {\n // Given\n KafkaCluster cluster = new KafkaCluster(\"testCluster\", \"localhost:9092\");\n String csv = \"principal,host,operation,permissionType\\nUser:Alice,*,Read,Allow\\nUser:Bob,*,Write,Allow\";\n AdminClient adminClient = Mockito.mock(AdminClient.class);\n when(adminClientService.get(cluster)).thenReturn(Mono.just(adminClient));\n when(adminClient.listAcls(ResourcePatternFilter.ANY)).thenReturn(Flux.just(\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Alice\", \"*\", AclOperation.READ, AclPermissionType.ALLOW)),\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Bob\", \"*\", AclOperation.WRITE, AclPermissionType.ALLOW))\n ).collectList());\n\n // When\n Mono result = service.syncAclWithAclCsv(cluster, csv);\n\n // Then\n StepVerifier.create(result)\n .verifyComplete();\n\n verify(adminClient).createAcls(any());\n verify(adminClient).deleteAcls(any());\n }\n```\n" + ], + "line": 32, + "token": 259, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 93, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyNoResults", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final SearchResponse mockResponse = mock(SearchResponse.class);\n final SearchHits mockHits = mock(SearchHits.class);\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n when(mockHits.getHits()).thenReturn(new SearchHit[0]);\n when(tenantAwareClient.search(any(SearchRequest.class))).thenReturn(mockResponse);\n\n // When\n try {\n processInstanceService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 94, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyMultipleRecoveryBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(3);\n assertThat(recoveryTasks).contains(recoveryTask1, recoveryTask11);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo(recoveryStartOffset * 3 + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(100);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 95, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n private String saveAndChecksumBlockList(String jobName, LazyBlockList blockList) {\n BlobId blobId = getBlobId(jobName, getBlockListFileName(blockList.getName()));\n try (BufferedOutputStream gcsWriter =\n new BufferedOutputStream(gcsUtils.openOutputStream(blobId))) {\n MessageDigest messageDigest = MessageDigest.getInstance(checksumAlgorithm);\n blockList.consumeAll(\n (byteArray, length) -> {\n try {\n gcsWriter.write(byteArray, 0, length);\n } catch (IOException e) {\n throw new RuntimeException(e);\n }\n messageDigest.update(byteArray, 0, length);\n });\n return base16().lowerCase().encode(messageDigest.digest());\n } catch (IOException | NoSuchAlgorithmException e) {\n throw new RuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "saveAndChecksumBlockList_success", + "reference": " @Test\n void saveAndChecksumBlockList_success() throws Exception {\n String payload = \"somedata\\n\";\n String payloadChecksum = \"0737c8e591c68b93feccde50829aca86a80137547d8cfbe96bab6b20f8580c63\";\n\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream((\"bsa-checksum\\n\" + payload).getBytes(UTF_8)));\n lazyBlockList = new LazyBlockList(BlockListType.BLOCK, connection);\n\n ImmutableMap checksums =\n gcsClient.saveAndChecksumBlockList(\"some-name\", ImmutableList.of(lazyBlockList));\n assertThat(gcsUtils.existsAndNotEmpty(BlobId.of(\"my-bucket\", \"some-name/BLOCK.csv\"))).isTrue();\n assertThat(checksums).containsExactly(BlockListType.BLOCK, payloadChecksum);\n assertThat(gcsClient.readBlockList(\"some-name\", BlockListType.BLOCK))\n .containsExactly(\"somedata\");\n }", + "completions": [ + "```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception {\n // Given\n String jobName = \"testJob\";\n LazyBlockList blockList = mock(LazyBlockList.class);\n BlobId blobId = mock(BlobId.class);\n BufferedOutputStream gcsWriter = mock(BufferedOutputStream.class);\n MessageDigest messageDigest = mock(MessageDigest.class);\n when(gcsUtils.openOutputStream(blobId)).thenReturn(gcsWriter);\n when(checksumAlgorithm).thenReturn(\"SHA-256\");\n when(messageDigest.digest()).thenReturn(new byte[] {1, 2, 3, 4, 5});\n when(base16().lowerCase().encode(messageDigest.digest())).thenReturn(\"0102030405\");\n\n // When\n String checksum = saveAndChecksumBlockList(jobName, blockList);\n\n // Then\n assertEquals(\"0102030405\", checksum);\n verify(gcsWriter).write(any(), anyInt(), anyInt());\n verify(messageDigest).update(any(), anyInt(), anyInt());\n }\n```\n" + ], + "line": 30, + "token": 211, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ListenableFuture> findMissingBlobs(\n Iterable blobDigests, RequestMetadata requestMetadata) {\n // Some requests have been blocked, and we should tell the client we refuse to perform a lookup.\n try {\n if (inDenyList(requestMetadata)) {\n return immediateFailedFuture(\n Status.UNAVAILABLE\n .withDescription(\"The action associated with this request is forbidden\")\n .asException());\n }\n } catch (IOException e) {\n return immediateFailedFuture(Status.fromThrowable(e).asException());\n }\n\n // Empty blobs are an exceptional case. Filter them out.\n // If the user only requested empty blobs we can immediately tell them we already have it.\n Iterable nonEmptyDigests =\n Iterables.filter(blobDigests, (digest) -> digest.getSizeBytes() != 0);\n if (Iterables.isEmpty(nonEmptyDigests)) {\n return immediateFuture(ImmutableList.of());\n }\n\n if (configs.getServer().isFindMissingBlobsViaBackplane()) {\n return findMissingBlobsViaBackplane(nonEmptyDigests, requestMetadata);\n }\n\n return findMissingBlobsQueryingEachWorker(nonEmptyDigests, requestMetadata);\n }\n```\n###test function signature:\n```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "findMissingBlobsTest_ViaBackPlane", + "reference": " @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n Set activeWorkers = ImmutableSet.of(\"worker1\", \"worker2\", \"worker3\");\n Set expiredWorkers = ImmutableSet.of(\"workerX\", \"workerY\", \"workerZ\");\n Set imposterWorkers = ImmutableSet.of(\"imposter1\", \"imposter2\", \"imposter3\");\n\n Set availableDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFound1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build());\n\n Set missingDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"missing1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build());\n\n Set digestAvailableOnImposters =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFoundOnImposter1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter3\").setSizeBytes(1).build());\n\n Set emptyDigests =\n new HashSet<>(\n Arrays.asList(\n Digest.newBuilder().setHash(\"empty1\").build(),\n Digest.newBuilder().setHash(\"empty2\").build()));\n\n Iterable allDigests =\n Iterables.concat(\n availableDigests,\n missingDigests,\n emptyDigests,\n digestAvailableOnImposters,\n Arrays.asList(\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build()));\n\n Map> digestAndWorkersMap = new HashMap<>();\n\n for (Digest digest : availableDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(activeWorkers));\n }\n for (Digest digest : missingDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(expiredWorkers));\n }\n for (Digest digest : digestAvailableOnImposters) {\n digestAndWorkersMap.put(digest, getRandomSubset(imposterWorkers));\n }\n\n BuildfarmConfigs buildfarmConfigs = instance.getBuildFarmConfigs();\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(true);\n Set activeAndImposterWorkers =\n Sets.newHashSet(Iterables.concat(activeWorkers, imposterWorkers));\n\n when(mockBackplane.getStorageWorkers()).thenReturn(activeAndImposterWorkers);\n when(mockBackplane.getBlobDigestsWorkers(any(Iterable.class))).thenReturn(digestAndWorkersMap);\n when(mockInstanceLoader.load(anyString())).thenReturn(mockWorkerInstance);\n when(mockWorkerInstance.findMissingBlobs(anyIterable(), any(RequestMetadata.class)))\n .thenReturn(Futures.immediateFuture(new ArrayList<>()));\n\n long serverStartTime = 1686951033L; // june 15th, 2023\n Map workersStartTime = new HashMap<>();\n for (String worker : activeAndImposterWorkers) {\n workersStartTime.put(worker, serverStartTime);\n }\n when(mockBackplane.getWorkersStartTimeInEpochSecs(activeAndImposterWorkers))\n .thenReturn(workersStartTime);\n long oneDay = 86400L;\n for (Digest digest : availableDigests) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime + oneDay);\n }\n for (Digest digest : digestAvailableOnImposters) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime - oneDay);\n }\n\n Iterable actualMissingDigests =\n instance.findMissingBlobs(allDigests, RequestMetadata.getDefaultInstance()).get();\n Iterable expectedMissingDigests =\n Iterables.concat(missingDigests, digestAvailableOnImposters);\n\n assertThat(actualMissingDigests).containsExactlyElementsIn(expectedMissingDigests);\n verify(mockWorkerInstance, atMost(3))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n verify(mockWorkerInstance, atLeast(1))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n\n for (Digest digest : actualMissingDigests) {\n assertThat(digest).isNotIn(availableDigests);\n assertThat(digest).isNotIn(emptyDigests);\n assertThat(digest).isIn(expectedMissingDigests);\n }\n\n // reset BuildfarmConfigs\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(false);\n }", + "completions": [ + "```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n // Given\n Iterable blobDigests = Arrays.asList(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n RequestMetadata requestMetadata = RequestMetadata.newBuilder().setClientId(\"client1\").build();\n\n // When\n ListenableFuture> future =\n underTest.findMissingBlobs(blobDigests, requestMetadata);\n\n // Then\n Iterable missingBlobs = future.get();\n assertThat(missingBlobs).containsExactly(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n }\n```\n" + ], + "line": 37, + "token": 276, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 97, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent_notFound()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent_notFound", + "reference": " @Test\n public void testReadComponent_notFound() {\n Component component = mock(Component.class);\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(null);\n try {\n underTest.readComponent(\"someid\", \"testRepositoryName\");\n fail(\"Exception should have been thrown\");\n }\n catch (WebApplicationException e) {\n assertThat(e.getResponse(), is(notNullValue()));\n assertThat(e.getResponse().getStatus(), is(404));\n }\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent_notFound() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(() -> storageTx);\n when(storageTx.findComponent(componentId)).thenReturn(null);\n\n // When\n try {\n readComponent(repository, componentId);\n fail(\"Expected WebApplicationException\");\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Status.NOT_FOUND.getStatusCode(), e.getResponse().getStatus());\n }\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture auth(MQTT3AuthData authData) {\n try {\n HttpRequest request = HttpRequest.newBuilder()\n .uri(URI.create(webhookURI + \"/auth\"))\n .header(\"Content-Type\", \"application/json\")\n .POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))\n .timeout(Duration.ofSeconds(5))\n .build();\n return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())\n .thenApply(response -> {\n if (response.statusCode() == 200) {\n try {\n MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();\n JsonFormat.parser()\n .ignoringUnknownFields()\n .merge(response.body(), resultBuilder);\n return resultBuilder.build();\n } catch (Throwable e) {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build();\n }\n } else {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(\"Authenticate failed\")\n .build())\n .build();\n }\n })\n .exceptionally(e -> {\n System.out.println(\"Failed to call webhook: \" + e.getMessage());\n return null;\n });\n } catch (Throwable e) {\n return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAuth()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAuth", + "reference": " @Test\n public void testAuth() {\n MQTT3AuthData authedUser = MQTT3AuthData.newBuilder()\n .setUsername(\"authUser\")\n .build();\n\n MQTT3AuthData unauthUser = MQTT3AuthData.newBuilder()\n .setUsername(\"unauthUser\")\n .build();\n testServer.addAuthedUser(authedUser.getUsername());\n WebHookBasedAuthProvider provider = new WebHookBasedAuthProvider(testServer.getURI());\n MQTT3AuthResult authResult = provider.auth(authedUser).join();\n assertTrue(authResult.hasOk());\n\n authResult = provider.auth(unauthUser).join();\n assertFalse(authResult.hasOk());\n }", + "completions": [ + "```java\n @Test\n public void testAuth() {\n // Given\n String webhookURI = \"http://localhost:8080\";\n HttpClient httpClient = HttpClient.newHttpClient();\n MQTT3AuthData authData = MQTT3AuthData.newBuilder()\n .setClientId(\"client1\")\n .setUsername(\"user1\")\n .setPassword(\"password1\")\n .build();\n MQTT3AuthResult expectedResult = MQTT3AuthResult.newBuilder()\n .setSuccess(Success.newBuilder()\n .setIsSuperuser(false)\n .build())\n .build();\n\n // When\n CompletableFuture result = new MQTT3AuthService(webhookURI, httpClient).auth(authData);\n\n // Then\n assertNotNull(result);\n assertEquals(expectedResult, result.join());\n }\n```\n" + ], + "line": 59, + "token": 335, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 99, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_success", + "reference": " @Test\n public void testValidateRoleDuplicate_success() {\n // 调用,不会抛异常\n roleService.validateRoleDuplicate(randomString(), randomString(), null);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_success() {\n // Given\n String name = \"TestRole\";\n String code = \"TEST_ROLE\";\n Long id = 1L;\n\n // When\n // Call the method to be tested\n validateRoleDuplicate(name, code, id);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/7/EI/line_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Java/7/EI/line_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..3319de71b94a2739f5e91a13a2386b534e20e0b9 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/7/EI/line_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~205),subset_1(205~279),subset_2(279~354),subset_3(358~424),subset_4(486~606),subset_5(606~952) +StarCoder2-15b,31.29,33.67,27.94,24.39,20.96,13.83,24.52 +CodeLlama-7b,33.59,33.65,25.37,34.67,20.62,28.42,20.99 +CodeLlama-13b,29.62,31.82,25.93,30.93,21.50,20.41,23.66 +CodeLlama-34b,30.52,33.61,24.26,29.43,22.09,20.93,25.67 +DeepSeek-Coder-1.3b,28.96,33.58,27.15,28.66,22.93,16.45,23.89 +DeepSeek-Coder-6.7b,20.37,20.67,19.51,25.49,21.52,24.54,24.81 +DeepSeek-Coder-33b,33.64,34.81,32.11,34.68,21.33,28.31,18.29 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/7/EI/token_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Java/7/EI/token_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..03939df5a71362fc9ae7dfa35c1c0333c1a3ce5a --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/7/EI/token_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~255),subset_1(259~367),subset_2(373~486),subset_3(490~602),subset_4(606~681),subset_5(793~793),subset_6(935~952) +StarCoder2-15b,31.18,33.11,30.92,20.98,21.56,16.57,23.59 +CodeLlama-7b,33.77,33.36,29.04,28.25,20.22,27.45,23.19 +CodeLlama-13b,30.68,30.95,26.74,25.39,21.88,16.48,22.53 +CodeLlama-34b,30.98,33.69,21.87,25.83,20.56,26.15,26.04 +DeepSeek-Coder-1.3b,29.52,32.18,31.15,24.27,24.27,8.33,20.71 +DeepSeek-Coder-6.7b,19.88,22.09,13.21,23.21,22.51,25.63,22.94 +DeepSeek-Coder-33b,34.50,33.72,34.46,26.08,23.97,26.97,19.38 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/7/EI/tongji.py b/dataset/Test Generation/ComplexCodeEval-Java/7/EI/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/7/EI/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/7/QS/QS.json b/dataset/Test Generation/ComplexCodeEval-Java/7/QS/QS.json new file mode 100644 index 0000000000000000000000000000000000000000..c194d3a69ae8fecf141dde666bcd0ad607e70b79 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/7/QS/QS.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTask", + "reference": " @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(TEST_S3_BUCKET);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 100L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockKafkaConsumer.consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(true);\n Mockito.when(mockKafkaConsumer.close()).thenReturn(true);\n Mockito.when(mockChunkManager.stopAsync()).thenReturn(true);\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenReturn(true);\n Mockito.when(AstraKafkaConsumer.makeKafkaConfig(Mockito.any(), Mockito.anyInt())).thenReturn(new Properties());\n Mockito.when(RecoveryChunkManager.fromConfig(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(mockChunkManager);\n Mockito.when(adminClient.listConsumerGroupOffsets(Mockito.anyString())).thenReturn(new ConsumerGroupOffsets());\n Mockito.when(AstraConfig.getRecoveryConfig()).thenReturn(new RecoveryConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig()).thenReturn(new KafkaConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic()).thenReturn(\"testTopic\");\n Mockito.when(AstraConfig.getIndexerConfig()).thenReturn(new IndexerConfig());\n Mockito.when(AstraConfig.getS3Config()).thenReturn(new S3Config());\n Mockito.when(meterRegistry.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class));\n Mockito.when(Timer.start(Mockito.any())).thenReturn(Mockito.mock(Timer.Sample.class));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n Assert.assertTrue(result);\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).prepConsumerForConsumption(Mockito.anyLong());\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong());\n Mockito.verify(mockChunkManager, Mockito.times(1)).waitForRollOvers();\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).close();\n Mockito.verify(mockChunkManager, Mockito.times(1)).stopAsync();\n Mockito.verify(mockChunkManager, Mockito.times(1)).awaitTerminated(Mockito.anyLong());\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 6, + "token_diff": 5 + }, + { + "id": 1, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskFailure", + "reference": " @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n String fakeS3Bucket = \"fakeBucket\";\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(fakeS3Bucket);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isFalse();\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 10L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(false);\n Mockito.when(mockKafkaConsumer.close()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.stopAsync()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertFalse(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 6, + "token_diff": 5 + }, + { + "id": 2, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public int migration(String srcTable, List uuids, String targetTable) {\n if (Objects.isNull(srcTable) || Objects.isNull(targetTable) || CollectionUtils.isEmpty(uuids)) {\n throw new SwValidationException(\n ValidSubject.EVALUATION, \"srcTable, targetTable and uuids must not be null\");\n }\n var filter = TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build();\n return store.migration(DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(filter)\n .build()\n );\n }\n```\n###test function signature:\n```java\n @Test\n public void testMigration()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMigration", + "reference": " @Test\n public void testMigration() {\n var srcTable = \"p/1/space/1/eval/summary\";\n var targetTable = \"p/1/eval/summary\";\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(null, List.of(\"uuid1\"), targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, null, targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, List.of(\"uuid1\"), null));\n\n evaluationRepo.migration(srcTable, List.of(\"uuid1\", \"uuid2\"), targetTable);\n\n verify(dataStore, times(1)).migration(\n DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(\n TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(List.of(\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid1\")))\n .build(),\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid2\")))\n .build()))\n .build())\n .build()\n );\n }", + "completions": [ + "```java\n @Test\n public void testMigration() {\n // Given\n String srcTable = \"srcTable\";\n List uuids = Arrays.asList(\"uuid1\", \"uuid2\", \"uuid3\");\n String targetTable = \"targetTable\";\n DataStoreMigrationRequest expectedRequest = DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build()\n )\n .build();\n when(store.migration(any(DataStoreMigrationRequest.class))).thenReturn(1);\n\n // When\n int result = migrationService.migration(srcTable, uuids, targetTable);\n\n // Then\n assertEquals(1, result);\n verify(store).migration(expectedRequest);\n }\n```\n" + ], + "line": 37, + "token": 221, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 3, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDatasMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDatasMapTest", + "reference": " @Test\n public void sendDatasMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"doris_test_db\";\n List> dataList = new ArrayList<>();\n for (int i = 0; i < 100; i++) {\n Map data = new HashMap<>();\n data.put(\"id\", i);\n data.put(\"name\", \"张三\");\n data.put(\"message\", \"股东风波的法规的规定翻跟斗广泛的给yryrtytr\");\n System.out.println(gson.toJson(data));\n dataList.add(data);\n\n }\n dorisStreamLoad.sendData(database, tableName, columnList, dataList);\n }", + "completions": [ + "```java\n @Test\n public void sendDataTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 4, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_serializable_key_and_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_serializable_key_and_value", + "reference": " @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map original = new HashMap<>();\n original.put(\"key\", new SerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map originalMap = new HashMap<>();\n originalMap.put(\"one\", 1);\n originalMap.put(\"two\", 2);\n originalMap.put(\"three\", 3);\n\n Map clonedMap = clone(originalMap);\n\n assertNotSame(originalMap, clonedMap);\n assertEquals(originalMap, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 5, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void encode(String schemaName, long nextTransactionOffset) {\n if (schemaName == null) {\n schemaName = UNKNOWN;\n }\n this.schemaName = schemaName;\n this.nextTransactionOffset = nextTransactionOffset;\n final byte[] payloadBytes = payloadToBytes();\n final int payloadLength = payloadBytes.length;\n\n // set logEventHeader\n int eventSize = eventHeaderLengthVersionGt1 + payloadLength;\n setLogEventHeader(new LogEventHeader(drc_filter_log_event.getType(), 0, eventSize, nextTransactionOffset)\n );\n\n // set payload\n final ByteBuf payloadByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(payloadLength);\n payloadByteBuf.writeBytes(payloadBytes);\n payloadByteBuf.skipBytes(payloadLength);\n setPayloadBuf(payloadByteBuf);\n }\n```\n###test function signature:\n```java\n @Test\n public void encode()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "encode", + "reference": " @Test\n public void encode() {\n FilterLogEvent filterLogEvent = new FilterLogEvent();\n filterLogEvent.encode(\"drc1\", 101);\n Assert.assertEquals(\"drc1\", filterLogEvent.getSchemaName());\n Assert.assertEquals(101, filterLogEvent.getNextTransactionOffset());\n\n ByteBuf headerBuf = filterLogEvent.getLogEventHeader().getHeaderBuf();\n headerBuf.readerIndex(0);\n ByteBuf payloadBuf = filterLogEvent.getPayloadBuf();\n payloadBuf.readerIndex(0);\n CompositeByteBuf compositeByteBuf = PooledByteBufAllocator.DEFAULT.compositeDirectBuffer();\n compositeByteBuf.addComponents(true, headerBuf, payloadBuf);\n\n FilterLogEvent newFilterLogEvent = new FilterLogEvent();\n newFilterLogEvent.read(compositeByteBuf);\n compositeByteBuf.release(compositeByteBuf.refCnt());\n Assert.assertEquals(\"drc1\", newFilterLogEvent.getSchemaName());\n Assert.assertEquals(101, newFilterLogEvent.getNextTransactionOffset());\n }", + "completions": [ + "```java\n @Test\n public void encode() {\n // Given\n String schemaName = \"testSchema\";\n long nextTransactionOffset = 1234567890L;\n drc_filter_log_event event = new drc_filter_log_event();\n\n // When\n event.encode(schemaName, nextTransactionOffset);\n\n // Then\n assertEquals(schemaName, event.getSchemaName());\n assertEquals(nextTransactionOffset, event.getNextTransactionOffset());\n assertEquals(drc_filter_log_event.getType(), event.getLogEventHeader().getType());\n assertEquals(0, event.getLogEventHeader().getServerId());\n assertEquals(eventHeaderLengthVersionGt1 + event.getPayloadBuf().readableBytes(), event.getLogEventHeader().getEventLength());\n assertEquals(nextTransactionOffset, event.getLogEventHeader().getNextTransactionOffset());\n assertEquals(event.getPayloadBuf().readableBytes(), event.getPayloadBuf().readableBytes());\n }\n```\n" + ], + "line": 29, + "token": 183, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void put(String path, long position, ByteBuf data) {\n writeLock.lock();\n try {\n int dataLength = data.readableBytes();\n NavigableMap cache = path2cache.computeIfAbsent(path, k -> new TreeMap<>());\n Map.Entry pos2value = cache.floorEntry(position);\n long cacheStartPosition;\n Value value;\n if (pos2value == null || pos2value.getKey() + pos2value.getValue().dataLength < position) {\n cacheStartPosition = position;\n value = Value.EMPTY;\n } else {\n cacheStartPosition = pos2value.getKey();\n value = pos2value.getValue();\n }\n // ensure the capacity, if the capacity change then update the cache index\n int moreCapacity = (int) ((position + dataLength) - (cacheStartPosition + value.blocks.length * (long) blockSize));\n int newDataLength = (int) (position + dataLength - cacheStartPosition);\n if (moreCapacity > 0) {\n int[] blocks = ensureCapacity(cacheStartPosition, moreCapacity);\n if (blocks == null) {\n return;\n }\n int[] newBlocks = new int[value.blocks.length + blocks.length];\n System.arraycopy(value.blocks, 0, newBlocks, 0, value.blocks.length);\n System.arraycopy(blocks, 0, newBlocks, value.blocks.length, blocks.length);\n value = new Value(newBlocks, newDataLength);\n } else {\n value = new Value(value.blocks, newDataLength);\n }\n cache.put(cacheStartPosition, value);\n lru.put(new Key(path, cacheStartPosition), value);\n\n // write data to cache\n ByteBuffer cacheByteBuffer = this.cacheByteBuffer.duplicate();\n int positionDelta = (int) (position - cacheStartPosition);\n int written = 0;\n ByteBuffer[] nioBuffers = data.nioBuffers();\n int[] blocks = value.blocks;\n for (ByteBuffer nioBuffer : nioBuffers) {\n ByteBuf buf = Unpooled.wrappedBuffer(nioBuffer);\n while (buf.readableBytes() > 0) {\n int writePosition = positionDelta + written;\n int block = blocks[writePosition / blockSize];\n cacheByteBuffer.position(block * blockSize + writePosition % blockSize);\n int length = Math.min(buf.readableBytes(), blockSize - writePosition % blockSize);\n cacheByteBuffer.put(buf.slice(buf.readerIndex(), length).nioBuffer());\n buf.skipBytes(length);\n written += length;\n }\n }\n } finally {\n writeLock.unlock();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testMergePut() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMergePut", + "reference": " @Test\n public void testMergePut() throws IOException {\n FileCache fileCache = new FileCache(\"/tmp/file_cache_test\", 10 * 1024, 1024);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 1, 500));\n buf.addComponent(true, genBuf((byte) 2, 500));\n buf.addComponent(true, genBuf((byte) 3, 500));\n fileCache.put(\"test\", 3333L, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(1500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 4, 500));\n buf.addComponent(true, genBuf((byte) 5, 500));\n buf.addComponent(true, genBuf((byte) 6, 500));\n fileCache.put(\"test\", 3333L + 1000, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(2500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 7, 500));\n fileCache.put(\"test\", 3333L + 1000 + 1500, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(3000, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n\n assertTrue(verify(fileCache.get(\"test\", 3333L, 500).get(), (byte) 1));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 500, 500).get(), (byte) 2));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1000, 500).get(), (byte) 4));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1500, 500).get(), (byte) 5));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2000, 500).get(), (byte) 6));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2500, 500).get(), (byte) 7));\n }", + "completions": [ + "```java\n @Test\n public void testMergePut() throws IOException {\n // Given\n String path = \"/test/path\";\n long position = 100;\n ByteBuf data = Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5});\n\n // When\n cache.put(path, position, data);\n\n // Then\n NavigableMap cacheMap = cache.path2cache.get(path);\n assertNotNull(cacheMap);\n Map.Entry entry = cacheMap.floorEntry(position);\n assertNotNull(entry);\n assertEquals(position, entry.getKey().longValue());\n assertEquals(5, entry.getValue().dataLength);\n assertEquals(1, entry.getValue().blocks[0]);\n assertEquals(2, entry.getValue().blocks[1]);\n assertEquals(3, entry.getValue().blocks[2]);\n assertEquals(4, entry.getValue().blocks[3]);\n assertEquals(5, entry.getValue().blocks[4]);\n }\n```\n" + ], + "line": 65, + "token": 508, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 7, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchById()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchById", + "reference": " @Test\n public void testSearchById() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult index =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(index.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchById() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 8, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = new ArrayList<>();\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 9, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_object", + "reference": " @Test\n public void should_clone_serializable_object() {\n Object original = new SerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_object() {\n // Given\n Serializable object = new Serializable() {\n private static final long serialVersionUID = 1L;\n // Implement necessary methods\n };\n\n // When\n Serializable clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertEquals(object, clonedObject);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 10, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() {\n\n\t\tvar text = \"Hello, how are you?\";\n\n\t\tEmbedding response = vertexAiPaLm2Api.embedText(text);\n\n\t\tassertThat(response).isNotNull();\n\t\tassertThat(response.value()).hasSize(768);\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(new EmbeddingResponse(expectedEmbedding));\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Optional maybeFetchErrorIntervalMs(short errorCode, int intervalMs) {\n if (errorCode == Errors.NONE.code())\n return Optional.empty();\n\n int pushIntervalMs;\n String reason;\n\n Errors error = Errors.forCode(errorCode);\n switch (error) {\n case INVALID_REQUEST:\n case INVALID_RECORD:\n case UNSUPPORTED_VERSION:\n pushIntervalMs = Integer.MAX_VALUE;\n reason = \"The broker response indicates the client sent an request that cannot be resolved\"\n + \" by re-trying, hence disable telemetry\";\n break;\n case UNKNOWN_SUBSCRIPTION_ID:\n case UNSUPPORTED_COMPRESSION_TYPE:\n pushIntervalMs = 0;\n reason = error.message();\n break;\n case TELEMETRY_TOO_LARGE:\n case THROTTLING_QUOTA_EXCEEDED:\n reason = error.message();\n pushIntervalMs = (intervalMs != -1) ? intervalMs : ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS;\n break;\n default:\n reason = \"Unwrapped error code\";\n log.error(\"Error code: {}. Unmapped error for telemetry, disable telemetry.\", errorCode);\n pushIntervalMs = Integer.MAX_VALUE;\n }\n\n log.debug(\"Error code: {}, reason: {}. Push interval update to {} ms.\", errorCode, reason, pushIntervalMs);\n return Optional.of(pushIntervalMs);\n }\n```\n###test function signature:\n```java\n @Test\n public void testMaybeFetchErrorIntervalMs()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMaybeFetchErrorIntervalMs", + "reference": " @Test\n public void testMaybeFetchErrorIntervalMs() {\n assertEquals(Optional.empty(), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.NONE.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_REQUEST.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_RECORD.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SUBSCRIPTION_ID.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNSUPPORTED_COMPRESSION_TYPE.code(), -1));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), 20000));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), 20000));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SERVER_ERROR.code(), -1));\n }", + "completions": [ + "```java\n @Test\n public void testMaybeFetchErrorIntervalMs() {\n // Test case 1: errorCode is Errors.NONE.code()\n short errorCode = Errors.NONE.code();\n int intervalMs = 1000;\n Optional result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertFalse(result.isPresent());\n\n // Test case 2: errorCode is Errors.INVALID_REQUEST.code()\n errorCode = Errors.INVALID_REQUEST.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 3: errorCode is Errors.UNKNOWN_SUBSCRIPTION_ID.code()\n errorCode = Errors.UNKNOWN_SUBSCRIPTION_ID.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 4: errorCode is Errors.TELEMETRY_TOO_LARGE.code()\n errorCode = Errors.TELEMETRY_TOO_LARGE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 5: errorCode is Errors.THROTTLING_QUOTA_EXCEEDED.code()\n errorCode = Errors.THROTTLING_QUOTA_EXCEEDED.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 6: errorCode is Errors.UNSUPPORTED_COMPRESSION_TYPE.code()\n errorCode = Errors.UNSUPPORTED_COMPRESSION_TYPE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 7: errorCode is Errors.INVALID_RECORD.code()\n errorCode = Errors.INVALID_RECORD.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 8: errorCode is Errors.UNSUPPORTED_VERSION.code()\n errorCode = Errors.UNSUPPORTED_VERSION.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 9: errorCode is an unmapped error\n errorCode = (short) 9999;\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n }\n```\n" + ], + "line": 43, + "token": 290, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 12, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_non_serializable_object", + "reference": " @Test\n public void should_clone_non_serializable_object() {\n Object original = new NonSerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_non_serializable_object() {\n // Given\n Map nonSerializableMap = new HashMap<>();\n nonSerializableMap.put(\"key1\", new Object());\n nonSerializableMap.put(\"key2\", new Object());\n\n // When\n Map clonedMap = objectCloner.clone(nonSerializableMap);\n\n // Then\n assertNotSame(nonSerializableMap, clonedMap);\n assertNotSame(nonSerializableMap.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(nonSerializableMap.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 13, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueOperationCompletesOperationWithCachedActionResult", + "reference": " @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ActionKey actionKey = DigestUtil.asActionKey(Digest.newBuilder().setHash(\"test\").build());\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"operation-with-cached-action-result\")\n .setActionDigest(actionKey.getDigest())\n .build();\n\n ActionResult actionResult = ActionResult.getDefaultInstance();\n\n when(mockBackplane.getActionResult(eq(actionKey))).thenReturn(actionResult);\n\n Poller poller = mock(Poller.class);\n\n instance.queue(executeEntry, poller, DEFAULT_TIMEOUT).get();\n\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(CACHE_CHECK));\n verify(mockBackplane, never()).putOperation(any(Operation.class), eq(QUEUED));\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(COMPLETED));\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n ActionResult actionResult = ActionResult.newBuilder().setExitCode(0).build();\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(\"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n when(cache.get(any(ActionKey.class))).thenReturn(actionResult);\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n future.get();\n\n verify(poller).pause();\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 4, + "token_diff": 2 + }, + { + "id": 14, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueActionFailsQueueEligibility", + "reference": " @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n Directory inputRoot = Directory.newBuilder().build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(false);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_INVALID)\n .setSubject(INVALID_PLATFORM)\n .setDescription(\n \"properties are not valid for queue eligibility: []. If you think your\"\n + \" queue should still accept these poperties without them being\"\n + \" specified in queue configuration, consider configuring the queue\"\n + \" with `allow_unmatched: True`\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(ACTION_DIGEST));\n when(executeEntry.getStdoutStreamName()).thenReturn(STDOUT_STREAM_NAME);\n when(executeEntry.getStderrStreamName()).thenReturn(STDERR_STREAM_NAME);\n when(executeEntry.getOperationName()).thenReturn(OPERATION_NAME);\n when(executeEntry.getRequestMetadata()).thenReturn(REQUEST_METADATA);\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n when(poller.pause()).thenThrow(new RuntimeException(\"Queue eligibility failed\"));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n\n assertThrows(RuntimeException.class, future::get);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 4, + "token_diff": 2 + }, + { + "id": 15, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new SerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new SerializableObject(\"name2\"),\n new SerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key\", new SerializableObject());\n SerializableObject object = new SerializableObject(map);\n\n // When\n SerializableObject clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertNotSame(object.getMap(), clonedObject.getMap());\n assertNotSame(object.getMap().get(\"key\"), clonedObject.getMap().get(\"key\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 16, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForSumAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForSumAgg", + "reference": " @Test\n public void testFullIndexSearchForSumAgg() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new SumAggBuilder(\"test\", TEST_SOURCE_LONG_PROPERTY, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalSum internalSum =\n (InternalSum) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n // 1, 3, 4, 5\n assertThat(internalSum.getValue()).isEqualTo(13);\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForSumAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new SumAggBuilder(\"testField\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertEquals(SumAggBuilder.class, result.getAggregation().getClass());\n assertEquals(\"testField\", ((SumAggBuilder) result.getAggregation()).getField());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 17, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void process(String base64AuthenticatorData, String signature, String clientDataJson, Fido2RegistrationData registration,\n Fido2AuthenticationData authenticationEntity) {\n log.debug(\"Registration: {}\", registration);\n\n AuthData authData = authenticatorDataParser.parseAssertionData(base64AuthenticatorData);\n commonVerifiers.verifyRpIdHash(authData, registration.getDomain());\n\n log.debug(\"User verification option: {}\", authenticationEntity.getUserVerificationOption());\n userVerificationVerifier.verifyUserVerificationOption(authenticationEntity.getUserVerificationOption(), authData);\n\n byte[] clientDataHash = DigestUtils.getSha256Digest().digest(base64Service.urlDecode(clientDataJson));\n\n try {\n int counter = authenticatorDataParser.parseCounter(authData.getCounters());\n commonVerifiers.verifyCounter(registration.getCounter(), counter);\n registration.setCounter(counter);\n\n JsonNode uncompressedECPointNode = dataMapperService.cborReadTree(base64Service.urlDecode(registration.getUncompressedECPoint()));\n PublicKey publicKey = coseService.createUncompressedPointFromCOSEPublicKey(uncompressedECPointNode);\n\n log.debug(\"Uncompressed ECpoint node: {}\", uncompressedECPointNode);\n log.debug(\"EC Public key hex: {}\", Hex.encodeHexString(publicKey.getEncoded()));\n log.debug(\"Registration algorithm: {}, default use: -7\", registration.getSignatureAlgorithm());\n authenticatorDataVerifier.verifyAssertionSignature(authData, clientDataHash, signature, publicKey, -7);\n\n } catch (Fido2CompromisedDevice ex) {\n log.error(\"Error compromised device: {}\", ex.getMessage());\n throw ex;\n } catch (Exception ex) {\n log.error(\"Error to check none assertion: {}\", ex.getMessage());\n throw new Fido2RuntimeException(\"Failed to check none assertion: {}\", ex.getMessage(), ex);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "process_ifCborReadTreeThrowException_fido2RuntimeException", + "reference": " @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n String base64AuthenticatorData = \"base64AuthenticatorData_test\";\n String signature = \"signature_test\";\n String clientDataJson = \"clientDataJson_test\";\n Fido2RegistrationData registration = mock(Fido2RegistrationData.class);\n Fido2AuthenticationData authenticationEntity = mock(Fido2AuthenticationData.class);\n\n when(authenticationEntity.getUserVerificationOption()).thenReturn(UserVerification.preferred);\n when(registration.getDomain()).thenReturn(\"domain_test\");\n when(registration.getCounter()).thenReturn(100);\n when(registration.getUncompressedECPoint()).thenReturn(\"uncompressedECPoint_test\");\n\n when(authenticatorDataParser.parseAssertionData(any())).thenReturn(mock(AuthData.class));\n when(base64Service.urlDecode(any(String.class))).thenReturn(\"decode_test\".getBytes());\n when(dataMapperService.cborReadTree(any(byte[].class))).thenThrow(new IOException(\"IOException_test\"));\n\n Fido2RuntimeException ex = assertThrows(Fido2RuntimeException.class, () -> noneAssertionFormatProcessor.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity));\n assertNotNull(ex);\n assertEquals(ex.getMessage(), \"IOException_test\");\n\n verify(log).debug(eq(\"Registration: {}\"), any(Fido2RegistrationData.class));\n verify(log).debug(eq(\"User verification option: {}\"), any(UserVerification.class));\n verify(commonVerifiers).verifyRpIdHash(any(AuthData.class), any(String.class));\n verify(authenticatorDataParser).parseCounter(any());\n verify(log).error(eq(\"Error to check none assertion: {}\"), any(String.class));\n\n verify(log, never()).error(eq(\"Error compromised device: {}\"), any(String.class));\n verifyNoInteractions(coseService, authenticatorDataVerifier);\n verifyNoMoreInteractions(log);\n }", + "completions": [ + "```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n // Arrange\n String base64AuthenticatorData = \"base64AuthenticatorData\";\n String signature = \"signature\";\n String clientDataJson = \"clientDataJson\";\n Fido2RegistrationData registration = new Fido2RegistrationData();\n Fido2AuthenticationData authenticationEntity = new Fido2AuthenticationData();\n\n when(authenticatorDataParser.parseAssertionData(base64AuthenticatorData)).thenReturn(new AuthData());\n when(authenticatorDataParser.parseCounter(any())).thenReturn(1);\n when(dataMapperService.cborReadTree(any())).thenThrow(new IOException());\n\n // Act & Assert\n Assertions.assertThrows(Fido2RuntimeException.class, () -> {\n fido2AuthenticationService.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity);\n });\n }\n```\n" + ], + "line": 38, + "token": 365, + "line_diff": 2, + "token_diff": 4 + }, + { + "id": 18, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void setUserData_invalidKey_throws()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "setUserData_invalidKey_throws", + "reference": "\t@Test\n\tpublic void setUserData_invalidKey_throws() {\n\t\t// setup\n\t\tTestParameters parameters = new TestParameters();\n\n\t\t// test\n\t\tfor (String key : new String[] { null, \"\" }) {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, \"test\");\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid key; key must be non-empty, non-null\"),\n\t\t\t\t\tex.getMessage());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@Test(expected = IllegalArgumentException.class)\n\tpublic void setUserData_invalidKey_throws() {\n\t\tTestClass testClass = new TestClass();\n\t\ttestClass.setUserData(\"\", \"value\");\n\t}\n```\n\nIn this test, we are creating an instance of the TestClass and calling the setUserData method with an empty key and a valid value. We expect this to throw an IllegalArgumentException, which is what we specify in the @Test annotation.\n" + ], + "line": 23, + "token": 170, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueDirectoryMissingErrorsOperation", + "reference": " @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ByteString foo = ByteString.copyFromUtf8(\"foo\");\n Digest subdirDigest = DIGEST_UTIL.compute(foo);\n Directory inputRoot =\n Directory.newBuilder()\n .addDirectories(\n DirectoryNode.newBuilder().setName(\"missing-subdir\").setDigest(subdirDigest))\n .build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(true);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_MISSING)\n .setSubject(\"blobs/\" + DigestUtil.toString(subdirDigest))\n .setDescription(\"The directory `/missing-subdir` was not found in the CAS.\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(10);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(DigestUtil.DEFAULT_INSTANCE_NAME + \"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n ActionKey actionKey = DigestUtil.asActionKey(executeEntry.getActionDigest());\n when(checkCacheFuture(actionKey, any(), any())).thenReturn(immediateFuture(true));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n assertThat(future.get()).isEqualTo(null);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 4, + "token_diff": 2 + }, + { + "id": 20, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime = Instant.now();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Start recovery with an offset range that is partially unavailable\n long startOffset = 50;\n long endOffset = 150;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry)).isEqualTo(50);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(51);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 100L,\n 200L,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 6, + "token_diff": 5 + }, + { + "id": 21, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public List selectHeaderAccept(String[] accepts) {\n if (accepts.length == 0) {\n return null;\n }\n for (String accept : accepts) {\n MediaType mediaType = MediaType.parseMediaType(accept);\n if (isJsonMime(mediaType) && !isProblemJsonMime(accept)) {\n return Collections.singletonList(mediaType);\n }\n }\n return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts));\n }\n```\n###test function signature:\n```java\n @Test\n public void testSelectHeaderAccept()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSelectHeaderAccept", + "reference": " @Test\n public void testSelectHeaderAccept() {\n String[] accepts = {\"application/json\", \"application/xml\"};\n assertEquals(\"application/json\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"APPLICATION/XML\", \"APPLICATION/JSON\"};\n assertEquals(\"APPLICATION/JSON\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"application/xml\", \"application/json; charset=UTF8\"};\n assertEquals(\"application/json; charset=UTF8\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"text/plain\", \"application/xml\"};\n assertEquals(\"text/plain,application/xml\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{};\n assertNull(apiClient.selectHeaderAccept(accepts));\n }", + "completions": [ + "```java\n @Test\n public void testSelectHeaderAccept() {\n // Given\n String[] accepts = {\"application/json\", \"application/problem+json\"};\n // When\n List result = selectHeaderAccept(accepts);\n // Then\n assertEquals(1, result.size());\n assertEquals(\"application/json\", result.get(0).toString());\n }\n```\n" + ], + "line": 23, + "token": 154, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollEventSentOnConsumerPoll", + "reference": " @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n SubscriptionState subscriptions = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);\n consumer = newConsumer(\n mock(FetchBuffer.class),\n new ConsumerInterceptors<>(Collections.emptyList()),\n mock(ConsumerRebalanceListenerInvoker.class),\n subscriptions,\n singletonList(new RoundRobinAssignor()),\n \"group-id\",\n \"client-id\");\n final TopicPartition tp = new TopicPartition(\"topic\", 0);\n final List> records = singletonList(\n new ConsumerRecord<>(\"topic\", 0, 2, \"key1\", \"value1\"));\n doAnswer(invocation -> Fetch.forPartition(tp, records, true))\n .when(fetchCollector)\n .collectFetch(Mockito.any(FetchBuffer.class));\n\n consumer.subscribe(singletonList(\"topic1\"));\n consumer.poll(Duration.ofMillis(100));\n verify(applicationEventHandler).add(any(PollEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n // Arrange\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n consumer.poll(Duration.ofMillis(100));\n\n // Act\n ConsumerRecords records = consumer.poll(Duration.ofMillis(100));\n\n // Assert\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 23, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetMultiplePartitionRecoveriesBehind", + "reference": " @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n final RecoveryTaskMetadata recoveryTask2 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"2\",\n \"2\",\n recoveryStartOffset * 3 + 1,\n recoveryStartOffset * 4,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask2);\n final RecoveryTaskMetadata recoveryTask21 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"21\", \"2\", recoveryStartOffset * 4 + 1, 50000, createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask21);\n await()\n .until(\n () ->\n recoveryTaskStore\n .listSync()\n .containsAll(\n List.of(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(5);\n assertThat(recoveryTasks)\n .contains(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n // Given\n String partitionId = \"partition1\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 24, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture commitStreamObject(apache.rocketmq.controller.v1.S3StreamObject streamObject,\n List compactedObjects) {\n LOGGER.info(\"commitStreamObject with streamObject: {}, compactedObjects: {}\", TextFormat.shortDebugString(streamObject),\n compactedObjects);\n\n CompletableFuture future = new CompletableFuture<>();\n try (SqlSession session = sessionFactory.openSession()) {\n if (streamObject.getObjectId() == S3Constants.NOOP_OBJECT_ID) {\n LOGGER.error(\"S3StreamObject[object-id={}] is null or objectId is unavailable\", streamObject.getObjectId());\n String msg = String.format(\"S3StreamObject[object-id=%d] is null or objectId is unavailable\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.NOT_FOUND_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n\n long committedTs = System.currentTimeMillis();\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n\n // commit object\n if (!commitObject(streamObject.getObjectId(), streamObject.getStreamId(), streamObject.getObjectSize(), session)) {\n String msg = String.format(\"S3StreamObject[object-id=%d] is not ready for commit\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.ILLEGAL_STATE_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n long dataTs = committedTs;\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n dataTs = compactedObjects.stream()\n .map(id -> {\n // mark destroy compacted object\n S3Object object = s3ObjectMapper.getById(id);\n object.setState(S3ObjectState.BOS_WILL_DELETE);\n object.setMarkedForDeletionTimestamp(new Date());\n s3ObjectMapper.markToDelete(object.getId(), new Date());\n\n // update dataTs to the min compacted object's dataTs\n com.automq.rocketmq.metadata.dao.S3StreamObject s3StreamObject =\n s3StreamObjectMapper.getByObjectId(id);\n return s3StreamObject.getBaseDataTimestamp().getTime();\n })\n .min(Long::compareTo).get();\n }\n\n List toCache = new ArrayList<>();\n\n // create a new S3StreamObject to replace committed ones\n if (streamObject.getObjectId() != S3Constants.NOOP_OBJECT_ID) {\n com.automq.rocketmq.metadata.dao.S3StreamObject newS3StreamObj =\n new com.automq.rocketmq.metadata.dao.S3StreamObject();\n newS3StreamObj.setStreamId(streamObject.getStreamId());\n newS3StreamObj.setObjectId(streamObject.getObjectId());\n newS3StreamObj.setObjectSize(streamObject.getObjectSize());\n newS3StreamObj.setStartOffset(streamObject.getStartOffset());\n newS3StreamObj.setEndOffset(streamObject.getEndOffset());\n newS3StreamObj.setBaseDataTimestamp(new Date(dataTs));\n newS3StreamObj.setCommittedTimestamp(new Date(committedTs));\n s3StreamObjectMapper.create(newS3StreamObj);\n toCache.add(newS3StreamObj);\n }\n\n // delete the compactedObjects of S3Stream\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n compactedObjects.forEach(id -> s3StreamObjectMapper.delete(null, null, id));\n }\n session.commit();\n\n // Update Cache\n s3StreamObjectCache.cache(streamObject.getStreamId(), toCache);\n s3StreamObjectCache.onCompact(streamObject.getStreamId(), compactedObjects);\n\n LOGGER.info(\"S3StreamObject[object-id={}] commit success, compacted objects: {}\",\n streamObject.getObjectId(), compactedObjects);\n future.complete(null);\n } catch (Exception e) {\n LOGGER.error(\"CommitStream failed\", e);\n ControllerException ex = new ControllerException(Code.INTERNAL_VALUE, \"CommitStream failed\" + e.getMessage());\n future.completeExceptionally(ex);\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCommitStreamObject() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCommitStreamObject", + "reference": " @Test\n public void testCommitStreamObject() throws IOException {\n long objectId, streamId = 1;\n\n try (S3MetadataService metadataService = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n objectId = metadataService.prepareS3Objects(3, 5).get();\n } catch (Exception e) {\n throw new RuntimeException(e);\n }\n\n S3StreamObject news3StreamObject = S3StreamObject.newBuilder()\n .setObjectId(objectId + 2)\n .setStreamId(streamId)\n .setObjectSize(111L)\n .build();\n\n try (SqlSession session = this.getSessionFactory().openSession()) {\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n buildS3StreamObjs(objectId, 2, 3, 100L).forEach(s3StreamObjectMapper::create);\n session.commit();\n }\n\n try (S3MetadataService service = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n List compactedObjects = new ArrayList<>();\n compactedObjects.add(objectId);\n compactedObjects.add(objectId + 1);\n service.commitStreamObject(news3StreamObject, compactedObjects);\n }\n\n try (SqlSession session = getSessionFactory().openSession()) {\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3Object s3Object = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object.getState());\n\n S3Object s3Object1 = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object1.getState());\n\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n for (long index = objectId; index < objectId + 2; index++) {\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(index);\n Assertions.assertNull(object);\n }\n\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(objectId + 2);\n Assertions.assertEquals(111L, object.getObjectSize());\n Assertions.assertEquals(streamId, object.getStreamId());\n Assertions.assertTrue(object.getBaseDataTimestamp().getTime() > 0);\n Assertions.assertTrue(object.getCommittedTimestamp().getTime() > 0);\n }\n }", + "completions": [ + "```java\n @Test\n public void testCommitStreamObject() throws IOException {\n // Given\n apache.rocketmq.controller.v1.S3StreamObject streamObject = new apache.rocketmq.controller.v1.S3StreamObject();\n streamObject.setObjectId(1L);\n streamObject.setStreamId(2L);\n streamObject.setObjectSize(3L);\n streamObject.setStartOffset(4L);\n streamObject.setEndOffset(5L);\n\n List compactedObjects = Arrays.asList(6L, 7L, 8L);\n\n // When\n CompletableFuture future = controller.commitStreamObject(streamObject, compactedObjects);\n\n // Then\n assertDoesNotThrow(() -> future.get());\n }\n```\n" + ], + "line": 85, + "token": 664, + "line_diff": 5, + "token_diff": 6 + }, + { + "id": 25, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyTooManyResults", + "reference": " @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(totalHits.value).thenReturn(2L);\n Mockito.when(hits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(response.getHits()).thenReturn(hits);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processService.getProcessByKey(processDefinitionKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Expected exception message to match\",\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n e.getMessage());\n }\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 26, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String doGet(String url) {\n try (CloseableHttpClient client = HttpClients.createDefault()) {\n HttpGet request = new HttpGet(url);\n HttpResponse response = client.execute(request);\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n return EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n }\n } catch (IOException e) {\n log.error(\"DoGet exception.\", e);\n\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n void testDoGet()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDoGet", + "reference": " @Test\n void testDoGet() {\n String result = HttpUtil.doGet(\"http://baidu.com\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testDoGet() {\n String url = \"http://example.com\";\n String expectedResponse = \"Hello, world!\";\n\n // Mock the HttpClient and HttpResponse\n CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);\n HttpResponse mockResponse = Mockito.mock(HttpResponse.class);\n StatusLine mockStatusLine = Mockito.mock(StatusLine.class);\n HttpEntity mockEntity = Mockito.mock(HttpEntity.class);\n\n // Set up the mock objects to return the expected values\n Mockito.when(mockClient.execute(Mockito.any(HttpGet.class))).thenReturn(mockResponse);\n Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);\n Mockito.when(mockStatusLine.getStatusCode()).thenReturn(HttpStatus.SC_OK);\n Mockito.when(mockResponse.getEntity()).thenReturn(mockEntity);\n Mockito.when(EntityUtils.toString(mockEntity, StandardCharsets.UTF_8)).thenReturn(expectedResponse);\n\n // Call the method under test\n String actualResponse = YourClassName.doGet(url);\n\n // Verify the result\n Assertions.assertEquals(expectedResponse, actualResponse);\n }\n```\nPlease replace `YourClassName` with the actual name of the class where the `doGet` method is located.\n" + ], + "line": 23, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n dpopService.validateDpopThumprint(null, \"any\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n // Given\n String existingThumprint = \"\";\n String requestThumprint = \"requestThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n try {\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), e.getResponse().getStatus());\n assertEquals(\"Thumprint does not match\", e.getMessage());\n }\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 28, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDataMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDataMapTest", + "reference": " @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"hera_log_doris_table_2_11\";\n for (int i = 0; i < 1; i++) {\n String jsonString = \"{\\\"linenumber\\\":1256,\\\"tailId\\\":90219,\\\"mqtag\\\":\\\"tags_4_13_90219\\\",\\\"filename\\\":\\\"/home/work/log/log-agent/server.log\\\",\\\"tail\\\":\\\"demo-client-agent\\\",\\\"mqtopic\\\":\\\"90219_hera-demo-client\\\",\\\"message\\\":\\\"2024-01-10 19:34:40,360|INFO ||NettyClientPublicExecutor_1|c.x.mone.log.agent.rpc.task.PingTask|83|ping res: log-agent-server:2022-12-05:0.0.2->2024-01-10 19:34:40 358->10.53.129.250:9899\\\",\\\"logstore\\\":\\\"测试doris日志\\\",\\\"logip\\\":\\\"10.53.129.176\\\",\\\"timestamp\\\":1704886481241}\";\n\n Map map = gson.fromJson(jsonString, Map.class);\n dorisStreamLoad.sendData(database, tableName, columnListNew, map);\n }\n }", + "completions": [ + "```java\n @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 29, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyNoResults", + "reference": " @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final NotFoundException exception =\n Assertions.assertThrows(\n NotFoundException.class, () -> processService.getProcessByKey(processDefinitionKey));\n\n // Then\n Assertions.assertEquals(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 30, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String getDiagramByKey(Long processDefinitionKey) {\n final IdsQueryBuilder q = idsQuery().addIds(processDefinitionKey.toString());\n\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n\n if (response.getHits().getTotalHits().value == 1) {\n final Map result = response.getHits().getHits()[0].getSourceAsMap();\n return (String) result.get(BPMN_XML);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with id '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with id '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\n \"Exception occurred, while obtaining the process diagram: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetDiagramByKeyTooManyResults", + "reference": " @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getDiagramByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 1L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHits.Relation.EQUAL_TO));\n\n // When\n try {\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n processDiagramService.getDiagramByKey(processDefinitionKey);\n } catch (NotFoundException e) {\n // Then\n Assert.assertEquals(\n \"Could not find unique process with id '\" + processDefinitionKey + \"'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 37, + "token": 309, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 31, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n void shutdown() {\n LOG.info(\"Running shutdown hook.\");\n try {\n serviceManager.stopAsync().awaitStopped(30, TimeUnit.SECONDS);\n } catch (Exception e) {\n // stopping timed out\n LOG.error(\"ServiceManager shutdown timed out\", e);\n }\n try {\n curatorFramework.unwrap().close();\n } catch (Exception e) {\n LOG.error(\"Error while closing curatorFramework \", e);\n }\n LOG.info(\"Shutting down LogManager\");\n LogManager.shutdown();\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexerShutdownTwice() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexerShutdownTwice", + "reference": " @Test\n public void testIndexerShutdownTwice() throws Exception {\n startKafkaServer();\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // Create a live partition for this partiton\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTimeMs = 1;\n final long endTimeMs = 100;\n final long maxOffset = 30;\n SnapshotMetadata livePartition0 =\n new SnapshotMetadata(\n name + \"live0\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"0\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition0);\n\n SnapshotMetadata livePartition1 =\n new SnapshotMetadata(\n name + \"live1\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"1\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition1);\n\n final SnapshotMetadata partition0 =\n new SnapshotMetadata(name, path, startTimeMs, endTimeMs, maxOffset, \"0\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition0);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .containsOnly(livePartition1, livePartition0, partition0);\n\n // Empty consumer offset since there is no prior consumer.\n astraIndexer =\n new AstraIndexer(\n chunkManagerUtil.chunkManager,\n curatorFramework,\n makeIndexerConfig(50),\n getKafkaConfig(),\n metricsRegistry);\n astraIndexer.startAsync();\n astraIndexer.awaitRunning(DEFAULT_START_STOP_DURATION);\n await().until(() -> kafkaServer.getConnectedConsumerGroups() == 1);\n\n // Produce more messages since the recovery task is created for head.\n produceMessagesToKafka(kafkaServer.getBroker(), startTime);\n\n consumeMessagesAndSearchMessagesTest(100, 1);\n\n // Live snapshot is deleted, recovery task is created.\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(livePartition1, partition0);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isEqualTo(4);\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).size()).isEqualTo(1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(searchMetadataStore).size()).isEqualTo(1);\n RecoveryTaskMetadata recoveryTask1 =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).get(0);\n assertThat(recoveryTask1.startOffset).isEqualTo(31);\n assertThat(recoveryTask1.endOffset).isEqualTo(99);\n assertThat(recoveryTask1.partitionId).isEqualTo(\"0\");\n\n // Shutting down is idempotent. So, doing it twice shouldn't throw an error.\n astraIndexer.shutDown();\n astraIndexer.shutDown();\n astraIndexer = null;\n }", + "completions": [ + "```java\n @Test\n public void testIndexerShutdownTwice() throws Exception {\n // Given\n Indexer indexer = new Indexer();\n indexer.serviceManager = Mockito.mock(ServiceManager.class);\n indexer.curatorFramework = Mockito.mock(CuratorFramework.class);\n Mockito.when(indexer.serviceManager.stopAsync()).thenReturn(Mockito.mock(ServiceManager.class));\n Mockito.when(indexer.serviceManager.awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class))).thenThrow(new TimeoutException());\n Mockito.when(indexer.curatorFramework.unwrap()).thenReturn(Mockito.mock(Closeable.class));\n Mockito.doThrow(new IOException()).when(indexer.curatorFramework.unwrap()).close();\n\n // When\n indexer.shutdown();\n\n // Then\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).stopAsync();\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class));\n Mockito.verify(indexer.curatorFramework.unwrap(), Mockito.times(1)).close();\n }\n```\n" + ], + "line": 27, + "token": 177, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(addressList,\"test\",\"test\",\"test\", SystemEnum.AGENT, MsgTypeEnum.ALARM);\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = null;\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 33, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyWithException", + "reference": " @Test\n public void testGetProcessByKeyWithException() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(OperateRuntimeException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(totalHits.value).thenReturn(2L);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final Executable executable = () -> processService.getProcessByKey(processDefinitionKey);\n\n // Then\n final NotFoundException exception =\n Assertions.assertThrows(NotFoundException.class, executable);\n Assertions.assertEquals(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 34, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testConcurrentSearches() throws InterruptedException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testConcurrentSearches", + "reference": " @Test\n public void testConcurrentSearches() throws InterruptedException {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n AtomicInteger searchFailures = new AtomicInteger(0);\n AtomicInteger statsFailures = new AtomicInteger(0);\n AtomicInteger searchExceptions = new AtomicInteger(0);\n AtomicInteger successfulRuns = new AtomicInteger(0);\n\n Runnable searchRun =\n () -> {\n for (int i = 0; i < 100; i++) {\n try {\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n 0L,\n MAX_TIME,\n 100,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n if (babies.hits.size() != 2) {\n searchFailures.addAndGet(1);\n } else {\n successfulRuns.addAndGet(1);\n }\n } catch (Exception e) {\n searchExceptions.addAndGet(1);\n }\n }\n };\n\n Thread t1 = new Thread(searchRun);\n Thread t2 = new Thread(searchRun);\n t1.start();\n t2.start();\n t1.join();\n t2.join();\n assertThat(searchExceptions.get()).isEqualTo(0);\n assertThat(statsFailures.get()).isEqualTo(0);\n assertThat(searchFailures.get()).isEqualTo(0);\n assertThat(successfulRuns.get()).isEqualTo(200);\n }", + "completions": [ + "```java\n @Test\n public void testConcurrentSearches() throws InterruptedException {\n // Create a fixed thread pool with 10 threads\n ExecutorService executorService = Executors.newFixedThreadPool(10);\n\n // Create a list to hold the futures\n List>> futures = new ArrayList<>();\n\n // Submit 100 search tasks to the executor service\n for (int i = 0; i < 100; i++) {\n // Create a search task\n Callable> searchTask = () -> {\n // Generate a random dataset\n String dataset = UUID.randomUUID().toString();\n\n // Generate a random query string\n String queryStr = UUID.randomUUID().toString();\n\n // Generate a random start time\n long startTimeMsEpoch = System.currentTimeMillis() - ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random end time\n long endTimeMsEpoch = startTimeMsEpoch + ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random number of hits requested\n int howMany = ThreadLocalRandom.current().nextInt(100);\n\n // Generate a random aggregation builder\n AggBuilder aggBuilder = new AggBuilder();\n\n // Perform the search\n return logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n };\n\n // Submit the search task to the executor service and add the future to the list\n futures.add(executorService.submit(searchTask));\n }\n\n // Wait for all search tasks to complete\n for (Future> future : futures) {\n future.get();\n }\n\n // Shutdown the executor service\n executorService.shutdown();\n }\n```\n" + ], + "line": 85, + "token": 608, + "line_diff": 5, + "token_diff": 6 + }, + { + "id": 35, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String sendGetData(String path, Map headMap) {\n log.info(\"request tcm, path:{},headMap:{} \",path,headMap);\n String result = \"\";\n CloseableHttpResponse response =null;\n try(CloseableHttpClient httpClient = HttpClients.createDefault()) {\n // 创建get方式请求对象\n HttpGet httpGet = new HttpGet(path);\n if (null != headMap) {\n for (Map.Entry entry : headMap.entrySet()) {\n httpGet.addHeader(entry.getKey(), entry.getValue());\n }\n }\n // 通过请求对象获取响应对象\n response = httpClient.execute(httpGet);\n\n // 获取结果实体\n // 判断网络连接状态码是否正常(0--200都数正常)\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n result = EntityUtils.toString(response.getEntity(), UTF_8);\n } else {\n log.error(\"get请求传输 异常.url:{}, headMap:{}\", path, headMap);\n }\n } catch (ClientProtocolException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n } catch (IOException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n }\n finally {\n try {\n if (null!=response){\n response.close();\n }\n } catch (IOException e) {\n log.error(\"关闭response 异常 \", e);\n }\n }\n log.debug(result);\n return result;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendGetData()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendGetData", + "reference": " @Test\n void testSendGetData() {\n final Map headMap = new HashMap<>();\n final String result = HttpUtils.sendGetData(\"path\", headMap);\n assertThat(result).isEqualTo(\"\");\n }", + "completions": [ + "```java\n @Test\n void testSendGetData() {\n String path = \"http://localhost:8080/api/test\";\n Map headMap = new HashMap<>();\n headMap.put(\"Content-Type\", \"application/json\");\n headMap.put(\"Authorization\", \"Bearer your_token\");\n\n String result = sendGetData(path, headMap);\n\n // Add assertions to verify the result\n // For example, if the result is expected to be a JSON string, you can use a JSON library to parse the result and assert the values\n // Here is an example using the JSON library Jackson\n ObjectMapper objectMapper = new ObjectMapper();\n JsonNode jsonNode = objectMapper.readTree(result);\n assertEquals(\"expectedValue\", jsonNode.get(\"key\").asText());\n }\n```\nPlease replace \"your_token\" with the actual token you want to use for the test. Also, replace \"expectedValue\" and \"key\" with the actual expected value and key from the JSON response.\n" + ], + "line": 51, + "token": 367, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 36, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ByteString retainKeyPrefix(ByteString tenantNS, List topicFilterLevels) {\n ByteString prefix = ByteString.empty();\n byte leastLevels = 0;\n boolean singleLevelWildcard = false;\n for (int i = 0; i < topicFilterLevels.size(); i++) {\n String tfl = topicFilterLevels.get(i);\n if (\"+\".equals(tfl)) {\n leastLevels++;\n singleLevelWildcard = true;\n continue;\n }\n if (\"#\".equals(tfl)) {\n break;\n }\n leastLevels++;\n if (!singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(tfl));\n }\n if (i + 1 < topicFilterLevels.size()) {\n if (!topicFilterLevels.get(i + 1).equals(\"#\") && !singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(NUL));\n }\n }\n }\n return tenantNS.concat(unsafeWrap(new byte[] {leastLevels})).concat(prefix);\n }\n```\n###test function signature:\n```java\n @Test\n public void testRetainKeyPrefix()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRetainKeyPrefix", + "reference": " @Test\n public void testRetainKeyPrefix() {\n String tenantId = \"tenantA\";\n ByteString tenantNS = tenantNS(tenantId);\n assertEquals(retainKeyPrefix(tenantNS, parse(\"#\", false)), tenantNS.concat(levelByte(0)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/#\", false)),\n tenantNS.concat(levelByte(1).concat(copyFromUtf8(\"a\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/+\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(\"a\").concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/b\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)).concat(copyFromUtf8(\"b\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/#\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+\", false)),\n tenantNS.concat(\n levelByte(3).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+/+\", false)),\n tenantNS.concat(\n levelByte(4).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/\", false)),\n tenantNS.concat(levelByte(4).concat(copyFromUtf8(NUL))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/+/\", false)),\n tenantNS.concat(levelByte(5).concat(copyFromUtf8(NUL))));\n }", + "completions": [ + "```java\n @Test\n public void testRetainKeyPrefix() {\n ByteString tenantNS = ByteString.copyFromUtf8(\"tenant1\");\n List topicFilterLevels = Arrays.asList(\"level1\", \"level2\", \"+\", \"#\");\n ByteString expectedPrefix = ByteString.copyFromUtf8(\"tenant1\\u0001level1\\u0000level2\");\n ByteString actualPrefix = retainKeyPrefix(tenantNS, topicFilterLevels);\n assertEquals(expectedPrefix, actualPrefix);\n }\n```\n" + ], + "line": 37, + "token": 255, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 37, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n @Transactional(rollbackFor = Exception.class)\n public void updateFileConfigMaster(Long id) {\n // 校验存在\n validateFileConfigExists(id);\n // 更新其它为非 master\n fileConfigMapper.updateBatch(new FileConfigDO().setMaster(false));\n // 更新\n fileConfigMapper.updateById(new FileConfigDO().setId(id).setMaster(true));\n // 发送刷新配置的消息\n TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {\n\n @Override\n public void afterCommit() {\n fileConfigProducer.sendFileConfigRefreshMessage();\n }\n\n });\n }\n```\n###test function signature:\n```java\n @Test\n public void testUpdateFileConfigMaster_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testUpdateFileConfigMaster_success", + "reference": " @Test\n public void testUpdateFileConfigMaster_success() {\n // mock 数据\n FileConfigDO dbFileConfig = randomFileConfigDO().setMaster(false);\n fileConfigMapper.insert(dbFileConfig);// @Sql: 先插入出一条存在的数据\n FileConfigDO masterFileConfig = randomFileConfigDO().setMaster(true);\n fileConfigMapper.insert(masterFileConfig);// @Sql: 先插入出一条存在的数据\n\n // 调用\n fileConfigService.updateFileConfigMaster(dbFileConfig.getId());\n // 断言数据\n assertTrue(fileConfigMapper.selectById(dbFileConfig.getId()).getMaster());\n assertFalse(fileConfigMapper.selectById(masterFileConfig.getId()).getMaster());\n // verify 调用\n verify(fileConfigProducer).sendFileConfigRefreshMessage();\n }", + "completions": [ + "```java\n @Test\n public void testUpdateFileConfigMaster_success() {\n // 准备数据\n FileConfigDO fileConfigDO = new FileConfigDO();\n fileConfigDO.setId(1L);\n fileConfigDO.setMaster(false);\n fileConfigMapper.insert(fileConfigDO);\n\n // 执行方法\n fileConfigService.updateFileConfigMaster(1L);\n\n // 验证结果\n FileConfigDO updatedFileConfigDO = fileConfigMapper.selectById(1L);\n Assert.assertTrue(updatedFileConfigDO.getMaster());\n\n // 清理数据\n fileConfigMapper.deleteById(1L);\n }\n```\n" + ], + "line": 28, + "token": 164, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\t@Override\n\tpublic void add(List documents) {\n\n\t\tAssert.notNull(documents, \"The document list should not be null.\");\n\t\tif (CollectionUtils.isEmpty(documents)) {\n\t\t\treturn; // nothing to do;\n\t\t}\n\n\t\tfinal var searchDocuments = documents.stream().map(document -> {\n\t\t\tfinal var embeddings = this.embeddingClient.embed(document);\n\t\t\tSearchDocument searchDocument = new SearchDocument();\n\t\t\tsearchDocument.put(ID_FIELD_NAME, document.getId());\n\t\t\tsearchDocument.put(EMBEDDING_FIELD_NAME, embeddings);\n\t\t\tsearchDocument.put(CONTENT_FIELD_NAME, document.getContent());\n\t\t\tsearchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());\n\n\t\t\t// Add the filterable metadata fields as top level fields, allowing filler\n\t\t\t// expressions on them.\n\t\t\tfor (MetadataField mf : this.filterMetadataFields) {\n\t\t\t\tif (document.getMetadata().containsKey(mf.name())) {\n\t\t\t\t\tsearchDocument.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name()));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn searchDocument;\n\t\t}).toList();\n\n\t\tIndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);\n\n\t\tfor (IndexingResult indexingResult : result.getResults()) {\n\t\t\tAssert.isTrue(indexingResult.isSucceeded(),\n\t\t\t\t\tString.format(\"Document with key %s upload is not successfully\", indexingResult.getKey()));\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void addAndSearchTest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "addAndSearchTest", + "reference": "\t@Test\n\tpublic void addAndSearchTest() {\n\n\t\tcontextRunner.run(context -> {\n\n\t\t\tVectorStore vectorStore = context.getBean(VectorStore.class);\n\n\t\t\tvectorStore.add(documents);\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\t\t\t}, hasSize(1));\n\n\t\t\tList results = vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\n\t\t\tassertThat(results).hasSize(1);\n\t\t\tDocument resultDoc = results.get(0);\n\t\t\tassertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());\n\t\t\tassertThat(resultDoc.getContent()).contains(\"The Great Depression (1929–1939) was an economic shock\");\n\t\t\tassertThat(resultDoc.getMetadata()).hasSize(2);\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"meta2\");\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"distance\");\n\n\t\t\t// Remove all documents from the store\n\t\t\tvectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Hello\").withTopK(1));\n\t\t\t}, hasSize(0));\n\t\t});\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void addAndSearchTest() {\n\t\t// Given\n\t\tList documents = new ArrayList<>();\n\t\tDocument document1 = new Document();\n\t\tdocument1.setId(\"1\");\n\t\tdocument1.setContent(\"This is a test document\");\n\t\tdocument1.setMetadata(new HashMap<>());\n\t\tdocument1.getMetadata().put(\"author\", \"John Doe\");\n\t\tdocument1.getMetadata().put(\"date\", \"2022-01-01\");\n\t\tdocuments.add(document1);\n\n\t\tDocument document2 = new Document();\n\t\tdocument2.setId(\"2\");\n\t\tdocument2.setContent(\"This is another test document\");\n\t\tdocument2.setMetadata(new HashMap<>());\n\t\tdocument2.getMetadata().put(\"author\", \"Jane Doe\");\n\t\tdocument2.getMetadata().put(\"date\", \"2022-01-02\");\n\t\tdocuments.add(document2);\n\n\t\t// When\n\t\tsearchService.add(documents);\n\n\t\t// Then\n\t\tSearchRequest searchRequest = new SearchRequest();\n\t\tsearchRequest.setQuery(\"test\");\n\t\tsearchRequest.setFilter(new HashMap<>());\n\t\tsearchRequest.getFilter().put(\"author\", \"John Doe\");\n\t\tSearchResponse searchResponse = searchService.search(searchRequest);\n\n\t\tAssert.assertEquals(1, searchResponse.getDocuments().size());\n\t\tAssert.assertEquals(\"1\", searchResponse.getDocuments().get(0).getId());\n\t}\n```\n" + ], + "line": 39, + "token": 312, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 39, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {\n long startTime = System.currentTimeMillis();\n\n // 校验参数\n try {\n validateParams(metrics);\n } catch (Exception e) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(e.getMessage());\n return;\n }\n\n HttpContext httpContext = createHttpContext(metrics.getNginx());\n HttpUriRequest request = createHttpRequest(metrics.getNginx());\n try (CloseableHttpResponse response = CommonHttpClient.getHttpClient().execute(request, httpContext)){\n // 发起http请求,获取响应数据\n int statusCode = response.getStatusLine().getStatusCode();\n if (statusCode != SUCCESS_CODE) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(\"StatusCode \" + statusCode);\n return;\n }\n String resp = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n\n Long responseTime = System.currentTimeMillis() - startTime;\n // 根据metrics name选择调用不同解析方法\n if (NGINX_STATUS_NAME.equals(metrics.getName()) || AVAILABLE.equals(metrics.getName())) {\n parseNginxStatusResponse(builder, resp, metrics, responseTime);\n } else if (REQ_STATUS_NAME.equals(metrics.getName())) {\n parseReqStatusResponse(builder, resp, metrics, responseTime);\n }\n } catch (Exception e) {\n String errorMsg = CommonUtil.getMessageFromThrowable(e);\n log.info(errorMsg);\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(errorMsg);\n } finally {\n if (request != null) {\n request.abort();\n }\n }\n\n }\n```\n###test function signature:\n```java\n @Test\n public void testNginxCollect()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNginxCollect", + "reference": " @Test\n public void testNginxCollect() {\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n long monitorId = 999;\n String app = \"testNginx\";\n\n Metrics metrics = new Metrics();\n metrics.setNginx(nginxProtocol);\n nginxCollect.collect(builder, monitorId, app, metrics);\n }", + "completions": [ + "```java\n @Test\n public void testNginxCollect() {\n // 创建一个Metrics对象\n Metrics metrics = new Metrics();\n metrics.setName(\"nginx_status\");\n metrics.setNginx(new Nginx(\"http://localhost:8080/nginx_status\"));\n\n // 创建一个CollectRep.MetricsData.Builder对象\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n\n // 创建一个NginxCollect对象\n NginxCollect nginxCollect = new NginxCollect();\n\n // 调用collect方法\n nginxCollect.collect(builder, 1, \"test\", metrics);\n\n // 验证结果\n assertEquals(CollectRep.Code.SUCCESS, builder.getCode());\n assertNotNull(builder.getData());\n assertTrue(builder.getData().getMetrics().containsKey(\"active_connections\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"accepts\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"handled\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"requests\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"reading\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"writing\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"waiting\"));\n }\n```\n" + ], + "line": 51, + "token": 335, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 40, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Transactional(rollbackOn = Exception.class)\n public void submit(@NonNull PartitionPlanConfig partitionPlanConfig)\n throws SchedulerException, ClassNotFoundException {\n Long databaseId = partitionPlanConfig.getDatabaseId();\n Validate.notNull(databaseId, \"DatabaseId can not be null\");\n // disable all related partition plan task\n Database database = this.databaseService.detail(databaseId);\n disablePartitionPlan(database.getId());\n PartitionPlanEntity partitionPlanEntity = modelToEntity(partitionPlanConfig);\n partitionPlanEntity = this.partitionPlanRepository.save(partitionPlanEntity);\n if (!partitionPlanConfig.isEnabled()\n || CollectionUtils.isEmpty(partitionPlanConfig.getPartitionTableConfigs())) {\n log.info(\"Partition plan is disabled or table config is empty, do nothing and return\");\n return;\n }\n Validate.isTrue(partitionPlanConfig.getCreationTrigger() != null, \"Creation trigger can not be null\");\n if (partitionPlanConfig.getDroppingTrigger() == null) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(partitionPlanConfig.getPartitionTableConfigs(),\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n return;\n }\n Map> strategy2TblCfgs =\n partitionPlanConfig.getPartitionTableConfigs().stream().flatMap(tableConfig -> {\n Map> strategy2Cfgs =\n tableConfig.getPartitionKeyConfigs().stream()\n .collect(Collectors.groupingBy(PartitionPlanKeyConfig::getStrategy));\n return strategy2Cfgs.values().stream().map(cfgs -> {\n PartitionPlanTableConfig cfg = new PartitionPlanTableConfig();\n cfg.setPartitionKeyConfigs(cfgs);\n cfg.setTableName(tableConfig.getTableName());\n cfg.setEnabled(tableConfig.isEnabled());\n cfg.setPartitionNameInvoker(tableConfig.getPartitionNameInvoker());\n cfg.setPartitionNameInvokerParameters(tableConfig.getPartitionNameInvokerParameters());\n return cfg;\n });\n }).collect(Collectors.groupingBy(cfg -> cfg.getPartitionKeyConfigs().get(0).getStrategy()));\n List createConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.CREATE);\n List dropConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.DROP);\n if (CollectionUtils.isNotEmpty(createConfigs)) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(createConfigs,\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n if (CollectionUtils.isNotEmpty(dropConfigs)) {\n ScheduleEntity dropScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getDroppingTrigger());\n createPartitionPlanTables(dropConfigs,\n partitionPlanEntity.getId(), dropScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "submit_bothCreateAndDropTrigger_submitSucceed", + "reference": " @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n PartitionPlanTableConfig tableConfig = new PartitionPlanTableConfig();\n tableConfig.setTableName(MYSQL_REAL_RANGE_TABLE_NAME);\n tableConfig.setPartitionNameInvoker(\"CUSTOM_PARTITION_NAME_GENERATOR\");\n SqlExprBasedGeneratorConfig config = new SqlExprBasedGeneratorConfig();\n config.setGenerateExpr(\"concat('p', date_format(from_unixtime(unix_timestamp(\"\n + \"STR_TO_DATE(20240125, '%Y%m%d')) + \"\n + PartitionPlanVariableKey.INTERVAL.getVariable() + \"), '%Y%m%d'))\");\n config.setIntervalGenerateExpr(\"86400\");\n tableConfig.setPartitionNameInvokerParameters(getSqlExprBasedNameGeneratorParameters(config));\n PartitionPlanKeyConfig c3Create = getMysqlc3CreateConfig();\n PartitionPlanKeyConfig datekeyCreate = getMysqldatekeyCreateConfig();\n PartitionPlanKeyConfig dropConfig = getDropConfig();\n tableConfig.setPartitionKeyConfigs(Arrays.asList(c3Create, datekeyCreate, dropConfig));\n\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setPartitionTableConfigs(Collections.singletonList(tableConfig));\n partitionPlanConfig.setFlowInstanceId(1L);\n partitionPlanConfig.setTimeoutMillis(180000L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setDatabaseId(1L);\n\n long t1 = System.currentTimeMillis();\n long t2 = System.currentTimeMillis() + 1000;\n TriggerConfig createTrigger = TestRandom.nextObject(TriggerConfig.class);\n createTrigger.setStartAt(new Date(t1));\n partitionPlanConfig.setCreationTrigger(createTrigger);\n\n TriggerConfig dropTrigger = TestRandom.nextObject(TriggerConfig.class);\n dropTrigger.setStartAt(new Date(t2));\n partitionPlanConfig.setDroppingTrigger(dropTrigger);\n Database database = TestRandom.nextObject(Database.class);\n database.setId(1L);\n Mockito.when(this.databaseService.detail(1L)).thenReturn(database);\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t1;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t2;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.doNothing().when(this.scheduleService).enable(Mockito.isA(ScheduleEntity.class));\n this.partitionPlanScheduleService.submit(partitionPlanConfig);\n List actuals = this.partitionPlanTableRepository.findAll();\n Set scheduleIds = actuals.stream()\n .map(PartitionPlanTableEntity::getScheduleId).collect(Collectors.toSet());\n Assert.assertEquals(scheduleIds.size(), 2);\n }", + "completions": [ + "```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n // Given\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setDatabaseId(1L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanConfig.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanConfig.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanConfig.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanConfig.setTaskId(\"taskId\");\n partitionPlanConfig.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanConfig.setTimeoutMillis(1000L);\n\n Database database = new Database();\n database.setId(1L);\n database.setName(\"database\");\n database.setType(DatabaseType.MYSQL);\n database.setHost(\"localhost\");\n database.setPort(3306);\n database.setUsername(\"root\");\n database.setPassword(\"password\");\n database.setDatabaseName(\"test\");\n database.setProperties(\"{}\");\n database.setStatus(DatabaseStatus.ONLINE);\n database.setCreateTime(new Date());\n database.setUpdateTime(new Date());\n when(databaseService.detail(1L)).thenReturn(database);\n\n PartitionPlanEntity partitionPlanEntity = new PartitionPlanEntity();\n partitionPlanEntity.setId(1L);\n partitionPlanEntity.setDatabaseId(1L);\n partitionPlanEntity.setEnabled(true);\n partitionPlanEntity.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanEntity.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanEntity.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanEntity.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanEntity.setTaskId(\"taskId\");\n partitionPlanEntity.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanEntity.setTimeoutMillis(1000L);\n when(partitionPlanRepository.save(any(PartitionPlanEntity.class))).thenReturn(partitionPlanEntity);\n\n ScheduleEntity createScheduleEntity = new ScheduleEntity();\n createScheduleEntity.setId(1L);\n createScheduleEntity.setDatabaseId(1L);\n createScheduleEntity.setTriggerConfig(new TriggerConfig(TriggerType" + ], + "line": 70, + "token": 530, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 41, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAllQueryWithFullTextSearchDisabled", + "reference": " @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n Instant time = Instant.now();\n Trace.KeyValue customField =\n Trace.KeyValue.newBuilder()\n .setVStr(\"value\")\n .setKey(\"customField\")\n .setFieldType(Schema.SchemaFieldType.KEYWORD)\n .build();\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time, List.of(customField)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n SearchResult termQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"customField:value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(termQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermStrQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermStrQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermNumericQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermNumericQuery.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"*\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNull(result.getInternalAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 42, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollExecutedCommitAsyncCallbacks", + "reference": " @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks() {\n consumer = newConsumer();\n MockCommitCallback callback = new MockCommitCallback();\n completeCommitAsyncApplicationEventSuccessfully();\n doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n completeFetchedCommittedOffsetApplicationEventSuccessfully(mkMap());\n\n consumer.assign(Collections.singleton(new TopicPartition(\"foo\", 0)));\n assertDoesNotThrow(() -> consumer.commitAsync(new HashMap<>(), callback));\n assertMockCommitCallbackInvoked(() -> consumer.poll(Duration.ZERO),\n callback,\n null);\n }", + "completions": [ + "```java\n @Test\n public void testPoll() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertNotNull(records);\n assertEquals(0, records.count());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 43, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static URI resolve(URI base, URI ref) {\n String path = ref.getRawPath();\n if (path != null && !path.isEmpty()) {\n path = base.getRawPath();\n if (path == null || !path.endsWith(\"/\")) {\n try {\n base = new URI(base.getScheme(), base.getAuthority(), base.getPath() + '/', null, null);\n } catch (URISyntaxException e) {\n throw new IllegalStateException(e);\n }\n }\n }\n return URIUtils.resolve(base, ref);\n }\n```\n###test function signature:\n```java\n @Test\n void testResolve_BasePathTrailingSlash()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testResolve_BasePathTrailingSlash", + "reference": " @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://host/sub/dir/\");\n assertEquals(\"http://host/sub/dir/file.jar\", resolve(base, \"file.jar\"));\n assertEquals(\"http://host/sub/dir/dir/file.jar\", resolve(base, \"dir/file.jar\"));\n assertEquals(\"http://host/sub/dir/?arg=val\", resolve(base, \"?arg=val\"));\n assertEquals(\"http://host/sub/dir/file?arg=val\", resolve(base, \"file?arg=val\"));\n assertEquals(\"http://host/sub/dir/dir/file?arg=val\", resolve(base, \"dir/file?arg=val\"));\n }", + "completions": [ + "```java\n @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://example.com/path/\");\n URI ref = URI.create(\"http://example.com/path/file.txt\");\n URI expected = URI.create(\"http://example.com/path/file.txt\");\n URI actual = resolve(base, ref);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 25, + "token": 187, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 44, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_collection_of_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_collection_of_non_serializable_object", + "reference": " @Test\n public void should_clone_collection_of_non_serializable_object() {\n List original = new ArrayList<>();\n original.add(new NonSerializableObject(\"value\"));\n List cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_collection_of_non_serializable_object() {\n // Given\n List nonSerializableList = new ArrayList<>();\n nonSerializableList.add(new NonSerializableObject(\"test\"));\n nonSerializableList.add(new NonSerializableObject(\"test2\"));\n\n // When\n List clonedList = objectCloner.clone(nonSerializableList);\n\n // Then\n assertNotSame(nonSerializableList, clonedList);\n assertEquals(nonSerializableList.size(), clonedList.size());\n for (int i = 0; i < nonSerializableList.size(); i++) {\n assertNotSame(nonSerializableList.get(i), clonedList.get(i));\n assertEquals(nonSerializableList.get(i), clonedList.get(i));\n }\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 45, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssign()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssign", + "reference": " @Test\n public void testAssign() {\n consumer = newConsumer();\n final TopicPartition tp = new TopicPartition(\"foo\", 3);\n consumer.assign(singleton(tp));\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().contains(tp));\n verify(applicationEventHandler).add(any(AssignmentChangeEvent.class));\n verify(applicationEventHandler).add(any(NewTopicsMetadataUpdateRequestEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testAssign() {\n // Create a mock Consumer instance\n Consumer mockConsumer = mock(Consumer.class);\n\n // Create a collection of TopicPartition instances\n Collection partitions = new ArrayList<>();\n partitions.add(new TopicPartition(\"testTopic\", 0));\n partitions.add(new TopicPartition(\"testTopic\", 1));\n\n // Call the assign method\n mockConsumer.assign(partitions);\n\n // Verify that the assign method was called with the correct parameters\n verify(mockConsumer).assign(partitions);\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 46, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String relativePath(Path absolutePath, Path basePath) {\n Preconditions.checkArgument(absolutePath.isAbsolute(), \"absolutePath must be an absolute path\");\n Preconditions.checkArgument(basePath.isAbsolute(), \"basePath must be an absolute path\");\n\n List absolutePathComponents = Lists.newArrayList(toPathComponents(absolutePath)); // make a copy\n List basePathComponents = toPathComponents(basePath);\n boolean hasCommonPrefix = basePathComponents.isEmpty(); // when basePath is \"/\", we always have a common prefix\n\n for (String base : basePathComponents) {\n if (absolutePathComponents.get(0).equals(base)) {\n absolutePathComponents.remove(0);\n hasCommonPrefix = true;\n } else {\n break;\n }\n }\n if (hasCommonPrefix) {\n return PATH_JOINER.join(absolutePathComponents);\n } else {\n return absolutePath.toString();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testRelativePath() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRelativePath", + "reference": " @Test\n public void testRelativePath() throws Exception {\n assertEquals(\"a\", PathUtils.relativePath(Path.of(\"/a\"), Path.of(\"/\")));\n assertEquals(\"b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/a\")));\n assertEquals(\"b/c.json\", PathUtils.relativePath(Path.of(\"/a/b/c.json\"), Path.of(\"/a\")));\n assertEquals(\"c/d/e\", PathUtils.relativePath(Path.of(\"/a/b/c/d/e\"), Path.of(\"/a/b\")));\n assertEquals(\"/a/b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/c/d\"))); // no common prefix\n }", + "completions": [ + "```java\n @Test\n public void testRelativePath() throws Exception {\n Path absolutePath = Paths.get(\"/a/b/c/d\");\n Path basePath = Paths.get(\"/a/b\");\n String expected = \"c/d\";\n String actual = relativePath(absolutePath, basePath);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 31, + "token": 234, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 47, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Boolean isJson(String value) {\n JsonElement jsonElement;\n try {\n jsonElement = JsonParser.parseString(value);\n } catch (Exception e) {\n return false;\n }\n if (jsonElement == null || jsonElement.isJsonNull()) {\n return false;\n }\n if (!jsonElement.isJsonObject() && !jsonElement.isJsonArray()) {\n return false;\n }\n return true;\n }\n```\n###test function signature:\n```java\n @Test\n void testIsJson()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJson", + "reference": " @Test\n void testIsJson() {\n assertThat(JsonUtil.isJson(\"value\")).isFalse();\n }", + "completions": [ + "```java\n @Test\n void testIsJson() {\n // Test case 1: Valid JSON object\n String jsonObject = \"{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}\";\n assertTrue(isJson(jsonObject));\n\n // Test case 2: Valid JSON array\n String jsonArray = \"[{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}, {\\\"name\\\":\\\"Jane\\\", \\\"age\\\":25, \\\"city\\\":\\\"Chicago\\\"}]\";\n assertTrue(isJson(jsonArray));\n\n // Test case 3: Invalid JSON (not an object or array)\n String invalidJson = \"\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"\";\n assertFalse(isJson(invalidJson));\n\n // Test case 4: Empty string\n String emptyString = \"\";\n assertFalse(isJson(emptyString));\n\n // Test case 5: Null value\n String nullValue = null;\n assertFalse(isJson(nullValue));\n }\n```\n" + ], + "line": 26, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchOnlyHistogram()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchOnlyHistogram", + "reference": " @Test\n public void testSearchOnlyHistogram() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 0,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(0);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(2);\n\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(1).getDocCount()).isEqualTo(1);\n\n assertThat(\n Long.parseLong(histogram.getBuckets().get(0).getKeyAsString()) >= time.toEpochMilli())\n .isTrue();\n assertThat(\n Long.parseLong(histogram.getBuckets().get(1).getKeyAsString())\n <= time.plusSeconds(10).toEpochMilli())\n .isTrue();\n }", + "completions": [ + "```java\n @Test\n public void testSearchOnlyHistogram() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 0;\n AggBuilder aggBuilder = new AggBuilder();\n aggBuilder.addHistogramAgg(\"testField\", \"testInterval\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation().getHistogramAggs().size() > 0);\n assertEquals(\"testField\", result.getAggregation().getHistogramAggs().get(0).getField());\n assertEquals(\"testInterval\", result.getAggregation().getHistogramAggs().get(0).getInterval());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 49, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent", + "reference": " @Test\n public void testReadComponent() {\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n EntityMetadata entityMetadata = mock(EntityMetadata.class);\n VariableSource variableSource = mock(VariableSource.class);\n when(contentPermissionChecker.isPermitted(anyString(), any(), eq(BreadActions.BROWSE), any())).thenReturn(true);\n when(component.getEntityMetadata()).thenReturn(entityMetadata);\n when(entityMetadata.getId()).thenReturn(new DetachedEntityId(\"someid\"));\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(Arrays.asList(asset));\n when(assetVariableResolver.fromAsset(asset)).thenReturn(variableSource);\n ComponentXO componentXO = underTest.readComponent(\"someid\", \"testRepositoryName\");\n\n assertThat(componentXO, is(notNullValue()));\n assertThat(componentXO.getId(), is(\"someid\"));\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n Supplier txSupplier = () -> storageTx;\n\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(txSupplier);\n when(storageTx.findComponent(componentId)).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(ImmutableList.of(asset));\n\n // When\n ComponentXO result = readComponent(repository, componentId);\n\n // Then\n verify(repository).facet(StorageFacet.class);\n verify(storageFacet).txSupplier();\n verify(storageTx).begin();\n verify(storageTx).findComponent(componentId);\n verify(storageTx).browseAssets(component);\n verify(storageTx).close();\n // Add assertions for the result\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 50, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ListResult listEntities(String tableName, String searchIndexName, List matchFilters,\n List queryFilters, List multiMatchFilter, String nextToken, List sorters, Class clazz) {\n SearchQuery searchQuery = createSearchQuery(matchFilters, queryFilters, multiMatchFilter);\n SearchRequest searchRequest = new SearchRequest(tableName, searchIndexName, searchQuery);\n if (!StringUtils.isEmpty(nextToken)) {\n byte[] tokenBytes = EncryptionUtil.decode(nextToken);\n searchRequest.getSearchQuery().setToken(tokenBytes);\n } else {\n if (sorters != null &&!sorters.isEmpty()) {\n searchQuery.setSort(new Sort(sorters));\n }\n }\n SearchRequest.ColumnsToGet columnsToGet = new SearchRequest.ColumnsToGet();\n columnsToGet.setColumns(ReflectionUtil.getPropertyNames(clazz));\n searchRequest.setColumnsToGet(columnsToGet);\n log.info(\"searchRequest:{}\", JsonUtil.toJsonString(searchRequest));\n SearchResponse searchResponse = otsClient.search(searchRequest);\n if (searchResponse == null || searchResponse.getRows() == null) {\n return ListResult.genSuccessListResult(null, 0);\n }\n byte[] nextTokenBytes = searchResponse.getNextToken();\n nextToken = nextTokenBytes == null || nextTokenBytes.length == 0 ? null : EncryptionUtil.encode(nextTokenBytes);\n List result = searchResponse.getRows().stream()\n .map(row -> OtsUtil.convertRowToDTO(row, clazz))\n .collect(Collectors.toList());\n return ListResult.genSuccessListResult(result, searchResponse.getTotalCount(), nextToken);\n }\n```\n###test function signature:\n```java\n @Test\n void testListEntities()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListEntities", + "reference": " @Test\n void testListEntities() {\n final List matchFilters = createMatchFilters();\n final List queryFilters = createQueryFilters();\n final ListResult expectedResult = new ListResult<>();\n FieldSort fieldSort = new FieldSort(OrderOtsConstant.GMT_CREATE_LONG);\n fieldSort.setOrder(SortOrder.DESC);\n expectedResult.setData(Arrays.asList());\n expectedResult.setCount(0L);\n String nextToken = \"CAESFQoTChEKDWdtdENyZWF0ZUxvbmcQARgBIlQKCQBI8UqGigEAAApHA0IAAAAxUzM1MzQzMTM0NjQzMjYzMzAzMzYyMzE2MTMzMzkzOTM1MzEzNjM2MzM2NDM2MzAzMDMwNjYzNTM1MzA2NjY0MzM=\";\n expectedResult.setNextToken(nextToken);\n final SearchResponse searchResponse = new SearchResponse(new Response(\"requestId\"));\n when(mockOtsClient.search(any(SearchRequest.class))).thenReturn(searchResponse);\n ListResult result = baseOtsHelper.listEntities(\"order\", \"order_index\", matchFilters, queryFilters, null, nextToken, Arrays.asList(fieldSort), OrderDTO.class);\n assertThat(result.getCount()).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n void testListEntities() {\n // Given\n String tableName = \"testTable\";\n String searchIndexName = \"testIndex\";\n List matchFilters = Arrays.asList(new OtsFilter(\"field1\", OtsFilter.CompareOperator.EQUAL, \"value1\"),\n new OtsFilter(\"field2\", OtsFilter.CompareOperator.EQUAL, \"value2\"));\n List queryFilters = Arrays.asList(new OtsFilter(\"field3\", OtsFilter.CompareOperator.EQUAL, \"value3\"),\n new OtsFilter(\"field4\", OtsFilter.CompareOperator.EQUAL, \"value4\"));\n List multiMatchFilter = Arrays.asList(new OtsFilter(\"field5\", OtsFilter.CompareOperator.EQUAL, \"value5\"),\n new OtsFilter(\"field6\", OtsFilter.CompareOperator.EQUAL, \"value6\"));\n String nextToken = \"testNextToken\";\n List sorters = Arrays.asList(new Sort.Sorter(\"field7\", Sort.SortOrder.ASC),\n new Sort.Sorter(\"field8\", Sort.SortOrder.DESC));\n Class clazz = TestDTO.class;\n\n // When\n ListResult result = listEntities(tableName, searchIndexName, matchFilters, queryFilters, multiMatchFilter, nextToken, sorters, clazz);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalCount());\n assertNull(result.getNextToken());\n assertTrue(result.getData().isEmpty());\n }\n```\n" + ], + "line": 38, + "token": 346, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 51, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String[] listFiles(URI fileUri, boolean recursive) throws IOException {\n try {\n ImmutableList.Builder builder = ImmutableList.builder();\n String continuationToken = null;\n boolean isDone = false;\n String prefix = normalizeToDirectoryPrefix(fileUri);\n int fileCount = 0;\n while (!isDone) {\n ListObjectsV2Request.Builder listObjectsV2RequestBuilder =\n ListObjectsV2Request.builder().bucket(fileUri.getHost());\n if (!prefix.equals(DELIMITER)) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.prefix(prefix);\n }\n if (!recursive) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.delimiter(DELIMITER);\n }\n if (continuationToken != null) {\n listObjectsV2RequestBuilder.continuationToken(continuationToken);\n }\n ListObjectsV2Request listObjectsV2Request = listObjectsV2RequestBuilder.build();\n LOG.debug(\"Trying to send ListObjectsV2Request {}\", listObjectsV2Request);\n ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);\n LOG.debug(\"Getting ListObjectsV2Response: {}\", listObjectsV2Response);\n List filesReturned = listObjectsV2Response.contents();\n fileCount += filesReturned.size();\n filesReturned.stream()\n .forEach(\n object -> {\n // Only add files and not directories\n if (!object.key().equals(fileUri.getPath())\n && !object.key().endsWith(DELIMITER)) {\n String fileKey = object.key();\n if (fileKey.startsWith(DELIMITER)) {\n fileKey = fileKey.substring(1);\n }\n builder.add(S3_SCHEME + fileUri.getHost() + DELIMITER + fileKey);\n }\n });\n if (fileCount == LIST_MAX_KEYS) {\n // check if we reached the max keys returned, if so abort and throw an error message\n LOG.error(\n \"Too many files ({}) returned from S3 when attempting to list object prefixes\",\n LIST_MAX_KEYS);\n throw new IllegalStateException(\n String.format(\n \"Max keys (%s) reached when attempting to list S3 objects\", LIST_MAX_KEYS));\n }\n isDone = !listObjectsV2Response.isTruncated();\n continuationToken = listObjectsV2Response.nextContinuationToken();\n }\n String[] listedFiles = builder.build().toArray(new String[0]);\n LOG.debug(\n \"Listed {} files from URI: {}, is recursive: {}\", listedFiles.length, fileUri, recursive);\n return listedFiles;\n } catch (Throwable t) {\n throw new IOException(t);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListFilesInFolderRecursive", + "reference": " @Test\n public void testListFilesInFolderRecursive() throws Exception {\n String folder = \"list-files-rec\";\n String[] nestedFolders = new String[] {\"list-files-child-1\", \"list-files-child-2\"};\n String[] originalFiles = new String[] {\"a-list-3.txt\", \"b-list-3.txt\", \"c-list-3.txt\"};\n\n List expectedResultList = new ArrayList<>();\n for (String childFolder : nestedFolders) {\n String folderName = folder + DELIMITER + childFolder;\n for (String fileName : originalFiles) {\n createEmptyFile(folderName, fileName);\n expectedResultList.add(\n String.format(FILE_FORMAT, SCHEME, bucket, folderName + DELIMITER + fileName));\n }\n }\n String[] actualFiles =\n s3BlobFs.listFiles(URI.create(String.format(FILE_FORMAT, SCHEME, bucket, folder)), true);\n\n actualFiles =\n Arrays.stream(actualFiles).filter(x -> x.contains(\"list-3\")).toArray(String[]::new);\n assertEquals(actualFiles.length, expectedResultList.size());\n assertTrue(Arrays.equals(expectedResultList.toArray(), actualFiles));\n }", + "completions": [ + "```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception {\n // Given\n URI fileUri = new URI(\"s3://bucket/folder/\");\n boolean recursive = true;\n String[] expectedFiles = {\n \"s3://bucket/folder/file1.txt\",\n \"s3://bucket/folder/file2.txt\",\n \"s3://bucket/folder/subfolder/file3.txt\",\n \"s3://bucket/folder/subfolder/file4.txt\"\n };\n\n // When\n String[] actualFiles = s3FileSystem.listFiles(fileUri, recursive);\n\n // Then\n assertArrayEquals(expectedFiles, actualFiles);\n }\n```\n" + ], + "line": 70, + "token": 490, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 52, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_nameDuplicate", + "reference": " @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // mock 数据\n RoleDO roleDO = randomPojo(RoleDO.class, o -> o.setName(\"role_name\"));\n roleMapper.insert(roleDO);\n // 准备参数\n String name = \"role_name\";\n\n // 调用,并断言异常\n assertServiceException(() -> roleService.validateRoleDuplicate(name, randomString(), null),\n ROLE_NAME_DUPLICATE, name);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // Given\n String name = \"duplicateName\";\n String code = \"uniqueCode\";\n Long id = 1L;\n RoleDO existingRole = new RoleDO();\n existingRole.setId(2L);\n existingRole.setName(name);\n existingRole.setCode(code);\n when(roleMapper.selectByName(name)).thenReturn(existingRole);\n\n // When\n try {\n roleService.validateRoleDuplicate(name, code, id);\n fail(\"Expected exception not thrown\");\n } catch (ServiceException e) {\n // Then\n assertEquals(ROLE_NAME_DUPLICATE, e.getCode());\n assertEquals(name, e.getArgs()[0]);\n }\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 53, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public String sign(Map map) {\n try {\n Date date = new Date(System.currentTimeMillis() + jwtProperties.getExpireTimeMills());\n Algorithm algorithm = Algorithm.HMAC256(jwtProperties.getTokenSecret());\n Map header = new HashMap<>(2);\n header.put(\"typ\", \"jwt\");\n JWTCreator.Builder builder = JWT.create()\n .withHeader(header)\n .withIssuedAt(new Date())\n .withExpiresAt(date);\n map.entrySet().forEach(entry -> {\n if (entry.getValue() instanceof Integer) {\n builder.withClaim(entry.getKey(), (Integer) entry.getValue());\n } else if (entry.getValue() instanceof Long) {\n builder.withClaim(entry.getKey(), (Long) entry.getValue());\n } else if (entry.getValue() instanceof Boolean) {\n builder.withClaim(entry.getKey(), (Boolean) entry.getValue());\n } else if (entry.getValue() instanceof String) {\n builder.withClaim(entry.getKey(), String.valueOf(entry.getValue()));\n } else if (entry.getValue() instanceof Double) {\n builder.withClaim(entry.getKey(), (Double) entry.getValue());\n } else if (entry.getValue() instanceof Date) {\n builder.withClaim(entry.getKey(), (Date) entry.getValue());\n }\n });\n return builder.sign(algorithm);\n } catch (Exception e) {\n return null;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sign_generateJwtToken_succeed()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sign_generateJwtToken_succeed", + "reference": " @Test\n public void sign_generateJwtToken_succeed() {\n Map map = new HashMap<>();\n map.put(\"userId\", 123);\n map.put(\"username\", \"testUser\");\n String token = jwtService.sign(map);\n Assert.assertNotNull(token);\n }", + "completions": [ + "```java\n @Test\n public void sign_generateJwtToken_succeed() {\n // Given\n Map map = new HashMap<>();\n map.put(\"userId\", 1);\n map.put(\"username\", \"testUser\");\n map.put(\"isAdmin\", true);\n map.put(\"expireTime\", System.currentTimeMillis() + 3600000);\n\n // When\n String jwtToken = jwtService.sign(map);\n\n // Then\n assertNotNull(jwtToken);\n assertTrue(jwtToken.length() > 0);\n }\n```\n" + ], + "line": 41, + "token": 349, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 54, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_non_serializable_value_with_null_value", + "reference": " @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map original = new LinkedHashMap<>();\n original.put(\"null\", null);\n original.put(\"key\", new NonSerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map map = new HashMap<>();\n map.put(\"key\", null);\n\n Map clonedMap = (Map) clone(map);\n\n assertNotSame(map, clonedMap);\n assertEquals(map, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 55, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void init(final DiscoveryConfig config) {\n if (this.namingService != null) {\n LOGGER.info(\"Nacos naming service already registered\");\n return;\n }\n Properties properties = config.getProps();\n Properties nacosProperties = new Properties();\n this.groupName = properties.getProperty(\"groupName\", \"SHENYU_GROUP\");\n String serverAddr = config.getServerList();\n nacosProperties.put(PropertyKeyConst.SERVER_ADDR, serverAddr);\n nacosProperties.put(PropertyKeyConst.NAMESPACE, properties.getProperty(NAMESPACE, \"\"));\n nacosProperties.put(PropertyKeyConst.USERNAME, properties.getProperty(PropertyKeyConst.USERNAME, \"\"));\n nacosProperties.put(PropertyKeyConst.PASSWORD, properties.getProperty(PropertyKeyConst.PASSWORD, \"\"));\n nacosProperties.put(PropertyKeyConst.ACCESS_KEY, properties.getProperty(PropertyKeyConst.ACCESS_KEY, \"\"));\n nacosProperties.put(PropertyKeyConst.SECRET_KEY, properties.getProperty(PropertyKeyConst.SECRET_KEY, \"\"));\n try {\n this.namingService = NamingFactory.createNamingService(nacosProperties);\n LOGGER.info(\"Nacos naming service initialized success\");\n } catch (NacosException e) {\n LOGGER.error(\"Error initializing Nacos naming service\", e);\n throw new ShenyuException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testInit", + "reference": " @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n // Set the discovery config\n setField(nacosDiscoveryServiceUnderTest.getClass(), \"namingService\", null);\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n config.setServerList(\"127.0.0.1:8848\");\n properties.setProperty(\"groupName\", \"SHENYU_GROUP\");\n config.setProps(properties);\n\n try (MockedStatic mockedNamingFactory = mockStatic(NamingFactory.class)) {\n // Mock the successful creation of NamingService\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenReturn(namingService);\n nacosDiscoveryServiceUnderTest.init(config);\n mockedNamingFactory.verify(() -> NamingFactory.createNamingService(any(Properties.class)));\n assertEquals(namingService, getField(nacosDiscoveryServiceUnderTest, \"namingService\"));\n // Mock the situation where NamingService fails to be created and throws an exception\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenThrow(new NacosException());\n assertDoesNotThrow(() -> nacosDiscoveryServiceUnderTest.init(config));\n }\n }", + "completions": [ + "```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n properties.put(\"groupName\", \"SHENYU_GROUP\");\n properties.put(\"serverAddr\", \"localhost:8848\");\n properties.put(\"namespace\", \"\");\n properties.put(\"username\", \"\");\n properties.put(\"password\", \"\");\n properties.put(\"accessKey\", \"\");\n properties.put(\"secretKey\", \"\");\n config.setProps(properties);\n config.setServerList(\"localhost:8848\");\n\n NacosDiscoveryService nacosDiscoveryService = new NacosDiscoveryService();\n nacosDiscoveryService.init(config);\n\n Field namingServiceField = NacosDiscoveryService.class.getDeclaredField(\"namingService\");\n namingServiceField.setAccessible(true);\n NamingService namingService = (NamingService) namingServiceField.get(nacosDiscoveryService);\n\n assertNotNull(namingService);\n }\n```\n" + ], + "line": 35, + "token": 273, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 56, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyRecoveryNotBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(850, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 2) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(750, 0, indexerConfig));\n\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n assertThat(recoveryTaskCreator.determineStartingOffset(1201, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTaskCreator.determineStartingOffset(1200, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(1150, 0, indexerConfig));\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore))\n .contains(recoveryTask1, recoveryTask11);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 57, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public boolean isJsonMime(String mediaType) {\n // \"* / *\" is default to JSON\n if (\"*/*\".equals(mediaType)) {\n return true;\n }\n\n try {\n return isJsonMime(MediaType.parseMediaType(mediaType));\n } catch (InvalidMediaTypeException e) {\n }\n return false;\n }\n```\n###test function signature:\n```java\n @Test\n public void testIsJsonMime()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJsonMime", + "reference": " @Test\n public void testIsJsonMime() {\n assertFalse(apiClient.isJsonMime(null));\n assertFalse(apiClient.isJsonMime(\"\"));\n assertFalse(apiClient.isJsonMime(\"text/plain\"));\n assertFalse(apiClient.isJsonMime(\"application/xml\"));\n assertFalse(apiClient.isJsonMime(\"application/jsonp\"));\n assertFalse(apiClient.isJsonMime(\"example/json\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+jsonx\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+xjson\"));\n\n assertTrue(apiClient.isJsonMime(\"application/json\"));\n assertTrue(apiClient.isJsonMime(\"application/json; charset=UTF8\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/JSON\"));\n\n assertTrue(apiClient.isJsonMime(\"application/problem+json\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/PROBLEM+JSON\"));\n assertTrue(apiClient.isJsonMime(\"application/json\\t\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+bar+json\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json;x;y\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json\\t;\"));\n assertTrue(apiClient.isJsonMime(\"Example/fOO+JSON\"));\n }", + "completions": [ + "```java\n @Test\n public void testIsJsonMime() {\n // Test case 1: \"* / *\"\n assertTrue(isJsonMime(\"*/*\"));\n\n // Test case 2: \"application/json\"\n assertTrue(isJsonMime(\"application/json\"));\n\n // Test case 3: \"application/json; charset=UTF-8\"\n assertTrue(isJsonMime(\"application/json; charset=UTF-8\"));\n\n // Test case 4: \"application/xml\"\n assertFalse(isJsonMime(\"application/xml\"));\n\n // Test case 5: \"text/plain\"\n assertFalse(isJsonMime(\"text/plain\"));\n\n // Test case 6: \"invalid/media/type\"\n assertFalse(isJsonMime(\"invalid/media/type\"));\n\n // Test case 7: null\n assertFalse(isJsonMime(null));\n }\n```\n" + ], + "line": 22, + "token": 139, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n dpopService.validateDpopThumprint(\"test\", \"test\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n // Given\n String existingThumprint = \"existingThumprint\";\n String requestThumprint = \"existingThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n // Call the method to test\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 59, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String getExpandUrl(String url, Object paramObj) {\n UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url);\n if (paramObj instanceof Map) {\n for (Map.Entry param : ((Map) paramObj).entrySet()) {\n if (null != param.getValue()) {\n builder.queryParam(param.getKey(), param.getValue());\n }\n }\n } else {\n for (Field param : paramObj.getClass().getDeclaredFields()) {\n param.setAccessible(true);\n\n try {\n Object value = param.get(paramObj);\n if (null != value) {\n if (value instanceof List) {\n builder.queryParam(param.getName(), (List) value);\n } else {\n builder.queryParam(param.getName(), value);\n }\n }\n } catch (IllegalAccessException e) {\n log.warn(\"Field: {} value retrieval failed, exception: {}\", param.getName(), e);\n }\n }\n }\n\n return builder.build().encode().toUri().toString();\n }\n```\n###test function signature:\n```java\n @Test\n void testGetExpandUrl()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetExpandUrl", + "reference": " @Test\n void testGetExpandUrl() {\n String result = HttpUtil.getExpandUrl(\"http://baidu.com\", \"paramObj\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testGetExpandUrl() {\n // Test with a Map\n Map map = new HashMap<>();\n map.put(\"param1\", \"value1\");\n map.put(\"param2\", \"value2\");\n String url = \"http://example.com\";\n String expected = \"http://example.com?param1=value1¶m2=value2\";\n String actual = getExpandUrl(url, map);\n assertEquals(expected, actual);\n\n // Test with a custom object\n TestObject obj = new TestObject(\"value1\", \"value2\");\n expected = \"http://example.com?param1=value1¶m2=value2\";\n actual = getExpandUrl(url, obj);\n assertEquals(expected, actual);\n }\n\n public static class TestObject {\n private String param1;\n private String param2;\n\n public TestObject(String param1, String param2) {\n this.param1 = param1;\n this.param2 = param2;\n }\n\n public String getParam1() {\n return param1;\n }\n\n public String getParam2() {\n return param2;\n }\n }\n```\n" + ], + "line": 38, + "token": 277, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 60, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"asdf\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n String invalidExpirationTime = \"invalid\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(invalidExpirationTime);\n assertEquals(CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 61, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Get\n @Path(\"/:indexName/_mapping\")\n public HttpResponse mapping(\n @Param(\"indexName\") Optional indexName,\n @Param(\"startTimeEpochMs\") Optional startTimeEpochMs,\n @Param(\"endTimeEpochMs\") Optional endTimeEpochMs)\n throws IOException {\n // Use a tree map so the results are naturally sorted\n Map> propertiesMap = new TreeMap<>();\n\n // we default the schema search to the last hour if params are not provided\n AstraSearch.SchemaResult schemaResult =\n searcher.getSchema(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(\n startTimeEpochMs.orElse(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build());\n\n Map schema = SearchResultUtils.fromSchemaResultProto(schemaResult);\n schema.forEach((key, value) -> propertiesMap.put(key, Map.of(\"type\", value.getName())));\n\n // todo - remove this after we add support for a \"date\" type\n // override the timestamp as a date field for proper autocomplete\n propertiesMap.put(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, Map.of(\"type\", \"date\"));\n\n return HttpResponse.of(\n HttpStatus.OK,\n MediaType.JSON,\n JsonUtil.writeAsString(\n ImmutableMap.of(\n indexName.orElseThrow(),\n ImmutableMap.of(\"mappings\", ImmutableMap.of(\"properties\", propertiesMap)))));\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexMapping() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexMapping", + "reference": " @Test\n public void testIndexMapping() throws IOException {\n AstraQueryServiceBase searcher = mock(AstraQueryServiceBase.class);\n ElasticsearchApiService serviceUnderTest = new ElasticsearchApiService(searcher);\n\n Instant start = Instant.now();\n Instant end = start.minusSeconds(60);\n\n when(searcher.getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build())))\n .thenReturn(AstraSearch.SchemaResult.newBuilder().build());\n\n HttpResponse response =\n serviceUnderTest.mapping(\n Optional.of(\"foo\"), Optional.of(start.toEpochMilli()), Optional.of(end.toEpochMilli()));\n verify(searcher)\n .getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build()));\n\n // handle response\n AggregatedHttpResponse aggregatedRes = response.aggregate().join();\n String body = aggregatedRes.content(StandardCharsets.UTF_8);\n JsonNode jsonNode = new ObjectMapper().readTree(body);\n\n assertThat(aggregatedRes.status().code()).isEqualTo(200);\n\n assertThat(jsonNode.findValue(\"foo\")).isNotNull();\n assertThat(\n jsonNode.findValue(\"foo\").findValue(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName))\n .isNotNull();\n\n when(searcher.getSchema(any()))\n .thenAnswer(\n invocationOnMock -> {\n AstraSearch.SchemaRequest request =\n ((AstraSearch.SchemaRequest) invocationOnMock.getArguments()[0]);\n assertThat(request.getDataset()).isEqualTo(\"bar\");\n assertThat(request.getStartTimeEpochMs())\n .isCloseTo(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli(),\n Offset.offset(1000L));\n assertThat(request.getEndTimeEpochMs())\n .isCloseTo(Instant.now().toEpochMilli(), Offset.offset(1000L));\n return AstraSearch.SchemaResult.newBuilder().build();\n });\n serviceUnderTest.mapping(Optional.of(\"bar\"), Optional.empty(), Optional.empty());\n }", + "completions": [ + "```java\n @Test\n public void testIndexMapping() throws IOException {\n Optional indexName = Optional.of(\"testIndex\");\n Optional startTimeEpochMs = Optional.of(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli());\n Optional endTimeEpochMs = Optional.of(Instant.now().toEpochMilli());\n\n AstraSearch.SchemaResult schemaResult =\n AstraSearch.SchemaResult.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(startTimeEpochMs.orElse(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build();\n\n when(searcher.getSchema(any(AstraSearch.SchemaRequest.class))).thenReturn(schemaResult);\n\n HttpResponse response = mapping(indexName, startTimeEpochMs, endTimeEpochMs);\n\n assertEquals(HttpStatus.OK, response.getStatus());\n assertEquals(MediaType.JSON, response.getContentType());\n assertNotNull(response.getContent());\n }\n```\n" + ], + "line": 43, + "token": 354, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 62, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "@SuppressWarnings", + "reference": " @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n consumer = newConsumer();\n Exception e = assertThrows(UnsupportedOperationException.class, () -> consumer.poll(0L));\n assertEquals(\"Consumer.poll(long) is not supported when \\\"group.protocol\\\" is \\\"consumer\\\". \" +\n \"This method is deprecated and will be removed in the next major release.\", e.getMessage());\n }", + "completions": [ + "```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 46, + "token": 373, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 63, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testNullSearchString()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNullSearchString", + "reference": " @Test\n public void testNullSearchString() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n assertThatExceptionOfType(IllegalArgumentException.class)\n .isThrownBy(\n () ->\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME + \"miss\",\n null,\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\")));\n }", + "completions": [ + "```java\n @Test\n public void testNullSearchString() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = null;\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n try {\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n fail(\"Expected an IllegalArgumentException to be thrown\");\n } catch (IllegalArgumentException e) {\n // Expected\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 64, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture sendMessageBack(ProxyContext ctx, ReceiptHandle handle, String messageId,\n ConsumerSendMsgBackRequestHeader requestHeader, long timeoutMillis) {\n // Build the response.\n final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null, null);\n\n Integer delayLevel = requestHeader.getDelayLevel();\n if (Objects.isNull(delayLevel)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument delay level is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n Long offset = requestHeader.getOffset();\n if (Objects.isNull(offset)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument offset is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n VirtualQueue virtualQueue = new VirtualQueue(requestHeader.getBname());\n\n CompletableFuture topicFuture = topicOf(requestHeader.getOriginTopic());\n CompletableFuture groupFuture = consumerGroupOf(requestHeader.getGroup());\n return topicFuture.thenCombine(groupFuture, (topic, group) -> {\n if (topic.getTopicId() != virtualQueue.topicId()) {\n LOGGER.error(\"Topic id in request header {} does not match topic id in message queue {}, maybe the topic is recreated.\",\n topic.getTopicId(), virtualQueue.topicId());\n throw new ProxyException(apache.rocketmq.v2.Code.TOPIC_NOT_FOUND, \"Topic resource does not exist.\");\n }\n return Pair.of(topic, group);\n }).thenCompose(pair -> {\n Topic topic = pair.getLeft();\n ConsumerGroup group = pair.getRight();\n\n return store.pull(StoreContext.EMPTY, group.getGroupId(), topic.getTopicId(), virtualQueue.physicalQueueId(),\n Filter.DEFAULT_FILTER, requestHeader.getOffset(), 1, false)\n .thenApply(pullResult -> {\n if (pullResult.status() == com.automq.rocketmq.store.model.message.PullResult.Status.FOUND) {\n return pullResult.messageList().get(0);\n }\n throw new ProxyException(apache.rocketmq.v2.Code.MESSAGE_NOT_FOUND, \"Message not found from server.\");\n }).thenCompose(messageExt -> {\n if (messageExt.deliveryAttempts() > group.getMaxDeliveryAttempt()) {\n return deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n }\n\n // Message consume retry strategy\n // <0: no retry,put into DLQ directly\n // =0: broker control retry frequency\n // >0: client control retry frequency\n return switch (Integer.compare(delayLevel, 0)) {\n case -1 ->\n deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n case 0 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n // Keep the same logic as apache RocketMQ.\n int serverDelayLevel = messageExt.deliveryAttempts() + 1;\n messageExt.setDeliveryAttempts(serverDelayLevel);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(serverDelayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put messageExt to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n case 1 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n messageExt.setDeliveryAttempts(messageExt.deliveryAttempts() + 1);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(delayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put message to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n default -> throw new IllegalStateException(\"Never reach here\");\n };\n });\n }).whenComplete((nil, throwable) -> {\n if (throwable != null) {\n response.setCode(ResponseCode.SYSTEM_ERROR);\n response.setRemark(throwable.getMessage());\n }\n }).thenApply(nil -> response);\n }\n```\n###test function signature:\n```java\n @Test\n void sendMessageBack()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendMessageBack", + "reference": " @Test\n void sendMessageBack() {\n String topicName = \"topic\";\n VirtualQueue virtualQueue = new VirtualQueue(2, 0);\n ConsumerSendMsgBackRequestHeader header = new ConsumerSendMsgBackRequestHeader();\n header.setBname(virtualQueue.brokerName());\n header.setGroup(\"group\");\n header.setOriginTopic(topicName);\n header.setOffset(0L);\n header.setDelayLevel(0);\n\n CompletionException exception = assertThrowsExactly(CompletionException.class, () -> messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join());\n assertInstanceOf(ProxyException.class, ExceptionUtils.getRealException(exception));\n ProxyException realException = (ProxyException) ExceptionUtils.getRealException(exception);\n assertEquals(realException.getErrorCode(), Code.MESSAGE_NOT_FOUND);\n assertEquals(realException.getMessage(), \"Message not found from server.\");\n\n long topicId = metadataService.topicOf(topicName).join().getTopicId();\n messageStore.put(StoreContext.EMPTY, FlatMessageUtil.convertTo(ProxyContextExt.create(), topicId, 0, \"\", new Message(topicName, \"\", new byte[] {})));\n\n // Broker controlled delay level.\n header.setDelayLevel(0);\n RemotingCommand response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Client controlled delay level.\n header.setDelayLevel(16);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Forward message into dead letter topic.\n header.setDelayLevel(-1);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.times(1)).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n }", + "completions": [ + "```java\n @Test\n void sendMessageBack() {\n // Given\n ProxyContext ctx = mock(ProxyContext.class);\n ReceiptHandle handle = mock(ReceiptHandle.class);\n String messageId = \"testMessageId\";\n ConsumerSendMsgBackRequestHeader requestHeader = mock(ConsumerSendMsgBackRequestHeader.class);\n long timeoutMillis = 1000L;\n\n when(requestHeader.getDelayLevel()).thenReturn(1);\n when(requestHeader.getOffset()).thenReturn(1L);\n when(requestHeader.getOriginTopic()).thenReturn(\"testOriginTopic\");\n when(requestHeader.getGroup()).thenReturn(\"testGroup\");\n when(requestHeader.getBname()).thenReturn(\"testBname\");\n\n Topic topic = mock(Topic.class);\n when(topic.getTopicId()).thenReturn(1);\n\n ConsumerGroup group = mock(ConsumerGroup.class);\n when(group.getGroupId()).thenReturn(\"testGroupId\");\n when(group.getMaxDeliveryAttempt()).thenReturn(3);\n\n FlatMessage message = mock(FlatMessage.class);\n when(message.deliveryAttempts()).thenReturn(1);\n when(message.originalOffset()).thenReturn(1L);\n\n MessageExt messageExt = mock(MessageExt.class);\n when(messageExt.message()).thenReturn(message);\n when(messageExt.deliveryAttempts()).thenReturn(1);\n\n PullResult pullResult = mock(PullResult.class);\n when(pullResult.status()).thenReturn(PullResult.Status.FOUND);\n when(pullResult.messageList()).thenReturn(List.of(messageExt));\n\n Store store = mock(Store.class);\n when(store.pull(any(), any(), any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(pullResult));\n when(store.put(any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n DeadLetterService deadLetterService = mock(DeadLetterService.class);\n when(deadLetterService.send(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n ProxyContextExt ctxExt = mock(ProxyContextExt.class);\n\n TopicService topicService = mock(TopicService.class);\n when(topicService.topicOf(any())).thenReturn(CompletableFuture.completedFuture(topic));\n when(topicService.consumerGroupOf(any())).thenReturn(CompletableFuture.completedFuture(group));\n\n ConsumerService consumerService = new ConsumerService(store, deadLetterService, topicService);\n\n // When\n CompletableFuture future = consumerService.sendMessageBack(ctx, handle, messageId, requestHeader, timeoutMillis);\n\n // Then\n RemotingCommand response = future.join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n }\n```\n" + ], + "line": 99, + "token": 793, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 65, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInRequestThreads()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInRequestThreads", + "reference": " @Test\n public void testGetTracerInRequestThreads() {\n ApolloAuditTracer mockTracer = new ApolloAuditTracer(Mockito.mock(ApolloAuditScopeManager.class), supplier);\n RequestAttributes mockRequestAttributes = Mockito.mock(RequestAttributes.class);\n RequestContextHolder.setRequestAttributes(mockRequestAttributes);\n Mockito.when(mockRequestAttributes.getAttribute(Mockito.eq(ApolloAuditConstants.TRACER), Mockito.eq(RequestAttributes.SCOPE_REQUEST)))\n .thenReturn(mockTracer);\n ApolloAuditTracer get = traceContext.tracer();\n assertNotNull(get);\n Mockito.verify(traceContext, Mockito.times(0))\n .setTracer(Mockito.any(ApolloAuditTracer.class));\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInRequestThreads() {\n // Given\n RequestAttributes requestAttributes = Mockito.mock(RequestAttributes.class);\n ApolloAuditTracer tracer = Mockito.mock(ApolloAuditTracer.class);\n Mockito.when(requestAttributes.getAttribute(ApolloAuditConstants.TRACER, RequestAttributes.SCOPE_REQUEST)).thenReturn(tracer);\n Mockito.when(RequestContextHolder.getRequestAttributes()).thenReturn(requestAttributes);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n Assert.assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Map generate(String modelName) {\n Map kv = new HashMap<>();\n if (modelName != null) {\n final Schema schema = this.globalContext.getSchemaMap().get(modelName);\n if (schema != null) {\n String example = Json.pretty(this.resolveModelToExample(modelName, schema));\n\n if (example != null) {\n kv.put(EXAMPLE, example);\n return Map.copyOf(kv);\n }\n }\n }\n return Collections.emptyMap();\n }\n```\n###test function signature:\n```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "shouldGenerateOneOfWhenOneOfInRoot", + "reference": " @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n OpenAPIModelGenerator generator = setupPayloadGenerator();\n\n Map example = generator.generate(\"PetType\");\n String exampleJson = example.get(\"example\");\n\n Assertions.assertThat(exampleJson)\n .contains(\"PetTypeONE_OF#/components/schemas/Husky\")\n .contains(\"PetTypeONE_OF#/components/schemas/Labrador\");\n }", + "completions": [ + "```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n // Given\n String modelName = \"ModelName\";\n Schema schema = new Schema();\n schema.setOneOf(Arrays.asList(new Schema(), new Schema()));\n this.globalContext.getSchemaMap().put(modelName, schema);\n\n // When\n Map result = this.generate(modelName);\n\n // Then\n assertNotNull(result);\n assertFalse(result.isEmpty());\n assertTrue(result.containsKey(EXAMPLE));\n assertNotNull(result.get(EXAMPLE));\n }\n```\n" + ], + "line": 25, + "token": 181, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInAnotherThreadButSameRequest", + "reference": " @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n ApolloAuditTracer mockTracer = Mockito.mock(ApolloAuditTracer.class);\n {\n Mockito.when(traceContext.tracer()).thenReturn(mockTracer);\n }\n CountDownLatch latch = new CountDownLatch(1);\n Executors.newSingleThreadExecutor().submit(() -> {\n ApolloAuditTracer tracer = traceContext.tracer();\n\n assertEquals(mockTracer, tracer);\n\n latch.countDown();\n });\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n // Given\n RequestAttributes requestAttributes = new ServletRequestAttributes(new MockHttpServletRequest());\n RequestContextHolder.setRequestAttributes(requestAttributes);\n ApolloAuditTracer tracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n requestAttributes.setAttribute(ApolloAuditConstants.TRACER, tracer, RequestAttributes.SCOPE_REQUEST);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchMultipleItemsAndIndices()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchMultipleItemsAndIndices", + "reference": " @Test\n public void testSearchMultipleItemsAndIndices() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(1);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchMultipleItemsAndIndices() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(10, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 5, + "token_diff": 6 + }, + { + "id": 69, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n static Map>> getCandidateJobs(String tableNameWithType,\n Map> allJobMetadata)\n throws Exception {\n long nowMs = System.currentTimeMillis();\n Map>> candidates = new HashMap<>();\n // If the job started most recently has already completed, then skip retry for the table.\n Pair latestStartedJob = null;\n Pair latestCompletedJob = null;\n // The processing order of job metadata from the given Map is not deterministic. Track the completed original\n // jobs so that we can simply skip the retry jobs belonging to the completed original jobs.\n Map completedOriginalJobs = new HashMap<>();\n Set cancelledOriginalJobs = new HashSet<>();\n for (Map.Entry> entry : allJobMetadata.entrySet()) {\n String jobId = entry.getKey();\n Map jobMetadata = entry.getValue();\n long statsUpdatedAt = Long.parseLong(jobMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));\n String jobStatsInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);\n if (StringUtils.isEmpty(jobStatsInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job progress stats\", jobId);\n continue;\n }\n String jobCtxInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n if (StringUtils.isEmpty(jobCtxInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job context\", jobId);\n continue;\n }\n TableRebalanceProgressStats jobStats = JsonUtils.stringToObject(jobStatsInStr, TableRebalanceProgressStats.class);\n TableRebalanceContext jobCtx = JsonUtils.stringToObject(jobCtxInStr, TableRebalanceContext.class);\n long jobStartTimeMs = jobStats.getStartTimeMs();\n if (latestStartedJob == null || latestStartedJob.getRight() < jobStartTimeMs) {\n latestStartedJob = Pair.of(jobId, jobStartTimeMs);\n }\n String originalJobId = jobCtx.getOriginalJobId();\n RebalanceResult.Status jobStatus = jobStats.getStatus();\n if (jobStatus == RebalanceResult.Status.DONE || jobStatus == RebalanceResult.Status.NO_OP) {\n LOGGER.info(\"Skip rebalance job: {} as it has completed with status: {}\", jobId, jobStatus);\n completedOriginalJobs.put(originalJobId, jobId);\n if (latestCompletedJob == null || latestCompletedJob.getRight() < jobStartTimeMs) {\n latestCompletedJob = Pair.of(jobId, jobStartTimeMs);\n }\n continue;\n }\n if (jobStatus == RebalanceResult.Status.FAILED || jobStatus == RebalanceResult.Status.ABORTED) {\n LOGGER.info(\"Found rebalance job: {} for original job: {} has been stopped with status: {}\", jobId,\n originalJobId, jobStatus);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n continue;\n }\n if (jobStatus == RebalanceResult.Status.CANCELLED) {\n LOGGER.info(\"Found cancelled rebalance job: {} for original job: {}\", jobId, originalJobId);\n cancelledOriginalJobs.add(originalJobId);\n continue;\n }\n // Check if an IN_PROGRESS job is still actively running.\n long heartbeatTimeoutMs = jobCtx.getConfig().getHeartbeatTimeoutInMs();\n if (nowMs - statsUpdatedAt < heartbeatTimeoutMs) {\n LOGGER.info(\"Rebalance job: {} is actively running with status updated at: {} within timeout: {}. Skip \"\n + \"retry for table: {}\", jobId, statsUpdatedAt, heartbeatTimeoutMs, tableNameWithType);\n return Collections.emptyMap();\n }\n // The job is considered failed, but it's possible it is still running, then we might end up with more than one\n // rebalance jobs running in parallel for a table. The rebalance algorithm is idempotent, so this should be fine\n // for the correctness.\n LOGGER.info(\"Found stuck rebalance job: {} for original job: {}\", jobId, originalJobId);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n }\n if (latestCompletedJob != null && latestCompletedJob.getLeft().equals(latestStartedJob.getLeft())) {\n LOGGER.info(\"Rebalance job: {} started most recently has already done. Skip retry for table: {}\",\n latestCompletedJob.getLeft(), tableNameWithType);\n return Collections.emptyMap();\n }\n for (String jobId : cancelledOriginalJobs) {\n LOGGER.info(\"Skip original job: {} as it's cancelled\", jobId);\n candidates.remove(jobId);\n }\n for (Map.Entry entry : completedOriginalJobs.entrySet()) {\n LOGGER.info(\"Skip original job: {} as it's completed by attempt: {}\", entry.getKey(), entry.getValue());\n candidates.remove(entry.getKey());\n }\n return candidates;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetCandidateJobs()\n throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetCandidateJobs", + "reference": " @Test\n public void testGetCandidateJobs()\n throws Exception {\n String tableName = \"table01\";\n Map> allJobMetadata = new HashMap<>();\n\n // Original job run as job1, and all its retry jobs failed too.\n RebalanceConfig jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n TableRebalanceProgressStats stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(1000);\n TableRebalanceContext jobCtx = TableRebalanceContext.forInitialAttempt(\"job1\", jobCfg);\n Map jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job1\", stats, jobCtx);\n allJobMetadata.put(\"job1\", jobMetadata);\n // 3 failed retry runs for job1\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 2, 1100, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 3, 1200, RebalanceResult.Status.ABORTED);\n allJobMetadata.put(\"job1_3\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 4, 1300, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_4\", jobMetadata);\n\n // Original job run as job2, and its retry job job2_1 completed.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(2000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job2\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job2\", stats, jobCtx);\n allJobMetadata.put(\"job2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job2\", 2, 2100, RebalanceResult.Status.DONE);\n allJobMetadata.put(\"job2_2\", jobMetadata);\n\n // Original job run as job3, and failed to send out heartbeat in time.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.IN_PROGRESS);\n stats.setStartTimeMs(3000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job3\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job3\", stats, jobCtx);\n jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"3000\");\n allJobMetadata.put(\"job3\", jobMetadata);\n\n // Original job run as job4, which didn't have retryJobCfg as from old version of the code.\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(4000);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job4\", stats, null);\n jobMetadata.remove(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n allJobMetadata.put(\"job4\", jobMetadata);\n\n // Only need to retry job1 and job3, as job2 is completed and job4 is from old version of code.\n Map>> jobs =\n RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 2);\n assertTrue(jobs.containsKey(\"job1\"));\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job1\").size(), 4); // four runs including job1,job1_1,job1_2,job1_3\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Abort job1 and cancel its retries, then only job3 is retry candidate.\n jobMetadata = allJobMetadata.get(\"job1_4\");\n cancelRebalanceJob(jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 1);\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Add latest job5 that's already done, thus no need to retry for table.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.DONE);\n stats.setStartTimeMs(5000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job5\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job5\", stats, jobCtx);\n allJobMetadata.put(\"job5\", jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 0);\n }", + "completions": [ + "```java\n @Test\n public void testGetCandidateJobs()\n throws Exception {\n // Given\n String tableNameWithType = \"table1_type1\";\n Map> allJobMetadata = new HashMap<>();\n Map jobMetadata1 = new HashMap<>();\n jobMetadata1.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"FAILED\\\"}\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job1\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job1\", jobMetadata1);\n Map jobMetadata2 = new HashMap<>();\n jobMetadata2.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"CANCELLED\\\"}\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job2\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job2\", jobMetadata2);\n\n // When\n Map>> result = getCandidateJobs(tableNameWithType, allJobMetadata);\n\n // Then\n assertEquals(1, result.size());\n assertTrue(result.containsKey(\"job2\"));\n Set> jobSet = result.get(\"job2\");\n assertEquals(1, jobSet.size());\n Pair jobPair = jobSet.iterator().next();\n assertEquals(\"job2\", jobPair.getLeft().getOriginalJobId());\n assertEquals(1640995200000L, jobPair.getRight().longValue());\n }\n```\n" + ], + "line": 94, + "token": 952, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 70, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchError()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchError", + "reference": " @Test\n public void distClientMatchError() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.ERROR);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.ERROR);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.NO_SUB));\n for (IMQTTTransientSession session : sessions) {\n verify(session, never()).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchError() {\n // Create a mock for the ITenantMeter interface\n ITenantMeter mockTenantMeter = Mockito.mock(ITenantMeter.class);\n // Create a mock for the ILocalDistService interface\n ILocalDistService mockLocalDistService = Mockito.mock(ILocalDistService.class);\n // Create a mock for the sessionMap\n Map mockSessionMap = new HashMap<>();\n // Create a mock for the routeMap\n Map> mockRouteMap = new HashMap<>();\n // Create a mock for the resourceThrottler\n IResourceThrottler mockResourceThrottler = Mockito.mock(IResourceThrottler.class);\n\n // Set up the mock objects\n Mockito.when(mockTenantMeter.get(Mockito.anyString())).thenReturn(mockTenantMeter);\n Mockito.when(mockLocalDistService.isGlobal(Mockito.anyString())).thenReturn(false);\n Mockito.when(mockSessionMap.get(Mockito.anyString())).thenReturn(null);\n Mockito.when(mockRouteMap.get(Mockito.any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(null));\n Mockito.when(mockResourceThrottler.hasResource(Mockito.anyString(), Mockito.anyLong())).thenReturn(false);\n\n // Create a new instance of the class under test\n DistService distService = new DistService(mockTenantMeter, mockLocalDistService, mockSessionMap, mockRouteMap, mockResourceThrottler);\n\n // Create a test request\n DeliveryRequest request = DeliveryRequest.newBuilder().build();\n\n // Call the method under test\n CompletableFuture replyFuture = distService.dist(request);\n\n // Verify the results\n DeliveryReply reply = replyFuture.join();\n Assert.assertEquals(0, reply.getResultMap().size());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 71, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchOK()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchOK", + "reference": " @Test\n public void distClientMatchOK() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n when(session.publish(any(), anyList())).thenReturn(true);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.OK);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.OK);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.OK));\n for (IMQTTTransientSession session : sessions) {\n verify(session, times(1)).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchOK() {\n // Given\n DeliveryRequest request = DeliveryRequest.newBuilder()\n .putPackage(\"tenant1\", DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .setMessagePack(TopicMessagePack.newBuilder()\n .setTopic(\"topic1\")\n .setPayload(\"payload1\")\n .build())\n .addMatchInfo(MatchInfo.newBuilder()\n .setReceiverId(\"receiver1\")\n .setTopicFilter(\"topic1\")\n .build())\n .build())\n .build())\n .build();\n sessionMap.put(\"receiver1\", mock(IMQTTTransientSession.class));\n when(sessionMap.get(\"receiver1\").publish(any(MatchInfo.class), anyList())).thenReturn(true);\n when(routeMap.get(any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(LocalRoutes.newBuilder()\n .setLocalizedReceiverId(\"receiver1\")\n .putRoute(\"receiver1\", mock(IMQTTTransientSession.class))\n .build()));\n when(ITenantMeter.get(\"tenant1\")).thenReturn(mock(ITenantMeter.class));\n when(resourceThrottler.hasResource(\"tenant1\", TotalTransientFanOutBytesPerSeconds)).thenReturn(true);\n\n // When\n CompletableFuture replyFuture = dist(request);\n\n // Then\n DeliveryReply reply = replyFuture.join();\n DeliveryResults results = reply.getResultMap().get(\"tenant1\");\n assertEquals(1, results.getResultCount());\n DeliveryResult result = results.getResult(0);\n assertEquals(DeliveryResult.Code.OK, result.getCode());\n assertEquals(\"receiver1\", result.getMatchInfo().getReceiverId());\n assertEquals(\"topic1\", result.getMatchInfo().getTopicFilter());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 72, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Range calcInterquartileRangeApproximation(List sortedValues) {\n Range iqr = null;\n\n if (sortedValues.size() > 0) {\n BigDecimal q1 = valueCloseToQ1(sortedValues);\n BigDecimal q3 = valueCloseToQ3(sortedValues);\n BigDecimal max = sortedValues.get(sortedValues.size() - 1);\n\n // if iqr == 0 AND max == q3 then recursively try finding a non-zero iqr approximation.\n if (q1.compareTo(q3) == 0 && max.compareTo(q3) == 0) {\n // filter out max and try again\n iqr = this.calcInterquartileRangeApproximation(\n sortedValues.stream().filter(d -> d.compareTo(max) < 0).collect(Collectors.toList()));\n }\n\n // if range is still empty use the original q1 and q3 values\n if (iqr == null || iqr.isEmpty()) {\n iqr = Range.closedOpen(q1, q3);\n }\n }\n\n return iqr;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCalcInterquartileRangeApproximation()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCalcInterquartileRangeApproximation", + "reference": " @Test\n public void testCalcInterquartileRangeApproximation() {\n List values = decList(1, 2, 3, 4, 5, 6, 7, 8);\n assertEquals(new BigDecimal(\"3\"), dataBinHelper.calcInterquartileRangeApproximation(values).lowerEndpoint());\n assertEquals(new BigDecimal(\"7\"), dataBinHelper.calcInterquartileRangeApproximation(values).upperEndpoint());\n }", + "completions": [ + "```java\n @Test\n public void testCalcInterquartileRangeApproximation() {\n List sortedValues = Arrays.asList(\n new BigDecimal(\"1.0\"),\n new BigDecimal(\"2.0\"),\n new BigDecimal(\"3.0\"),\n new BigDecimal(\"4.0\"),\n new BigDecimal(\"5.0\"),\n new BigDecimal(\"6.0\"),\n new BigDecimal(\"7.0\"),\n new BigDecimal(\"8.0\"),\n new BigDecimal(\"9.0\"),\n new BigDecimal(\"10.0\")\n );\n\n Range expectedIqr = Range.closedOpen(new BigDecimal(\"2.5\"), new BigDecimal(\"7.5\"));\n Range actualIqr = calcInterquartileRangeApproximation(sortedValues);\n\n assertEquals(expectedIqr, actualIqr);\n }\n```\n" + ], + "line": 30, + "token": 250, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 73, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_non_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new NonSerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new NonSerializableObject(\"name2\"),\n new NonSerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObjectWithNonSerializableNestedObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serialable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key1\", \"value1\");\n map.put(\"key2\", new NonSerializableObject(\"nonSerializableValue\"));\n\n // When\n Map clonedMap = objectCloner.clone(map);\n\n // Then\n assertNotSame(map, clonedMap);\n assertEquals(map.size(), clonedMap.size());\n assertEquals(map.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(map.get(\"key2\"), clonedMap.get(\"key2\"));\n assertEquals(map.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 74, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SignedJWT token2Jwt(String token) {\n if (token == null) {\n log.error(\"Token is null.\");\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n try {\n return SignedJWT.parse(token);\n } catch (ParseException e) {\n log.error(\"Token to JWT Error,message = {}\", e.getMessage());\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testToken2Jwt()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testToken2Jwt", + "reference": " @Test\n public void testToken2Jwt() {\n Assertions.assertThrows(BizException.class, () -> tokenParseHelper.token2Jwt(null));\n }", + "completions": [ + "```java\n @Test\n public void testToken2Jwt() {\n // Given\n String token = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c\";\n\n // When\n SignedJWT jwt = token2Jwt(token);\n\n // Then\n assertNotNull(jwt);\n assertEquals(token, jwt.serialize());\n }\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testLongPollWaitIsLimited()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testLongPollWaitIsLimited", + "reference": " @Test\n public void testLongPollWaitIsLimited() {\n consumer = newConsumer();\n String topicName = \"topic1\";\n consumer.subscribe(singletonList(topicName));\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertTrue(consumer.assignment().isEmpty());\n\n final int partition = 3;\n final TopicPartition tp = new TopicPartition(topicName, partition);\n final List> records = asList(\n new ConsumerRecord<>(topicName, partition, 2, \"key1\", \"value1\"),\n new ConsumerRecord<>(topicName, partition, 3, \"key2\", \"value2\")\n );\n\n // On the first iteration, return no data; on the second, return two records\n doAnswer(invocation -> {\n // Mock the subscription being assigned as the first fetch is collected\n consumer.subscriptions().assignFromSubscribed(Collections.singleton(tp));\n return Fetch.empty();\n }).doAnswer(invocation -> {\n return Fetch.forPartition(tp, records, true);\n }).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n\n // And then poll for up to 10000ms, which should return 2 records without timing out\n ConsumerRecords returnedRecords = consumer.poll(Duration.ofMillis(10000));\n assertEquals(2, returnedRecords.count());\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertEquals(singleton(tp), consumer.assignment());\n }", + "completions": [ + "```java\n @Test\n public void testLongPollWaitIsLimited() {\n // Arrange\n Duration timeout = Duration.ofMillis(1000);\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // Act\n long startTime = System.currentTimeMillis();\n ConsumerRecords records = consumer.poll(timeout);\n long endTime = System.currentTimeMillis();\n\n // Assert\n long pollDuration = endTime - startTime;\n assertTrue(pollDuration >= timeout.toMillis());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 76, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForMinAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForMinAgg", + "reference": " @Test\n public void testFullIndexSearchForMinAgg() {\n Instant time = Instant.now();\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new MinAggBuilder(\n \"test\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalMin internalMin =\n (InternalMin) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n assertThat(Double.valueOf(internalMin.getValue()).longValue()).isEqualTo(time.toEpochMilli());\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForMinAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new MinAggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation() instanceof MinAgg);\n assertTrue(result.getElapsedTimeMicros() > 0);\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 5, + "token_diff": 6 + }, + { + "id": 77, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_bsaChecksumFetched", + "reference": " @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n setupMocks();\n when(connection.getResponseCode()).thenReturn(SC_OK);\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream(\"bsa-checksum\\ndata\".getBytes(UTF_8)));\n LazyBlockList download = fetcher.tryFetch(BlockListType.BLOCK);\n assertThat(download.getName()).isEqualTo(BlockListType.BLOCK);\n assertThat(download.checksum()).isEqualTo(\"bsa-checksum\");\n verify(connection, times(1)).setRequestMethod(\"GET\");\n verify(connection, times(1)).setRequestProperty(\"Authorization\", \"Bearer authToken\");\n }", + "completions": [ + "```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n // Given\n URL dataUrl = new URL(\"https://example.com/bsaChecksum\");\n HttpsURLConnection connection = mock(HttpsURLConnection.class);\n when(connection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);\n when(connection.getErrorStream()).thenReturn(new ByteArrayInputStream(\"\".getBytes()));\n when(urlConnectionService.createConnection(dataUrl)).thenReturn(connection);\n\n // When\n LazyBlockList lazyBlockList = bsaService.tryFetch(BlockListType.BSA_CHECKSUM);\n\n // Then\n assertEquals(BlockListType.BSA_CHECKSUM, lazyBlockList.getBlockListType());\n assertEquals(connection, lazyBlockList.getConnection());\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 78, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDisabledFullTextSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDisabledFullTextSearch", + "reference": " @Test\n public void testDisabledFullTextSearch() {\n Instant time = Instant.ofEpochSecond(1593365471);\n Trace.KeyValue field1Tag =\n Trace.KeyValue.newBuilder()\n .setVInt32(1234)\n .setKey(\"field1\")\n .setFieldType(Schema.SchemaFieldType.INTEGER)\n .build();\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(2, \"apple baby\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(3, \"baby car 1234\", time.plusSeconds(4)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Without the _all field as default.\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(3);\n\n // empty string\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"app*\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Returns baby or car, 2 messages.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby car\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Test numbers\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"apple 1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"123\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n }", + "completions": [ + "```java\n @Test\n public void testDisabledFullTextSearch() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getResults().size());\n assertNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 5, + "token_diff": 6 + }, + { + "id": 79, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyWithException", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(\n OperateRuntimeException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHitsRelation.EQUAL_TO));\n Mockito.when(searchHits.getHits()).thenReturn(new SearchHit[0]);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processInstanceListViewService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException to be thrown\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find unique process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 80, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_SecurityException_notRetriable", + "reference": " @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n when(connectionService.createConnection(any(URL.class)))\n .thenThrow(new GeneralSecurityException());\n assertThat(\n assertThrows(BsaException.class, () -> fetcher.tryFetch(BlockListType.BLOCK))\n .isRetriable())\n .isFalse();\n }", + "completions": [ + "```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n // Arrange\n String blockListTypeName = \"TEST\";\n BlockListType blockListType = new BlockListType(blockListTypeName);\n URL dataUrl = new URL(\"http://example.com\");\n when(blockListUrls.get(blockListTypeName)).thenReturn(dataUrl.toString());\n when(urlConnectionService.createConnection(dataUrl)).thenThrow(new GeneralSecurityException());\n\n // Act and Assert\n assertThrows(BsaException.class, () -> {\n try {\n blockListService.tryFetch(blockListType);\n } catch (BsaException e) {\n assertFalse(e.isRetriable());\n throw e;\n }\n });\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 81, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssignOnEmptyTopicPartition()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssignOnEmptyTopicPartition", + "reference": " @Test\n public void testAssignOnEmptyTopicPartition() {\n consumer = newConsumer();\n completeUnsubscribeApplicationEventSuccessfully();\n\n consumer.assign(Collections.emptyList());\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().isEmpty());\n }", + "completions": [ + "```java\n @Test\n public void testAssignOnEmptyTopicPartition() {\n // Given\n Collection partitions = new ArrayList<>();\n ConsumerImpl consumer = new ConsumerImpl();\n\n // When\n consumer.assign(partitions);\n\n // Then\n // Add assertions to verify the expected behavior\n // For example, you can check if the consumer is unsubscribed and if the fetch buffer is empty\n // assertFalse(consumer.isSubscribed());\n // assertTrue(consumer.getFetchBuffer().isEmpty());\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 82, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ImmutableList applyLabelDiff(\n ImmutableList labels,\n IdnChecker idnChecker,\n DownloadSchedule schedule,\n DateTime now) {\n ImmutableList.Builder nonBlockedDomains = new ImmutableList.Builder<>();\n ImmutableMap> labelsByType =\n ImmutableMap.copyOf(\n labels.stream().collect(groupingBy(BlockLabel::labelType, toImmutableList())));\n\n tm().transact(\n () -> {\n for (Map.Entry> entry :\n labelsByType.entrySet()) {\n switch (entry.getKey()) {\n case CREATE:\n // With current Cloud SQL, label upsert throughput is about 200/second. If\n // better performance is needed, consider bulk insert in native SQL.\n tm().putAll(\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(\n label ->\n new BsaLabel(label.label(), schedule.jobCreationTime()))\n .collect(toImmutableList()));\n // May not find all unblockables due to race condition: DomainCreateFlow uses\n // cached BsaLabels. Eventually will be consistent.\n nonBlockedDomains.addAll(\n tallyUnblockableDomainsForNewLabels(entry.getValue(), idnChecker, now));\n break;\n case DELETE:\n ImmutableSet deletedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n // Delete labels in DB. Also cascade-delete BsaUnblockableDomain.\n int nDeleted = Queries.deleteBsaLabelByLabels(deletedLabels);\n if (nDeleted != deletedLabels.size()) {\n logger.atSevere().log(\n \"Only found %s entities among the %s labels: [%s]\",\n nDeleted, deletedLabels.size(), deletedLabels);\n }\n break;\n case NEW_ORDER_ASSOCIATION:\n ImmutableSet affectedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n ImmutableSet labelsInDb =\n Queries.queryBsaLabelByLabels(affectedLabels)\n .map(BsaLabel::getLabel)\n .collect(toImmutableSet());\n verify(\n labelsInDb.size() == affectedLabels.size(),\n \"Missing labels in DB: %s\",\n LazyArgs.lazy(() -> difference(affectedLabels, labelsInDb)));\n\n // Reuse registered and reserved names that are already computed.\n Queries.queryBsaUnblockableDomainByLabels(affectedLabels)\n .map(BsaUnblockableDomain::toUnblockableDomain)\n .forEach(nonBlockedDomains::add);\n\n for (BlockLabel label : entry.getValue()) {\n getInvalidTldsForLabel(label, idnChecker)\n .map(tld -> UnblockableDomain.of(label.label(), tld, Reason.INVALID))\n .forEach(nonBlockedDomains::add);\n }\n break;\n }\n }\n },\n TRANSACTION_REPEATABLE_READ);\n logger.atInfo().log(\"Processed %s of labels.\", labels.size());\n return nonBlockedDomains.build();\n }\n```\n###test function signature:\n```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "applyLabelDiffs_newAssociationOfLabelToOrder", + "reference": " @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n tm().transact(\n () -> {\n tm().insert(new BsaLabel(\"label\", fakeClock.nowUtc()));\n tm().insert(new BsaUnblockableDomain(\"label\", \"app\", Reason.REGISTERED));\n });\n when(idnChecker.getSupportingTlds(any())).thenReturn(ImmutableSet.of(app));\n when(idnChecker.getForbiddingTlds(any()))\n .thenReturn(Sets.difference(ImmutableSet.of(dev), ImmutableSet.of()).immutableCopy());\n\n ImmutableList unblockableDomains =\n applyLabelDiff(\n ImmutableList.of(\n BlockLabel.of(\"label\", LabelType.NEW_ORDER_ASSOCIATION, ImmutableSet.of())),\n idnChecker,\n schedule,\n fakeClock.nowUtc());\n assertThat(unblockableDomains)\n .containsExactly(\n UnblockableDomain.of(\"label.app\", UnblockableDomain.Reason.REGISTERED),\n UnblockableDomain.of(\"label.dev\", UnblockableDomain.Reason.INVALID));\n assertThat(tm().transact(() -> tm().loadByKeyIfPresent(BsaLabel.vKey(\"label\")))).isPresent();\n assertThat(\n tm().transact(() -> tm().loadByKeyIfPresent(BsaUnblockableDomain.vKey(\"label\", \"app\"))))\n .isPresent();\n }", + "completions": [ + "```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n // Given\n ImmutableList labels = ImmutableList.of(BlockLabel.of(\"label1\", LabelType.NEW_ORDER_ASSOCIATION));\n IdnChecker idnChecker = mock(IdnChecker.class);\n DownloadSchedule schedule = mock(DownloadSchedule.class);\n DateTime now = DateTime.now();\n\n // When\n ImmutableList result = applyLabelDiff(labels, idnChecker, schedule, now);\n\n // Then\n // Add assertions here to verify the result\n }\n```\n" + ], + "line": 85, + "token": 602, + "line_diff": 6, + "token_diff": 5 + }, + { + "id": 83, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture createTopic(CreateTopicRequest request) {\n CompletableFuture future = new CompletableFuture<>();\n for (; ; ) {\n if (metadataStore.isLeader()) {\n try (SqlSession session = metadataStore.openSession()) {\n if (!metadataStore.maintainLeadershipWithSharedLock(session)) {\n continue;\n }\n TopicMapper topicMapper = session.getMapper(TopicMapper.class);\n if (null != topicMapper.get(null, request.getTopic())) {\n ControllerException e = new ControllerException(Code.DUPLICATED_VALUE,\n String.format(\"Topic %s was taken\", request.getTopic()));\n future.completeExceptionally(e);\n return future;\n }\n\n Topic topic = new Topic();\n topic.setName(request.getTopic());\n topic.setQueueNum(request.getCount());\n topic.setStatus(TopicStatus.TOPIC_STATUS_ACTIVE);\n topic.setAcceptMessageTypes(JsonFormat.printer().print(request.getAcceptTypes()));\n topic.setRetentionHours(request.getRetentionHours());\n topicMapper.create(topic);\n long topicId = topic.getId();\n List assignments = createQueues(IntStream.range(0, request.getCount()),\n topicId, session);\n // Commit transaction\n session.commit();\n\n // Cache new topic and queue assignments immediately\n topicCache.apply(List.of(topic));\n assignmentCache.apply(assignments);\n future.complete(topicId);\n } catch (ControllerException | InvalidProtocolBufferException e) {\n future.completeExceptionally(e);\n }\n return future;\n } else {\n Optional leaderAddress = metadataStore.electionService().leaderAddress();\n if (leaderAddress.isEmpty()) {\n return CompletableFuture.failedFuture(new ControllerException(Code.NO_LEADER_VALUE, \"No leader is elected yet\"));\n }\n metadataStore.controllerClient().createTopic(leaderAddress.get(), request).whenComplete((res, e) -> {\n if (null != e) {\n future.completeExceptionally(e);\n } else {\n future.complete(res);\n }\n });\n }\n break;\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCreateTopic()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCreateTopic", + "reference": " @Test\n public void testCreateTopic() {\n MetadataStore metadataStore = Mockito.mock(MetadataStore.class);\n ControllerClient controllerClient = Mockito.mock(ControllerClient.class);\n Mockito.when(controllerClient.createTopic(ArgumentMatchers.anyString(), ArgumentMatchers.any())).thenReturn(\n CompletableFuture.failedFuture(new CompletionException(new RuntimeException()))\n );\n Mockito.when(metadataStore.controllerClient()).thenReturn(controllerClient);\n Mockito.when(metadataStore.isLeader()).thenReturn(false);\n\n ElectionService electionService = Mockito.mock(ElectionService.class);\n Mockito.when(metadataStore.electionService()).thenReturn(electionService);\n\n Mockito.when(electionService.leaderAddress()).thenReturn(Optional.of(\"localhost:1234\"));\n\n TopicManager topicManager = new TopicManager(metadataStore);\n topicManager.createTopic(CreateTopicRequest.newBuilder().build());\n }", + "completions": [ + "```java\n @Test\n public void testCreateTopic() {\n // Given\n CreateTopicRequest request = new CreateTopicRequest();\n request.setTopic(\"testTopic\");\n request.setCount(3);\n request.setRetentionHours(24);\n request.setAcceptTypes(List.of(MessageType.MESSAGE_TYPE_JSON));\n\n // When\n CompletableFuture future = controller.createTopic(request);\n\n // Then\n assertDoesNotThrow(() -> {\n Long topicId = future.get();\n assertNotNull(topicId);\n assertTrue(topicId > 0);\n });\n }\n```\n" + ], + "line": 63, + "token": 424, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 84, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull(){\n try(MockedStatic sessionMockedStatic = Mockito.mockStatic(Session.class)){\n Session mockSession = mock(Session.class);\n Transport transport = mock(Transport.class);\n sessionMockedStatic.when(()->Session.getDefaultInstance(any(Properties.class))).thenReturn(mockSession);\n when(blacklistService.inBlacklist(anyString())).thenReturn(false);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n when(settingsService.getByCategoryAndKey(anyString(),anyString())).thenReturn(\"test\").thenReturn(5678).thenReturn(\"test\").thenReturn(\"test\").thenReturn(\"123456\");\n when(mockSession.getTransport(anyString())).thenReturn(transport);\n doNothing().when(transport).connect(anyString(),anyInt(),anyString(),anyString());\n doAnswer(invocationOnMock -> {\n InternetAddress[] internetAddressList = invocationOnMock.getArgument(1);\n Assertions.assertEquals(addressList.get(0),internetAddressList[0].getAddress());\n return null;\n }).when(transport).sendMessage(any(MimeMessage.class),any());\n mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n } catch (MessagingException e) {\n throw new RuntimeException(e);\n }\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = Arrays.asList(\"test1@example.com\", \"test2@example.com\");\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://example.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 85, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead", + "reference": " @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 1,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n AstraConfigs.IndexerConfig headLocationAndNoRecoveryConfig =\n AstraConfigs.IndexerConfig.newBuilder()\n .setCreateRecoveryTasksOnStart(false)\n .setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST)\n .build();\n\n // When there is no data and ReadFromLocationOnStart is set to LATEST, return the current head\n assertThat(\n recoveryTaskCreator.determineStartingOffset(1000, 0, headLocationAndNoRecoveryConfig))\n .isEqualTo(1000);\n\n // Data exists for not for this partition.\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTime = 1;\n final long endTime = 100;\n final long maxOffset = 100;\n\n final SnapshotMetadata partition1 =\n new SnapshotMetadata(name, path, startTime, endTime, maxOffset, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition1);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition1));\n\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n\n final SnapshotMetadata partition11 =\n new SnapshotMetadata(\n name + \"1\", path, endTime + 1, endTime * 2, maxOffset * 2, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition11);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition11));\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(partition1, partition11);\n assertThat(recoveryTaskCreator.determineStartingOffset(0, 0, indexerConfig)).isNegative();\n\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n \"2\",\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(false);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n\n // When\n long result = determineStartingOffset(currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(currentEndOffsetForPartition, result);\n }\n```\n" + ], + "line": 128, + "token": 935, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 86, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime =\n LocalDateTime.of(2020, 10, 1, 10, 10, 0).atZone(ZoneOffset.UTC).toInstant();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n long startOffset = 1;\n long endOffset = msgsToProduce - 1;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry))\n .isEqualTo(endOffset - startOffset + 1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(0);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(0);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isEmpty();\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 0,\n 100,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 6, + "token_diff": 5 + }, + { + "id": 87, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"3g\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n String expirationTime = \"10x\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(expirationTime);\n assertEquals(CacheControlHeader.DEFAULT_BASE_VALUE + CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 88, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_empty_collection()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_empty_collection", + "reference": " @Test\n public void should_clone_empty_collection() {\n List original = new ArrayList<>();\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_empty_collection() {\n // Given\n Collection collection = new ArrayList<>();\n\n // When\n Collection clonedCollection = (Collection) clone(collection);\n\n // Then\n assertNotSame(collection, clonedCollection);\n assertEquals(collection, clonedCollection);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 89, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testTimeBoundSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testTimeBoundSearch", + "reference": " @Test\n public void testTimeBoundSearch() {\n Instant time = Instant.now();\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(1, time));\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(2, time.plusSeconds(100)));\n strictLogStore.logStore.commit();\n strictLogStore.logStore.refresh();\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(2);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(0);\n assertThat(getTimerCount(REFRESHES_TIMER, strictLogStore.metricsRegistry)).isEqualTo(1);\n\n // Start inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Extended range still only picking one element.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(90).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Both ranges are inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.toEpochMilli(),\n time.plusSeconds(100).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n // Extended range to pick up both events\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(1000).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n }", + "completions": [ + "```java\n @Test\n public void testTimeBoundSearch() {\n // Given\n String dataset = \"test_dataset\";\n String queryStr = \"test_query\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n for (LogMessage logMessage : result.getResults()) {\n assertTrue(logMessage.getTimestampMs() >= startTimeMsEpoch);\n assertTrue(logMessage.getTimestampMs() <= endTimeMsEpoch);\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 90, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "\t@ParameterizedTest\n\t@MethodSource", + "reference": "\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\t// setup\n\t\tString key = \"key\";\n\t\tTestParameters parameters = new TestParameters();\n\t\tObject testValue = theParams.getTestValue();\n\n\t\t// test\n\t\tif (theParams.isExpectedToWork()) {\n\t\t\tparameters.setUserData(key, testValue);\n\t\t\tassertFalse(parameters.getUserData().isEmpty());\n\t\t\tassertEquals(testValue, parameters.getUserData().get(key));\n\t\t} else {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, testValue);\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tString dataType = testValue.getClass().getName();\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid data type provided \" + dataType),\n\t\t\t\t\tex.getMessage());\n\t\t\t\tassertTrue(parameters.getUserData().isEmpty());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\tUser user = new User();\n\t\tuser.setUserData(theParams.getKey(), theParams.getValue());\n\t\tassertEquals(theParams.getValue(), user.getUserData().get(theParams.getKey()));\n\t}\n\n\tstatic Stream parameters() {\n\t\treturn Stream.of(\n\t\t\t\tnew TestParam(\"stringKey\", \"stringValue\"),\n\t\t\t\tnew TestParam(\"numberKey\", 123),\n\t\t\t\tnew TestParam(\"booleanKey\", true),\n\t\t\t\tnew TestParam(\"nullKey\", null)\n\t\t);\n\t}\n\n\tstatic class TestParam {\n\t\tprivate final String key;\n\t\tprivate final Object value;\n\n\t\tTestParam(String key, Object value) {\n\t\t\tthis.key = key;\n\t\t\tthis.value = value;\n\t\t}\n\n\t\tString getKey() {\n\t\t\treturn key;\n\t\t}\n\n\t\tObject getValue() {\n\t\t\treturn value;\n\t\t}\n\t}\n```\n" + ], + "line": 24, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\n\t\tString text = \"Hello, how are you?\";\n\n\t\tEmbedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));\n\n\t\tserver\n\t\t\t.expect(requestToUriTemplate(\"/models/{generative}:embedText?key={apiKey}\",\n\t\t\t\t\tVertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))\n\t\t\t.andExpect(method(HttpMethod.POST))\n\t\t\t.andExpect(content().json(objectMapper.writeValueAsString(Map.of(\"text\", text))))\n\t\t\t.andRespond(withSuccess(objectMapper.writeValueAsString(Map.of(\"embedding\", expectedEmbedding)),\n\t\t\t\t\tMediaType.APPLICATION_JSON));\n\n\t\tEmbedding embedding = client.embedText(text);\n\n\t\tassertThat(embedding).isEqualTo(expectedEmbedding);\n\n\t\tserver.verify();\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\tEmbeddingResponse response = new EmbeddingResponse(expectedEmbedding);\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(response);\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Mono syncAclWithAclCsv(KafkaCluster cluster, String csv) {\n return adminClientService.get(cluster)\n .flatMap(ac -> ac.listAcls(ResourcePatternFilter.ANY).flatMap(existingAclList -> {\n var existingSet = Set.copyOf(existingAclList);\n var newAcls = Set.copyOf(AclCsv.parseCsv(csv));\n var toDelete = Sets.difference(existingSet, newAcls);\n var toAdd = Sets.difference(newAcls, existingSet);\n logAclSyncPlan(cluster, toAdd, toDelete);\n if (toAdd.isEmpty() && toDelete.isEmpty()) {\n return Mono.empty();\n }\n log.info(\"Starting new ACLs creation\");\n return ac.createAcls(toAdd)\n .doOnSuccess(v -> {\n log.info(\"{} new ACLs created\", toAdd.size());\n log.info(\"Starting ACLs deletion\");\n })\n .then(ac.deleteAcls(toDelete)\n .doOnSuccess(v -> log.info(\"{} ACLs deleted\", toDelete.size())));\n }));\n }\n```\n###test function signature:\n```java\n @Test\n void testSyncAclWithAclCsv()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSyncAclWithAclCsv", + "reference": " @Test\n void testSyncAclWithAclCsv() {\n var existingBinding1 = new AclBinding(\n new ResourcePattern(ResourceType.TOPIC, \"*\", PatternType.LITERAL),\n new AccessControlEntry(\"User:test1\", \"*\", AclOperation.READ, AclPermissionType.ALLOW));\n\n var existingBinding2 = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"group1\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test2\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n var newBindingToBeAdded = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"groupNew\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test3\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n when(adminClientMock.listAcls(ResourcePatternFilter.ANY))\n .thenReturn(Mono.just(List.of(existingBinding1, existingBinding2)));\n\n ArgumentCaptor> createdCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.createAcls(createdCaptor.capture()))\n .thenReturn(Mono.empty());\n\n ArgumentCaptor> deletedCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.deleteAcls(deletedCaptor.capture()))\n .thenReturn(Mono.empty());\n\n aclsService.syncAclWithAclCsv(\n CLUSTER,\n \"Principal,ResourceType, PatternType, ResourceName,Operation,PermissionType,Host\\n\"\n + \"User:test1,TOPIC,LITERAL,*,READ,ALLOW,*\\n\"\n + \"User:test3,GROUP,PREFIXED,groupNew,DESCRIBE,DENY,localhost\"\n ).block();\n\n Collection createdBindings = createdCaptor.getValue();\n assertThat(createdBindings)\n .hasSize(1)\n .contains(newBindingToBeAdded);\n\n Collection deletedBindings = deletedCaptor.getValue();\n assertThat(deletedBindings)\n .hasSize(1)\n .contains(existingBinding2);\n }", + "completions": [ + "```java\n @Test\n void testSyncAclWithAclCsv() {\n // Given\n KafkaCluster cluster = new KafkaCluster(\"testCluster\", \"localhost:9092\");\n String csv = \"principal,host,operation,permissionType\\nUser:Alice,*,Read,Allow\\nUser:Bob,*,Write,Allow\";\n AdminClient adminClient = Mockito.mock(AdminClient.class);\n when(adminClientService.get(cluster)).thenReturn(Mono.just(adminClient));\n when(adminClient.listAcls(ResourcePatternFilter.ANY)).thenReturn(Flux.just(\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Alice\", \"*\", AclOperation.READ, AclPermissionType.ALLOW)),\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Bob\", \"*\", AclOperation.WRITE, AclPermissionType.ALLOW))\n ).collectList());\n\n // When\n Mono result = service.syncAclWithAclCsv(cluster, csv);\n\n // Then\n StepVerifier.create(result)\n .verifyComplete();\n\n verify(adminClient).createAcls(any());\n verify(adminClient).deleteAcls(any());\n }\n```\n" + ], + "line": 32, + "token": 259, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 93, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyNoResults", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final SearchResponse mockResponse = mock(SearchResponse.class);\n final SearchHits mockHits = mock(SearchHits.class);\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n when(mockHits.getHits()).thenReturn(new SearchHit[0]);\n when(tenantAwareClient.search(any(SearchRequest.class))).thenReturn(mockResponse);\n\n // When\n try {\n processInstanceService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 94, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyMultipleRecoveryBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(3);\n assertThat(recoveryTasks).contains(recoveryTask1, recoveryTask11);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo(recoveryStartOffset * 3 + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(100);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 95, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n private String saveAndChecksumBlockList(String jobName, LazyBlockList blockList) {\n BlobId blobId = getBlobId(jobName, getBlockListFileName(blockList.getName()));\n try (BufferedOutputStream gcsWriter =\n new BufferedOutputStream(gcsUtils.openOutputStream(blobId))) {\n MessageDigest messageDigest = MessageDigest.getInstance(checksumAlgorithm);\n blockList.consumeAll(\n (byteArray, length) -> {\n try {\n gcsWriter.write(byteArray, 0, length);\n } catch (IOException e) {\n throw new RuntimeException(e);\n }\n messageDigest.update(byteArray, 0, length);\n });\n return base16().lowerCase().encode(messageDigest.digest());\n } catch (IOException | NoSuchAlgorithmException e) {\n throw new RuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "saveAndChecksumBlockList_success", + "reference": " @Test\n void saveAndChecksumBlockList_success() throws Exception {\n String payload = \"somedata\\n\";\n String payloadChecksum = \"0737c8e591c68b93feccde50829aca86a80137547d8cfbe96bab6b20f8580c63\";\n\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream((\"bsa-checksum\\n\" + payload).getBytes(UTF_8)));\n lazyBlockList = new LazyBlockList(BlockListType.BLOCK, connection);\n\n ImmutableMap checksums =\n gcsClient.saveAndChecksumBlockList(\"some-name\", ImmutableList.of(lazyBlockList));\n assertThat(gcsUtils.existsAndNotEmpty(BlobId.of(\"my-bucket\", \"some-name/BLOCK.csv\"))).isTrue();\n assertThat(checksums).containsExactly(BlockListType.BLOCK, payloadChecksum);\n assertThat(gcsClient.readBlockList(\"some-name\", BlockListType.BLOCK))\n .containsExactly(\"somedata\");\n }", + "completions": [ + "```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception {\n // Given\n String jobName = \"testJob\";\n LazyBlockList blockList = mock(LazyBlockList.class);\n BlobId blobId = mock(BlobId.class);\n BufferedOutputStream gcsWriter = mock(BufferedOutputStream.class);\n MessageDigest messageDigest = mock(MessageDigest.class);\n when(gcsUtils.openOutputStream(blobId)).thenReturn(gcsWriter);\n when(checksumAlgorithm).thenReturn(\"SHA-256\");\n when(messageDigest.digest()).thenReturn(new byte[] {1, 2, 3, 4, 5});\n when(base16().lowerCase().encode(messageDigest.digest())).thenReturn(\"0102030405\");\n\n // When\n String checksum = saveAndChecksumBlockList(jobName, blockList);\n\n // Then\n assertEquals(\"0102030405\", checksum);\n verify(gcsWriter).write(any(), anyInt(), anyInt());\n verify(messageDigest).update(any(), anyInt(), anyInt());\n }\n```\n" + ], + "line": 30, + "token": 211, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 96, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ListenableFuture> findMissingBlobs(\n Iterable blobDigests, RequestMetadata requestMetadata) {\n // Some requests have been blocked, and we should tell the client we refuse to perform a lookup.\n try {\n if (inDenyList(requestMetadata)) {\n return immediateFailedFuture(\n Status.UNAVAILABLE\n .withDescription(\"The action associated with this request is forbidden\")\n .asException());\n }\n } catch (IOException e) {\n return immediateFailedFuture(Status.fromThrowable(e).asException());\n }\n\n // Empty blobs are an exceptional case. Filter them out.\n // If the user only requested empty blobs we can immediately tell them we already have it.\n Iterable nonEmptyDigests =\n Iterables.filter(blobDigests, (digest) -> digest.getSizeBytes() != 0);\n if (Iterables.isEmpty(nonEmptyDigests)) {\n return immediateFuture(ImmutableList.of());\n }\n\n if (configs.getServer().isFindMissingBlobsViaBackplane()) {\n return findMissingBlobsViaBackplane(nonEmptyDigests, requestMetadata);\n }\n\n return findMissingBlobsQueryingEachWorker(nonEmptyDigests, requestMetadata);\n }\n```\n###test function signature:\n```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "findMissingBlobsTest_ViaBackPlane", + "reference": " @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n Set activeWorkers = ImmutableSet.of(\"worker1\", \"worker2\", \"worker3\");\n Set expiredWorkers = ImmutableSet.of(\"workerX\", \"workerY\", \"workerZ\");\n Set imposterWorkers = ImmutableSet.of(\"imposter1\", \"imposter2\", \"imposter3\");\n\n Set availableDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFound1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build());\n\n Set missingDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"missing1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build());\n\n Set digestAvailableOnImposters =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFoundOnImposter1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter3\").setSizeBytes(1).build());\n\n Set emptyDigests =\n new HashSet<>(\n Arrays.asList(\n Digest.newBuilder().setHash(\"empty1\").build(),\n Digest.newBuilder().setHash(\"empty2\").build()));\n\n Iterable allDigests =\n Iterables.concat(\n availableDigests,\n missingDigests,\n emptyDigests,\n digestAvailableOnImposters,\n Arrays.asList(\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build()));\n\n Map> digestAndWorkersMap = new HashMap<>();\n\n for (Digest digest : availableDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(activeWorkers));\n }\n for (Digest digest : missingDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(expiredWorkers));\n }\n for (Digest digest : digestAvailableOnImposters) {\n digestAndWorkersMap.put(digest, getRandomSubset(imposterWorkers));\n }\n\n BuildfarmConfigs buildfarmConfigs = instance.getBuildFarmConfigs();\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(true);\n Set activeAndImposterWorkers =\n Sets.newHashSet(Iterables.concat(activeWorkers, imposterWorkers));\n\n when(mockBackplane.getStorageWorkers()).thenReturn(activeAndImposterWorkers);\n when(mockBackplane.getBlobDigestsWorkers(any(Iterable.class))).thenReturn(digestAndWorkersMap);\n when(mockInstanceLoader.load(anyString())).thenReturn(mockWorkerInstance);\n when(mockWorkerInstance.findMissingBlobs(anyIterable(), any(RequestMetadata.class)))\n .thenReturn(Futures.immediateFuture(new ArrayList<>()));\n\n long serverStartTime = 1686951033L; // june 15th, 2023\n Map workersStartTime = new HashMap<>();\n for (String worker : activeAndImposterWorkers) {\n workersStartTime.put(worker, serverStartTime);\n }\n when(mockBackplane.getWorkersStartTimeInEpochSecs(activeAndImposterWorkers))\n .thenReturn(workersStartTime);\n long oneDay = 86400L;\n for (Digest digest : availableDigests) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime + oneDay);\n }\n for (Digest digest : digestAvailableOnImposters) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime - oneDay);\n }\n\n Iterable actualMissingDigests =\n instance.findMissingBlobs(allDigests, RequestMetadata.getDefaultInstance()).get();\n Iterable expectedMissingDigests =\n Iterables.concat(missingDigests, digestAvailableOnImposters);\n\n assertThat(actualMissingDigests).containsExactlyElementsIn(expectedMissingDigests);\n verify(mockWorkerInstance, atMost(3))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n verify(mockWorkerInstance, atLeast(1))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n\n for (Digest digest : actualMissingDigests) {\n assertThat(digest).isNotIn(availableDigests);\n assertThat(digest).isNotIn(emptyDigests);\n assertThat(digest).isIn(expectedMissingDigests);\n }\n\n // reset BuildfarmConfigs\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(false);\n }", + "completions": [ + "```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n // Given\n Iterable blobDigests = Arrays.asList(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n RequestMetadata requestMetadata = RequestMetadata.newBuilder().setClientId(\"client1\").build();\n\n // When\n ListenableFuture> future =\n underTest.findMissingBlobs(blobDigests, requestMetadata);\n\n // Then\n Iterable missingBlobs = future.get();\n assertThat(missingBlobs).containsExactly(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n }\n```\n" + ], + "line": 37, + "token": 276, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 97, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent_notFound()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent_notFound", + "reference": " @Test\n public void testReadComponent_notFound() {\n Component component = mock(Component.class);\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(null);\n try {\n underTest.readComponent(\"someid\", \"testRepositoryName\");\n fail(\"Exception should have been thrown\");\n }\n catch (WebApplicationException e) {\n assertThat(e.getResponse(), is(notNullValue()));\n assertThat(e.getResponse().getStatus(), is(404));\n }\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent_notFound() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(() -> storageTx);\n when(storageTx.findComponent(componentId)).thenReturn(null);\n\n // When\n try {\n readComponent(repository, componentId);\n fail(\"Expected WebApplicationException\");\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Status.NOT_FOUND.getStatusCode(), e.getResponse().getStatus());\n }\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 98, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture auth(MQTT3AuthData authData) {\n try {\n HttpRequest request = HttpRequest.newBuilder()\n .uri(URI.create(webhookURI + \"/auth\"))\n .header(\"Content-Type\", \"application/json\")\n .POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))\n .timeout(Duration.ofSeconds(5))\n .build();\n return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())\n .thenApply(response -> {\n if (response.statusCode() == 200) {\n try {\n MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();\n JsonFormat.parser()\n .ignoringUnknownFields()\n .merge(response.body(), resultBuilder);\n return resultBuilder.build();\n } catch (Throwable e) {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build();\n }\n } else {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(\"Authenticate failed\")\n .build())\n .build();\n }\n })\n .exceptionally(e -> {\n System.out.println(\"Failed to call webhook: \" + e.getMessage());\n return null;\n });\n } catch (Throwable e) {\n return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAuth()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAuth", + "reference": " @Test\n public void testAuth() {\n MQTT3AuthData authedUser = MQTT3AuthData.newBuilder()\n .setUsername(\"authUser\")\n .build();\n\n MQTT3AuthData unauthUser = MQTT3AuthData.newBuilder()\n .setUsername(\"unauthUser\")\n .build();\n testServer.addAuthedUser(authedUser.getUsername());\n WebHookBasedAuthProvider provider = new WebHookBasedAuthProvider(testServer.getURI());\n MQTT3AuthResult authResult = provider.auth(authedUser).join();\n assertTrue(authResult.hasOk());\n\n authResult = provider.auth(unauthUser).join();\n assertFalse(authResult.hasOk());\n }", + "completions": [ + "```java\n @Test\n public void testAuth() {\n // Given\n String webhookURI = \"http://localhost:8080\";\n HttpClient httpClient = HttpClient.newHttpClient();\n MQTT3AuthData authData = MQTT3AuthData.newBuilder()\n .setClientId(\"client1\")\n .setUsername(\"user1\")\n .setPassword(\"password1\")\n .build();\n MQTT3AuthResult expectedResult = MQTT3AuthResult.newBuilder()\n .setSuccess(Success.newBuilder()\n .setIsSuperuser(false)\n .build())\n .build();\n\n // When\n CompletableFuture result = new MQTT3AuthService(webhookURI, httpClient).auth(authData);\n\n // Then\n assertNotNull(result);\n assertEquals(expectedResult, result.join());\n }\n```\n" + ], + "line": 59, + "token": 335, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 99, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_success", + "reference": " @Test\n public void testValidateRoleDuplicate_success() {\n // 调用,不会抛异常\n roleService.validateRoleDuplicate(randomString(), randomString(), null);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_success() {\n // Given\n String name = \"TestRole\";\n String code = \"TEST_ROLE\";\n Long id = 1L;\n\n // When\n // Call the method to be tested\n validateRoleDuplicate(name, code, id);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 1, + "token_diff": 1 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/7/QS/line_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Java/7/QS/line_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..0ae45f4bc2a89c42bf74d65e6e764a059247f1cb --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/7/QS/line_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(22~26),subset_1(27~36),subset_2(36~42),subset_3(42~45),subset_4(46~70),subset_5(70~85),subset_6(85~128) +StarCoder2-15b,34.01,28.36,31.34,37.87,28.95,23.71,18.04 +CodeLlama-7b,36.81,33.41,30.02,35.05,29.66,23.48,23.57 +CodeLlama-13b,30.91,32.25,27.06,34.66,26.05,25.24,21.20 +CodeLlama-34b,32.95,30.63,26.76,40.14,26.25,23.39,23.65 +DeepSeek-Coder-1.3b,29.10,31.76,26.04,35.73,30.30,27.63,19.63 +DeepSeek-Coder-6.7b,18.24,21.67,22.89,20.17,19.95,22.88,22.57 +DeepSeek-Coder-33b,30.65,38.49,30.31,40.46,29.93,27.16,20.20 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/7/QS/token_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Java/7/QS/token_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..b6eeee6e6bd52ab8fc0db7536a44a191f4d1401c --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/7/QS/token_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~183),subset_1(187~259),subset_2(273~312),subset_3(330~358),subset_4(358~508),subset_5(530~606),subset_6(606~952) +StarCoder2-15b,30.09,32.14,28.25,38.02,31.59,21.33,20.87 +CodeLlama-7b,31.89,35.35,30.66,36.52,30.35,25.93,21.53 +CodeLlama-13b,27.75,32.32,27.99,35.03,28.56,25.61,20.33 +CodeLlama-34b,30.76,30.96,27.46,38.94,28.51,24.47,22.80 +DeepSeek-Coder-1.3b,27.06,31.64,27.11,38.68,29.50,25.86,20.50 +DeepSeek-Coder-6.7b,19.38,20.49,24.06,20.51,18.16,23.44,22.34 +DeepSeek-Coder-33b,30.38,38.09,26.98,41.58,32.81,25.19,22.22 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/7/QS/tongji.py b/dataset/Test Generation/ComplexCodeEval-Java/7/QS/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/7/QS/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/8/EI/EI.json b/dataset/Test Generation/ComplexCodeEval-Java/8/EI/EI.json new file mode 100644 index 0000000000000000000000000000000000000000..87a4dd39d36b227f5a8a617ca5c266dc9bd85e25 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/8/EI/EI.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTask", + "reference": " @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(TEST_S3_BUCKET);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 100L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockKafkaConsumer.consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(true);\n Mockito.when(mockKafkaConsumer.close()).thenReturn(true);\n Mockito.when(mockChunkManager.stopAsync()).thenReturn(true);\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenReturn(true);\n Mockito.when(AstraKafkaConsumer.makeKafkaConfig(Mockito.any(), Mockito.anyInt())).thenReturn(new Properties());\n Mockito.when(RecoveryChunkManager.fromConfig(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(mockChunkManager);\n Mockito.when(adminClient.listConsumerGroupOffsets(Mockito.anyString())).thenReturn(new ConsumerGroupOffsets());\n Mockito.when(AstraConfig.getRecoveryConfig()).thenReturn(new RecoveryConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig()).thenReturn(new KafkaConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic()).thenReturn(\"testTopic\");\n Mockito.when(AstraConfig.getIndexerConfig()).thenReturn(new IndexerConfig());\n Mockito.when(AstraConfig.getS3Config()).thenReturn(new S3Config());\n Mockito.when(meterRegistry.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class));\n Mockito.when(Timer.start(Mockito.any())).thenReturn(Mockito.mock(Timer.Sample.class));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n Assert.assertTrue(result);\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).prepConsumerForConsumption(Mockito.anyLong());\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong());\n Mockito.verify(mockChunkManager, Mockito.times(1)).waitForRollOvers();\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).close();\n Mockito.verify(mockChunkManager, Mockito.times(1)).stopAsync();\n Mockito.verify(mockChunkManager, Mockito.times(1)).awaitTerminated(Mockito.anyLong());\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 5, + "token_diff": 3 + }, + { + "id": 1, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskFailure", + "reference": " @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n String fakeS3Bucket = \"fakeBucket\";\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(fakeS3Bucket);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isFalse();\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 10L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(false);\n Mockito.when(mockKafkaConsumer.close()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.stopAsync()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertFalse(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 5, + "token_diff": 3 + }, + { + "id": 2, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public int migration(String srcTable, List uuids, String targetTable) {\n if (Objects.isNull(srcTable) || Objects.isNull(targetTable) || CollectionUtils.isEmpty(uuids)) {\n throw new SwValidationException(\n ValidSubject.EVALUATION, \"srcTable, targetTable and uuids must not be null\");\n }\n var filter = TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build();\n return store.migration(DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(filter)\n .build()\n );\n }\n```\n###test function signature:\n```java\n @Test\n public void testMigration()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMigration", + "reference": " @Test\n public void testMigration() {\n var srcTable = \"p/1/space/1/eval/summary\";\n var targetTable = \"p/1/eval/summary\";\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(null, List.of(\"uuid1\"), targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, null, targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, List.of(\"uuid1\"), null));\n\n evaluationRepo.migration(srcTable, List.of(\"uuid1\", \"uuid2\"), targetTable);\n\n verify(dataStore, times(1)).migration(\n DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(\n TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(List.of(\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid1\")))\n .build(),\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid2\")))\n .build()))\n .build())\n .build()\n );\n }", + "completions": [ + "```java\n @Test\n public void testMigration() {\n // Given\n String srcTable = \"srcTable\";\n List uuids = Arrays.asList(\"uuid1\", \"uuid2\", \"uuid3\");\n String targetTable = \"targetTable\";\n DataStoreMigrationRequest expectedRequest = DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build()\n )\n .build();\n when(store.migration(any(DataStoreMigrationRequest.class))).thenReturn(1);\n\n // When\n int result = migrationService.migration(srcTable, uuids, targetTable);\n\n // Then\n assertEquals(1, result);\n verify(store).migration(expectedRequest);\n }\n```\n" + ], + "line": 37, + "token": 221, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDatasMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDatasMapTest", + "reference": " @Test\n public void sendDatasMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"doris_test_db\";\n List> dataList = new ArrayList<>();\n for (int i = 0; i < 100; i++) {\n Map data = new HashMap<>();\n data.put(\"id\", i);\n data.put(\"name\", \"张三\");\n data.put(\"message\", \"股东风波的法规的规定翻跟斗广泛的给yryrtytr\");\n System.out.println(gson.toJson(data));\n dataList.add(data);\n\n }\n dorisStreamLoad.sendData(database, tableName, columnList, dataList);\n }", + "completions": [ + "```java\n @Test\n public void sendDataTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 4, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_serializable_key_and_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_serializable_key_and_value", + "reference": " @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map original = new HashMap<>();\n original.put(\"key\", new SerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map originalMap = new HashMap<>();\n originalMap.put(\"one\", 1);\n originalMap.put(\"two\", 2);\n originalMap.put(\"three\", 3);\n\n Map clonedMap = clone(originalMap);\n\n assertNotSame(originalMap, clonedMap);\n assertEquals(originalMap, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 5, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void encode(String schemaName, long nextTransactionOffset) {\n if (schemaName == null) {\n schemaName = UNKNOWN;\n }\n this.schemaName = schemaName;\n this.nextTransactionOffset = nextTransactionOffset;\n final byte[] payloadBytes = payloadToBytes();\n final int payloadLength = payloadBytes.length;\n\n // set logEventHeader\n int eventSize = eventHeaderLengthVersionGt1 + payloadLength;\n setLogEventHeader(new LogEventHeader(drc_filter_log_event.getType(), 0, eventSize, nextTransactionOffset)\n );\n\n // set payload\n final ByteBuf payloadByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(payloadLength);\n payloadByteBuf.writeBytes(payloadBytes);\n payloadByteBuf.skipBytes(payloadLength);\n setPayloadBuf(payloadByteBuf);\n }\n```\n###test function signature:\n```java\n @Test\n public void encode()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "encode", + "reference": " @Test\n public void encode() {\n FilterLogEvent filterLogEvent = new FilterLogEvent();\n filterLogEvent.encode(\"drc1\", 101);\n Assert.assertEquals(\"drc1\", filterLogEvent.getSchemaName());\n Assert.assertEquals(101, filterLogEvent.getNextTransactionOffset());\n\n ByteBuf headerBuf = filterLogEvent.getLogEventHeader().getHeaderBuf();\n headerBuf.readerIndex(0);\n ByteBuf payloadBuf = filterLogEvent.getPayloadBuf();\n payloadBuf.readerIndex(0);\n CompositeByteBuf compositeByteBuf = PooledByteBufAllocator.DEFAULT.compositeDirectBuffer();\n compositeByteBuf.addComponents(true, headerBuf, payloadBuf);\n\n FilterLogEvent newFilterLogEvent = new FilterLogEvent();\n newFilterLogEvent.read(compositeByteBuf);\n compositeByteBuf.release(compositeByteBuf.refCnt());\n Assert.assertEquals(\"drc1\", newFilterLogEvent.getSchemaName());\n Assert.assertEquals(101, newFilterLogEvent.getNextTransactionOffset());\n }", + "completions": [ + "```java\n @Test\n public void encode() {\n // Given\n String schemaName = \"testSchema\";\n long nextTransactionOffset = 1234567890L;\n drc_filter_log_event event = new drc_filter_log_event();\n\n // When\n event.encode(schemaName, nextTransactionOffset);\n\n // Then\n assertEquals(schemaName, event.getSchemaName());\n assertEquals(nextTransactionOffset, event.getNextTransactionOffset());\n assertEquals(drc_filter_log_event.getType(), event.getLogEventHeader().getType());\n assertEquals(0, event.getLogEventHeader().getServerId());\n assertEquals(eventHeaderLengthVersionGt1 + event.getPayloadBuf().readableBytes(), event.getLogEventHeader().getEventLength());\n assertEquals(nextTransactionOffset, event.getLogEventHeader().getNextTransactionOffset());\n assertEquals(event.getPayloadBuf().readableBytes(), event.getPayloadBuf().readableBytes());\n }\n```\n" + ], + "line": 29, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void put(String path, long position, ByteBuf data) {\n writeLock.lock();\n try {\n int dataLength = data.readableBytes();\n NavigableMap cache = path2cache.computeIfAbsent(path, k -> new TreeMap<>());\n Map.Entry pos2value = cache.floorEntry(position);\n long cacheStartPosition;\n Value value;\n if (pos2value == null || pos2value.getKey() + pos2value.getValue().dataLength < position) {\n cacheStartPosition = position;\n value = Value.EMPTY;\n } else {\n cacheStartPosition = pos2value.getKey();\n value = pos2value.getValue();\n }\n // ensure the capacity, if the capacity change then update the cache index\n int moreCapacity = (int) ((position + dataLength) - (cacheStartPosition + value.blocks.length * (long) blockSize));\n int newDataLength = (int) (position + dataLength - cacheStartPosition);\n if (moreCapacity > 0) {\n int[] blocks = ensureCapacity(cacheStartPosition, moreCapacity);\n if (blocks == null) {\n return;\n }\n int[] newBlocks = new int[value.blocks.length + blocks.length];\n System.arraycopy(value.blocks, 0, newBlocks, 0, value.blocks.length);\n System.arraycopy(blocks, 0, newBlocks, value.blocks.length, blocks.length);\n value = new Value(newBlocks, newDataLength);\n } else {\n value = new Value(value.blocks, newDataLength);\n }\n cache.put(cacheStartPosition, value);\n lru.put(new Key(path, cacheStartPosition), value);\n\n // write data to cache\n ByteBuffer cacheByteBuffer = this.cacheByteBuffer.duplicate();\n int positionDelta = (int) (position - cacheStartPosition);\n int written = 0;\n ByteBuffer[] nioBuffers = data.nioBuffers();\n int[] blocks = value.blocks;\n for (ByteBuffer nioBuffer : nioBuffers) {\n ByteBuf buf = Unpooled.wrappedBuffer(nioBuffer);\n while (buf.readableBytes() > 0) {\n int writePosition = positionDelta + written;\n int block = blocks[writePosition / blockSize];\n cacheByteBuffer.position(block * blockSize + writePosition % blockSize);\n int length = Math.min(buf.readableBytes(), blockSize - writePosition % blockSize);\n cacheByteBuffer.put(buf.slice(buf.readerIndex(), length).nioBuffer());\n buf.skipBytes(length);\n written += length;\n }\n }\n } finally {\n writeLock.unlock();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testMergePut() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMergePut", + "reference": " @Test\n public void testMergePut() throws IOException {\n FileCache fileCache = new FileCache(\"/tmp/file_cache_test\", 10 * 1024, 1024);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 1, 500));\n buf.addComponent(true, genBuf((byte) 2, 500));\n buf.addComponent(true, genBuf((byte) 3, 500));\n fileCache.put(\"test\", 3333L, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(1500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 4, 500));\n buf.addComponent(true, genBuf((byte) 5, 500));\n buf.addComponent(true, genBuf((byte) 6, 500));\n fileCache.put(\"test\", 3333L + 1000, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(2500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 7, 500));\n fileCache.put(\"test\", 3333L + 1000 + 1500, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(3000, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n\n assertTrue(verify(fileCache.get(\"test\", 3333L, 500).get(), (byte) 1));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 500, 500).get(), (byte) 2));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1000, 500).get(), (byte) 4));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1500, 500).get(), (byte) 5));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2000, 500).get(), (byte) 6));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2500, 500).get(), (byte) 7));\n }", + "completions": [ + "```java\n @Test\n public void testMergePut() throws IOException {\n // Given\n String path = \"/test/path\";\n long position = 100;\n ByteBuf data = Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5});\n\n // When\n cache.put(path, position, data);\n\n // Then\n NavigableMap cacheMap = cache.path2cache.get(path);\n assertNotNull(cacheMap);\n Map.Entry entry = cacheMap.floorEntry(position);\n assertNotNull(entry);\n assertEquals(position, entry.getKey().longValue());\n assertEquals(5, entry.getValue().dataLength);\n assertEquals(1, entry.getValue().blocks[0]);\n assertEquals(2, entry.getValue().blocks[1]);\n assertEquals(3, entry.getValue().blocks[2]);\n assertEquals(4, entry.getValue().blocks[3]);\n assertEquals(5, entry.getValue().blocks[4]);\n }\n```\n" + ], + "line": 65, + "token": 508, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 7, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchById()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchById", + "reference": " @Test\n public void testSearchById() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult index =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(index.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchById() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 8, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = new ArrayList<>();\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 9, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_object", + "reference": " @Test\n public void should_clone_serializable_object() {\n Object original = new SerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_object() {\n // Given\n Serializable object = new Serializable() {\n private static final long serialVersionUID = 1L;\n // Implement necessary methods\n };\n\n // When\n Serializable clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertEquals(object, clonedObject);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 10, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() {\n\n\t\tvar text = \"Hello, how are you?\";\n\n\t\tEmbedding response = vertexAiPaLm2Api.embedText(text);\n\n\t\tassertThat(response).isNotNull();\n\t\tassertThat(response.value()).hasSize(768);\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(new EmbeddingResponse(expectedEmbedding));\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Optional maybeFetchErrorIntervalMs(short errorCode, int intervalMs) {\n if (errorCode == Errors.NONE.code())\n return Optional.empty();\n\n int pushIntervalMs;\n String reason;\n\n Errors error = Errors.forCode(errorCode);\n switch (error) {\n case INVALID_REQUEST:\n case INVALID_RECORD:\n case UNSUPPORTED_VERSION:\n pushIntervalMs = Integer.MAX_VALUE;\n reason = \"The broker response indicates the client sent an request that cannot be resolved\"\n + \" by re-trying, hence disable telemetry\";\n break;\n case UNKNOWN_SUBSCRIPTION_ID:\n case UNSUPPORTED_COMPRESSION_TYPE:\n pushIntervalMs = 0;\n reason = error.message();\n break;\n case TELEMETRY_TOO_LARGE:\n case THROTTLING_QUOTA_EXCEEDED:\n reason = error.message();\n pushIntervalMs = (intervalMs != -1) ? intervalMs : ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS;\n break;\n default:\n reason = \"Unwrapped error code\";\n log.error(\"Error code: {}. Unmapped error for telemetry, disable telemetry.\", errorCode);\n pushIntervalMs = Integer.MAX_VALUE;\n }\n\n log.debug(\"Error code: {}, reason: {}. Push interval update to {} ms.\", errorCode, reason, pushIntervalMs);\n return Optional.of(pushIntervalMs);\n }\n```\n###test function signature:\n```java\n @Test\n public void testMaybeFetchErrorIntervalMs()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMaybeFetchErrorIntervalMs", + "reference": " @Test\n public void testMaybeFetchErrorIntervalMs() {\n assertEquals(Optional.empty(), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.NONE.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_REQUEST.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_RECORD.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SUBSCRIPTION_ID.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNSUPPORTED_COMPRESSION_TYPE.code(), -1));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), 20000));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), 20000));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SERVER_ERROR.code(), -1));\n }", + "completions": [ + "```java\n @Test\n public void testMaybeFetchErrorIntervalMs() {\n // Test case 1: errorCode is Errors.NONE.code()\n short errorCode = Errors.NONE.code();\n int intervalMs = 1000;\n Optional result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertFalse(result.isPresent());\n\n // Test case 2: errorCode is Errors.INVALID_REQUEST.code()\n errorCode = Errors.INVALID_REQUEST.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 3: errorCode is Errors.UNKNOWN_SUBSCRIPTION_ID.code()\n errorCode = Errors.UNKNOWN_SUBSCRIPTION_ID.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 4: errorCode is Errors.TELEMETRY_TOO_LARGE.code()\n errorCode = Errors.TELEMETRY_TOO_LARGE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 5: errorCode is Errors.THROTTLING_QUOTA_EXCEEDED.code()\n errorCode = Errors.THROTTLING_QUOTA_EXCEEDED.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 6: errorCode is Errors.UNSUPPORTED_COMPRESSION_TYPE.code()\n errorCode = Errors.UNSUPPORTED_COMPRESSION_TYPE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 7: errorCode is Errors.INVALID_RECORD.code()\n errorCode = Errors.INVALID_RECORD.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 8: errorCode is Errors.UNSUPPORTED_VERSION.code()\n errorCode = Errors.UNSUPPORTED_VERSION.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 9: errorCode is an unmapped error\n errorCode = (short) 9999;\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n }\n```\n" + ], + "line": 43, + "token": 290, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 12, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_non_serializable_object", + "reference": " @Test\n public void should_clone_non_serializable_object() {\n Object original = new NonSerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_non_serializable_object() {\n // Given\n Map nonSerializableMap = new HashMap<>();\n nonSerializableMap.put(\"key1\", new Object());\n nonSerializableMap.put(\"key2\", new Object());\n\n // When\n Map clonedMap = objectCloner.clone(nonSerializableMap);\n\n // Then\n assertNotSame(nonSerializableMap, clonedMap);\n assertNotSame(nonSerializableMap.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(nonSerializableMap.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 13, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueOperationCompletesOperationWithCachedActionResult", + "reference": " @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ActionKey actionKey = DigestUtil.asActionKey(Digest.newBuilder().setHash(\"test\").build());\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"operation-with-cached-action-result\")\n .setActionDigest(actionKey.getDigest())\n .build();\n\n ActionResult actionResult = ActionResult.getDefaultInstance();\n\n when(mockBackplane.getActionResult(eq(actionKey))).thenReturn(actionResult);\n\n Poller poller = mock(Poller.class);\n\n instance.queue(executeEntry, poller, DEFAULT_TIMEOUT).get();\n\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(CACHE_CHECK));\n verify(mockBackplane, never()).putOperation(any(Operation.class), eq(QUEUED));\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(COMPLETED));\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n ActionResult actionResult = ActionResult.newBuilder().setExitCode(0).build();\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(\"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n when(cache.get(any(ActionKey.class))).thenReturn(actionResult);\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n future.get();\n\n verify(poller).pause();\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 14, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueActionFailsQueueEligibility", + "reference": " @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n Directory inputRoot = Directory.newBuilder().build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(false);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_INVALID)\n .setSubject(INVALID_PLATFORM)\n .setDescription(\n \"properties are not valid for queue eligibility: []. If you think your\"\n + \" queue should still accept these poperties without them being\"\n + \" specified in queue configuration, consider configuring the queue\"\n + \" with `allow_unmatched: True`\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(ACTION_DIGEST));\n when(executeEntry.getStdoutStreamName()).thenReturn(STDOUT_STREAM_NAME);\n when(executeEntry.getStderrStreamName()).thenReturn(STDERR_STREAM_NAME);\n when(executeEntry.getOperationName()).thenReturn(OPERATION_NAME);\n when(executeEntry.getRequestMetadata()).thenReturn(REQUEST_METADATA);\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n when(poller.pause()).thenThrow(new RuntimeException(\"Queue eligibility failed\"));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n\n assertThrows(RuntimeException.class, future::get);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 15, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new SerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new SerializableObject(\"name2\"),\n new SerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key\", new SerializableObject());\n SerializableObject object = new SerializableObject(map);\n\n // When\n SerializableObject clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertNotSame(object.getMap(), clonedObject.getMap());\n assertNotSame(object.getMap().get(\"key\"), clonedObject.getMap().get(\"key\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 16, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForSumAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForSumAgg", + "reference": " @Test\n public void testFullIndexSearchForSumAgg() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new SumAggBuilder(\"test\", TEST_SOURCE_LONG_PROPERTY, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalSum internalSum =\n (InternalSum) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n // 1, 3, 4, 5\n assertThat(internalSum.getValue()).isEqualTo(13);\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForSumAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new SumAggBuilder(\"testField\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertEquals(SumAggBuilder.class, result.getAggregation().getClass());\n assertEquals(\"testField\", ((SumAggBuilder) result.getAggregation()).getField());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 17, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void process(String base64AuthenticatorData, String signature, String clientDataJson, Fido2RegistrationData registration,\n Fido2AuthenticationData authenticationEntity) {\n log.debug(\"Registration: {}\", registration);\n\n AuthData authData = authenticatorDataParser.parseAssertionData(base64AuthenticatorData);\n commonVerifiers.verifyRpIdHash(authData, registration.getDomain());\n\n log.debug(\"User verification option: {}\", authenticationEntity.getUserVerificationOption());\n userVerificationVerifier.verifyUserVerificationOption(authenticationEntity.getUserVerificationOption(), authData);\n\n byte[] clientDataHash = DigestUtils.getSha256Digest().digest(base64Service.urlDecode(clientDataJson));\n\n try {\n int counter = authenticatorDataParser.parseCounter(authData.getCounters());\n commonVerifiers.verifyCounter(registration.getCounter(), counter);\n registration.setCounter(counter);\n\n JsonNode uncompressedECPointNode = dataMapperService.cborReadTree(base64Service.urlDecode(registration.getUncompressedECPoint()));\n PublicKey publicKey = coseService.createUncompressedPointFromCOSEPublicKey(uncompressedECPointNode);\n\n log.debug(\"Uncompressed ECpoint node: {}\", uncompressedECPointNode);\n log.debug(\"EC Public key hex: {}\", Hex.encodeHexString(publicKey.getEncoded()));\n log.debug(\"Registration algorithm: {}, default use: -7\", registration.getSignatureAlgorithm());\n authenticatorDataVerifier.verifyAssertionSignature(authData, clientDataHash, signature, publicKey, -7);\n\n } catch (Fido2CompromisedDevice ex) {\n log.error(\"Error compromised device: {}\", ex.getMessage());\n throw ex;\n } catch (Exception ex) {\n log.error(\"Error to check none assertion: {}\", ex.getMessage());\n throw new Fido2RuntimeException(\"Failed to check none assertion: {}\", ex.getMessage(), ex);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "process_ifCborReadTreeThrowException_fido2RuntimeException", + "reference": " @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n String base64AuthenticatorData = \"base64AuthenticatorData_test\";\n String signature = \"signature_test\";\n String clientDataJson = \"clientDataJson_test\";\n Fido2RegistrationData registration = mock(Fido2RegistrationData.class);\n Fido2AuthenticationData authenticationEntity = mock(Fido2AuthenticationData.class);\n\n when(authenticationEntity.getUserVerificationOption()).thenReturn(UserVerification.preferred);\n when(registration.getDomain()).thenReturn(\"domain_test\");\n when(registration.getCounter()).thenReturn(100);\n when(registration.getUncompressedECPoint()).thenReturn(\"uncompressedECPoint_test\");\n\n when(authenticatorDataParser.parseAssertionData(any())).thenReturn(mock(AuthData.class));\n when(base64Service.urlDecode(any(String.class))).thenReturn(\"decode_test\".getBytes());\n when(dataMapperService.cborReadTree(any(byte[].class))).thenThrow(new IOException(\"IOException_test\"));\n\n Fido2RuntimeException ex = assertThrows(Fido2RuntimeException.class, () -> noneAssertionFormatProcessor.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity));\n assertNotNull(ex);\n assertEquals(ex.getMessage(), \"IOException_test\");\n\n verify(log).debug(eq(\"Registration: {}\"), any(Fido2RegistrationData.class));\n verify(log).debug(eq(\"User verification option: {}\"), any(UserVerification.class));\n verify(commonVerifiers).verifyRpIdHash(any(AuthData.class), any(String.class));\n verify(authenticatorDataParser).parseCounter(any());\n verify(log).error(eq(\"Error to check none assertion: {}\"), any(String.class));\n\n verify(log, never()).error(eq(\"Error compromised device: {}\"), any(String.class));\n verifyNoInteractions(coseService, authenticatorDataVerifier);\n verifyNoMoreInteractions(log);\n }", + "completions": [ + "```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n // Arrange\n String base64AuthenticatorData = \"base64AuthenticatorData\";\n String signature = \"signature\";\n String clientDataJson = \"clientDataJson\";\n Fido2RegistrationData registration = new Fido2RegistrationData();\n Fido2AuthenticationData authenticationEntity = new Fido2AuthenticationData();\n\n when(authenticatorDataParser.parseAssertionData(base64AuthenticatorData)).thenReturn(new AuthData());\n when(authenticatorDataParser.parseCounter(any())).thenReturn(1);\n when(dataMapperService.cborReadTree(any())).thenThrow(new IOException());\n\n // Act & Assert\n Assertions.assertThrows(Fido2RuntimeException.class, () -> {\n fido2AuthenticationService.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity);\n });\n }\n```\n" + ], + "line": 38, + "token": 365, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 18, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void setUserData_invalidKey_throws()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "setUserData_invalidKey_throws", + "reference": "\t@Test\n\tpublic void setUserData_invalidKey_throws() {\n\t\t// setup\n\t\tTestParameters parameters = new TestParameters();\n\n\t\t// test\n\t\tfor (String key : new String[] { null, \"\" }) {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, \"test\");\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid key; key must be non-empty, non-null\"),\n\t\t\t\t\tex.getMessage());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@Test(expected = IllegalArgumentException.class)\n\tpublic void setUserData_invalidKey_throws() {\n\t\tTestClass testClass = new TestClass();\n\t\ttestClass.setUserData(\"\", \"value\");\n\t}\n```\n\nIn this test, we are creating an instance of the TestClass and calling the setUserData method with an empty key and a valid value. We expect this to throw an IllegalArgumentException, which is what we specify in the @Test annotation.\n" + ], + "line": 23, + "token": 170, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueDirectoryMissingErrorsOperation", + "reference": " @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ByteString foo = ByteString.copyFromUtf8(\"foo\");\n Digest subdirDigest = DIGEST_UTIL.compute(foo);\n Directory inputRoot =\n Directory.newBuilder()\n .addDirectories(\n DirectoryNode.newBuilder().setName(\"missing-subdir\").setDigest(subdirDigest))\n .build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(true);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_MISSING)\n .setSubject(\"blobs/\" + DigestUtil.toString(subdirDigest))\n .setDescription(\"The directory `/missing-subdir` was not found in the CAS.\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(10);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(DigestUtil.DEFAULT_INSTANCE_NAME + \"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n ActionKey actionKey = DigestUtil.asActionKey(executeEntry.getActionDigest());\n when(checkCacheFuture(actionKey, any(), any())).thenReturn(immediateFuture(true));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n assertThat(future.get()).isEqualTo(null);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 20, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime = Instant.now();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Start recovery with an offset range that is partially unavailable\n long startOffset = 50;\n long endOffset = 150;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry)).isEqualTo(50);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(51);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 100L,\n 200L,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 5, + "token_diff": 3 + }, + { + "id": 21, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public List selectHeaderAccept(String[] accepts) {\n if (accepts.length == 0) {\n return null;\n }\n for (String accept : accepts) {\n MediaType mediaType = MediaType.parseMediaType(accept);\n if (isJsonMime(mediaType) && !isProblemJsonMime(accept)) {\n return Collections.singletonList(mediaType);\n }\n }\n return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts));\n }\n```\n###test function signature:\n```java\n @Test\n public void testSelectHeaderAccept()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSelectHeaderAccept", + "reference": " @Test\n public void testSelectHeaderAccept() {\n String[] accepts = {\"application/json\", \"application/xml\"};\n assertEquals(\"application/json\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"APPLICATION/XML\", \"APPLICATION/JSON\"};\n assertEquals(\"APPLICATION/JSON\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"application/xml\", \"application/json; charset=UTF8\"};\n assertEquals(\"application/json; charset=UTF8\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"text/plain\", \"application/xml\"};\n assertEquals(\"text/plain,application/xml\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{};\n assertNull(apiClient.selectHeaderAccept(accepts));\n }", + "completions": [ + "```java\n @Test\n public void testSelectHeaderAccept() {\n // Given\n String[] accepts = {\"application/json\", \"application/problem+json\"};\n // When\n List result = selectHeaderAccept(accepts);\n // Then\n assertEquals(1, result.size());\n assertEquals(\"application/json\", result.get(0).toString());\n }\n```\n" + ], + "line": 23, + "token": 154, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollEventSentOnConsumerPoll", + "reference": " @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n SubscriptionState subscriptions = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);\n consumer = newConsumer(\n mock(FetchBuffer.class),\n new ConsumerInterceptors<>(Collections.emptyList()),\n mock(ConsumerRebalanceListenerInvoker.class),\n subscriptions,\n singletonList(new RoundRobinAssignor()),\n \"group-id\",\n \"client-id\");\n final TopicPartition tp = new TopicPartition(\"topic\", 0);\n final List> records = singletonList(\n new ConsumerRecord<>(\"topic\", 0, 2, \"key1\", \"value1\"));\n doAnswer(invocation -> Fetch.forPartition(tp, records, true))\n .when(fetchCollector)\n .collectFetch(Mockito.any(FetchBuffer.class));\n\n consumer.subscribe(singletonList(\"topic1\"));\n consumer.poll(Duration.ofMillis(100));\n verify(applicationEventHandler).add(any(PollEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n // Arrange\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n consumer.poll(Duration.ofMillis(100));\n\n // Act\n ConsumerRecords records = consumer.poll(Duration.ofMillis(100));\n\n // Assert\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 23, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetMultiplePartitionRecoveriesBehind", + "reference": " @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n final RecoveryTaskMetadata recoveryTask2 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"2\",\n \"2\",\n recoveryStartOffset * 3 + 1,\n recoveryStartOffset * 4,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask2);\n final RecoveryTaskMetadata recoveryTask21 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"21\", \"2\", recoveryStartOffset * 4 + 1, 50000, createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask21);\n await()\n .until(\n () ->\n recoveryTaskStore\n .listSync()\n .containsAll(\n List.of(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(5);\n assertThat(recoveryTasks)\n .contains(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n // Given\n String partitionId = \"partition1\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 24, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture commitStreamObject(apache.rocketmq.controller.v1.S3StreamObject streamObject,\n List compactedObjects) {\n LOGGER.info(\"commitStreamObject with streamObject: {}, compactedObjects: {}\", TextFormat.shortDebugString(streamObject),\n compactedObjects);\n\n CompletableFuture future = new CompletableFuture<>();\n try (SqlSession session = sessionFactory.openSession()) {\n if (streamObject.getObjectId() == S3Constants.NOOP_OBJECT_ID) {\n LOGGER.error(\"S3StreamObject[object-id={}] is null or objectId is unavailable\", streamObject.getObjectId());\n String msg = String.format(\"S3StreamObject[object-id=%d] is null or objectId is unavailable\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.NOT_FOUND_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n\n long committedTs = System.currentTimeMillis();\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n\n // commit object\n if (!commitObject(streamObject.getObjectId(), streamObject.getStreamId(), streamObject.getObjectSize(), session)) {\n String msg = String.format(\"S3StreamObject[object-id=%d] is not ready for commit\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.ILLEGAL_STATE_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n long dataTs = committedTs;\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n dataTs = compactedObjects.stream()\n .map(id -> {\n // mark destroy compacted object\n S3Object object = s3ObjectMapper.getById(id);\n object.setState(S3ObjectState.BOS_WILL_DELETE);\n object.setMarkedForDeletionTimestamp(new Date());\n s3ObjectMapper.markToDelete(object.getId(), new Date());\n\n // update dataTs to the min compacted object's dataTs\n com.automq.rocketmq.metadata.dao.S3StreamObject s3StreamObject =\n s3StreamObjectMapper.getByObjectId(id);\n return s3StreamObject.getBaseDataTimestamp().getTime();\n })\n .min(Long::compareTo).get();\n }\n\n List toCache = new ArrayList<>();\n\n // create a new S3StreamObject to replace committed ones\n if (streamObject.getObjectId() != S3Constants.NOOP_OBJECT_ID) {\n com.automq.rocketmq.metadata.dao.S3StreamObject newS3StreamObj =\n new com.automq.rocketmq.metadata.dao.S3StreamObject();\n newS3StreamObj.setStreamId(streamObject.getStreamId());\n newS3StreamObj.setObjectId(streamObject.getObjectId());\n newS3StreamObj.setObjectSize(streamObject.getObjectSize());\n newS3StreamObj.setStartOffset(streamObject.getStartOffset());\n newS3StreamObj.setEndOffset(streamObject.getEndOffset());\n newS3StreamObj.setBaseDataTimestamp(new Date(dataTs));\n newS3StreamObj.setCommittedTimestamp(new Date(committedTs));\n s3StreamObjectMapper.create(newS3StreamObj);\n toCache.add(newS3StreamObj);\n }\n\n // delete the compactedObjects of S3Stream\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n compactedObjects.forEach(id -> s3StreamObjectMapper.delete(null, null, id));\n }\n session.commit();\n\n // Update Cache\n s3StreamObjectCache.cache(streamObject.getStreamId(), toCache);\n s3StreamObjectCache.onCompact(streamObject.getStreamId(), compactedObjects);\n\n LOGGER.info(\"S3StreamObject[object-id={}] commit success, compacted objects: {}\",\n streamObject.getObjectId(), compactedObjects);\n future.complete(null);\n } catch (Exception e) {\n LOGGER.error(\"CommitStream failed\", e);\n ControllerException ex = new ControllerException(Code.INTERNAL_VALUE, \"CommitStream failed\" + e.getMessage());\n future.completeExceptionally(ex);\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCommitStreamObject() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCommitStreamObject", + "reference": " @Test\n public void testCommitStreamObject() throws IOException {\n long objectId, streamId = 1;\n\n try (S3MetadataService metadataService = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n objectId = metadataService.prepareS3Objects(3, 5).get();\n } catch (Exception e) {\n throw new RuntimeException(e);\n }\n\n S3StreamObject news3StreamObject = S3StreamObject.newBuilder()\n .setObjectId(objectId + 2)\n .setStreamId(streamId)\n .setObjectSize(111L)\n .build();\n\n try (SqlSession session = this.getSessionFactory().openSession()) {\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n buildS3StreamObjs(objectId, 2, 3, 100L).forEach(s3StreamObjectMapper::create);\n session.commit();\n }\n\n try (S3MetadataService service = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n List compactedObjects = new ArrayList<>();\n compactedObjects.add(objectId);\n compactedObjects.add(objectId + 1);\n service.commitStreamObject(news3StreamObject, compactedObjects);\n }\n\n try (SqlSession session = getSessionFactory().openSession()) {\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3Object s3Object = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object.getState());\n\n S3Object s3Object1 = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object1.getState());\n\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n for (long index = objectId; index < objectId + 2; index++) {\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(index);\n Assertions.assertNull(object);\n }\n\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(objectId + 2);\n Assertions.assertEquals(111L, object.getObjectSize());\n Assertions.assertEquals(streamId, object.getStreamId());\n Assertions.assertTrue(object.getBaseDataTimestamp().getTime() > 0);\n Assertions.assertTrue(object.getCommittedTimestamp().getTime() > 0);\n }\n }", + "completions": [ + "```java\n @Test\n public void testCommitStreamObject() throws IOException {\n // Given\n apache.rocketmq.controller.v1.S3StreamObject streamObject = new apache.rocketmq.controller.v1.S3StreamObject();\n streamObject.setObjectId(1L);\n streamObject.setStreamId(2L);\n streamObject.setObjectSize(3L);\n streamObject.setStartOffset(4L);\n streamObject.setEndOffset(5L);\n\n List compactedObjects = Arrays.asList(6L, 7L, 8L);\n\n // When\n CompletableFuture future = controller.commitStreamObject(streamObject, compactedObjects);\n\n // Then\n assertDoesNotThrow(() -> future.get());\n }\n```\n" + ], + "line": 85, + "token": 664, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 25, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyTooManyResults", + "reference": " @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(totalHits.value).thenReturn(2L);\n Mockito.when(hits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(response.getHits()).thenReturn(hits);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processService.getProcessByKey(processDefinitionKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Expected exception message to match\",\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n e.getMessage());\n }\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 26, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String doGet(String url) {\n try (CloseableHttpClient client = HttpClients.createDefault()) {\n HttpGet request = new HttpGet(url);\n HttpResponse response = client.execute(request);\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n return EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n }\n } catch (IOException e) {\n log.error(\"DoGet exception.\", e);\n\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n void testDoGet()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDoGet", + "reference": " @Test\n void testDoGet() {\n String result = HttpUtil.doGet(\"http://baidu.com\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testDoGet() {\n String url = \"http://example.com\";\n String expectedResponse = \"Hello, world!\";\n\n // Mock the HttpClient and HttpResponse\n CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);\n HttpResponse mockResponse = Mockito.mock(HttpResponse.class);\n StatusLine mockStatusLine = Mockito.mock(StatusLine.class);\n HttpEntity mockEntity = Mockito.mock(HttpEntity.class);\n\n // Set up the mock objects to return the expected values\n Mockito.when(mockClient.execute(Mockito.any(HttpGet.class))).thenReturn(mockResponse);\n Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);\n Mockito.when(mockStatusLine.getStatusCode()).thenReturn(HttpStatus.SC_OK);\n Mockito.when(mockResponse.getEntity()).thenReturn(mockEntity);\n Mockito.when(EntityUtils.toString(mockEntity, StandardCharsets.UTF_8)).thenReturn(expectedResponse);\n\n // Call the method under test\n String actualResponse = YourClassName.doGet(url);\n\n // Verify the result\n Assertions.assertEquals(expectedResponse, actualResponse);\n }\n```\nPlease replace `YourClassName` with the actual name of the class where the `doGet` method is located.\n" + ], + "line": 23, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n dpopService.validateDpopThumprint(null, \"any\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n // Given\n String existingThumprint = \"\";\n String requestThumprint = \"requestThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n try {\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), e.getResponse().getStatus());\n assertEquals(\"Thumprint does not match\", e.getMessage());\n }\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDataMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDataMapTest", + "reference": " @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"hera_log_doris_table_2_11\";\n for (int i = 0; i < 1; i++) {\n String jsonString = \"{\\\"linenumber\\\":1256,\\\"tailId\\\":90219,\\\"mqtag\\\":\\\"tags_4_13_90219\\\",\\\"filename\\\":\\\"/home/work/log/log-agent/server.log\\\",\\\"tail\\\":\\\"demo-client-agent\\\",\\\"mqtopic\\\":\\\"90219_hera-demo-client\\\",\\\"message\\\":\\\"2024-01-10 19:34:40,360|INFO ||NettyClientPublicExecutor_1|c.x.mone.log.agent.rpc.task.PingTask|83|ping res: log-agent-server:2022-12-05:0.0.2->2024-01-10 19:34:40 358->10.53.129.250:9899\\\",\\\"logstore\\\":\\\"测试doris日志\\\",\\\"logip\\\":\\\"10.53.129.176\\\",\\\"timestamp\\\":1704886481241}\";\n\n Map map = gson.fromJson(jsonString, Map.class);\n dorisStreamLoad.sendData(database, tableName, columnListNew, map);\n }\n }", + "completions": [ + "```java\n @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 29, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyNoResults", + "reference": " @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final NotFoundException exception =\n Assertions.assertThrows(\n NotFoundException.class, () -> processService.getProcessByKey(processDefinitionKey));\n\n // Then\n Assertions.assertEquals(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 30, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String getDiagramByKey(Long processDefinitionKey) {\n final IdsQueryBuilder q = idsQuery().addIds(processDefinitionKey.toString());\n\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n\n if (response.getHits().getTotalHits().value == 1) {\n final Map result = response.getHits().getHits()[0].getSourceAsMap();\n return (String) result.get(BPMN_XML);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with id '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with id '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\n \"Exception occurred, while obtaining the process diagram: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetDiagramByKeyTooManyResults", + "reference": " @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getDiagramByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 1L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHits.Relation.EQUAL_TO));\n\n // When\n try {\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n processDiagramService.getDiagramByKey(processDefinitionKey);\n } catch (NotFoundException e) {\n // Then\n Assert.assertEquals(\n \"Could not find unique process with id '\" + processDefinitionKey + \"'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 37, + "token": 309, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 31, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n void shutdown() {\n LOG.info(\"Running shutdown hook.\");\n try {\n serviceManager.stopAsync().awaitStopped(30, TimeUnit.SECONDS);\n } catch (Exception e) {\n // stopping timed out\n LOG.error(\"ServiceManager shutdown timed out\", e);\n }\n try {\n curatorFramework.unwrap().close();\n } catch (Exception e) {\n LOG.error(\"Error while closing curatorFramework \", e);\n }\n LOG.info(\"Shutting down LogManager\");\n LogManager.shutdown();\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexerShutdownTwice() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexerShutdownTwice", + "reference": " @Test\n public void testIndexerShutdownTwice() throws Exception {\n startKafkaServer();\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // Create a live partition for this partiton\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTimeMs = 1;\n final long endTimeMs = 100;\n final long maxOffset = 30;\n SnapshotMetadata livePartition0 =\n new SnapshotMetadata(\n name + \"live0\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"0\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition0);\n\n SnapshotMetadata livePartition1 =\n new SnapshotMetadata(\n name + \"live1\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"1\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition1);\n\n final SnapshotMetadata partition0 =\n new SnapshotMetadata(name, path, startTimeMs, endTimeMs, maxOffset, \"0\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition0);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .containsOnly(livePartition1, livePartition0, partition0);\n\n // Empty consumer offset since there is no prior consumer.\n astraIndexer =\n new AstraIndexer(\n chunkManagerUtil.chunkManager,\n curatorFramework,\n makeIndexerConfig(50),\n getKafkaConfig(),\n metricsRegistry);\n astraIndexer.startAsync();\n astraIndexer.awaitRunning(DEFAULT_START_STOP_DURATION);\n await().until(() -> kafkaServer.getConnectedConsumerGroups() == 1);\n\n // Produce more messages since the recovery task is created for head.\n produceMessagesToKafka(kafkaServer.getBroker(), startTime);\n\n consumeMessagesAndSearchMessagesTest(100, 1);\n\n // Live snapshot is deleted, recovery task is created.\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(livePartition1, partition0);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isEqualTo(4);\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).size()).isEqualTo(1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(searchMetadataStore).size()).isEqualTo(1);\n RecoveryTaskMetadata recoveryTask1 =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).get(0);\n assertThat(recoveryTask1.startOffset).isEqualTo(31);\n assertThat(recoveryTask1.endOffset).isEqualTo(99);\n assertThat(recoveryTask1.partitionId).isEqualTo(\"0\");\n\n // Shutting down is idempotent. So, doing it twice shouldn't throw an error.\n astraIndexer.shutDown();\n astraIndexer.shutDown();\n astraIndexer = null;\n }", + "completions": [ + "```java\n @Test\n public void testIndexerShutdownTwice() throws Exception {\n // Given\n Indexer indexer = new Indexer();\n indexer.serviceManager = Mockito.mock(ServiceManager.class);\n indexer.curatorFramework = Mockito.mock(CuratorFramework.class);\n Mockito.when(indexer.serviceManager.stopAsync()).thenReturn(Mockito.mock(ServiceManager.class));\n Mockito.when(indexer.serviceManager.awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class))).thenThrow(new TimeoutException());\n Mockito.when(indexer.curatorFramework.unwrap()).thenReturn(Mockito.mock(Closeable.class));\n Mockito.doThrow(new IOException()).when(indexer.curatorFramework.unwrap()).close();\n\n // When\n indexer.shutdown();\n\n // Then\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).stopAsync();\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class));\n Mockito.verify(indexer.curatorFramework.unwrap(), Mockito.times(1)).close();\n }\n```\n" + ], + "line": 27, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(addressList,\"test\",\"test\",\"test\", SystemEnum.AGENT, MsgTypeEnum.ALARM);\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = null;\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 33, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyWithException", + "reference": " @Test\n public void testGetProcessByKeyWithException() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(OperateRuntimeException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(totalHits.value).thenReturn(2L);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final Executable executable = () -> processService.getProcessByKey(processDefinitionKey);\n\n // Then\n final NotFoundException exception =\n Assertions.assertThrows(NotFoundException.class, executable);\n Assertions.assertEquals(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 34, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testConcurrentSearches() throws InterruptedException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testConcurrentSearches", + "reference": " @Test\n public void testConcurrentSearches() throws InterruptedException {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n AtomicInteger searchFailures = new AtomicInteger(0);\n AtomicInteger statsFailures = new AtomicInteger(0);\n AtomicInteger searchExceptions = new AtomicInteger(0);\n AtomicInteger successfulRuns = new AtomicInteger(0);\n\n Runnable searchRun =\n () -> {\n for (int i = 0; i < 100; i++) {\n try {\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n 0L,\n MAX_TIME,\n 100,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n if (babies.hits.size() != 2) {\n searchFailures.addAndGet(1);\n } else {\n successfulRuns.addAndGet(1);\n }\n } catch (Exception e) {\n searchExceptions.addAndGet(1);\n }\n }\n };\n\n Thread t1 = new Thread(searchRun);\n Thread t2 = new Thread(searchRun);\n t1.start();\n t2.start();\n t1.join();\n t2.join();\n assertThat(searchExceptions.get()).isEqualTo(0);\n assertThat(statsFailures.get()).isEqualTo(0);\n assertThat(searchFailures.get()).isEqualTo(0);\n assertThat(successfulRuns.get()).isEqualTo(200);\n }", + "completions": [ + "```java\n @Test\n public void testConcurrentSearches() throws InterruptedException {\n // Create a fixed thread pool with 10 threads\n ExecutorService executorService = Executors.newFixedThreadPool(10);\n\n // Create a list to hold the futures\n List>> futures = new ArrayList<>();\n\n // Submit 100 search tasks to the executor service\n for (int i = 0; i < 100; i++) {\n // Create a search task\n Callable> searchTask = () -> {\n // Generate a random dataset\n String dataset = UUID.randomUUID().toString();\n\n // Generate a random query string\n String queryStr = UUID.randomUUID().toString();\n\n // Generate a random start time\n long startTimeMsEpoch = System.currentTimeMillis() - ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random end time\n long endTimeMsEpoch = startTimeMsEpoch + ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random number of hits requested\n int howMany = ThreadLocalRandom.current().nextInt(100);\n\n // Generate a random aggregation builder\n AggBuilder aggBuilder = new AggBuilder();\n\n // Perform the search\n return logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n };\n\n // Submit the search task to the executor service and add the future to the list\n futures.add(executorService.submit(searchTask));\n }\n\n // Wait for all search tasks to complete\n for (Future> future : futures) {\n future.get();\n }\n\n // Shutdown the executor service\n executorService.shutdown();\n }\n```\n" + ], + "line": 85, + "token": 608, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 35, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String sendGetData(String path, Map headMap) {\n log.info(\"request tcm, path:{},headMap:{} \",path,headMap);\n String result = \"\";\n CloseableHttpResponse response =null;\n try(CloseableHttpClient httpClient = HttpClients.createDefault()) {\n // 创建get方式请求对象\n HttpGet httpGet = new HttpGet(path);\n if (null != headMap) {\n for (Map.Entry entry : headMap.entrySet()) {\n httpGet.addHeader(entry.getKey(), entry.getValue());\n }\n }\n // 通过请求对象获取响应对象\n response = httpClient.execute(httpGet);\n\n // 获取结果实体\n // 判断网络连接状态码是否正常(0--200都数正常)\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n result = EntityUtils.toString(response.getEntity(), UTF_8);\n } else {\n log.error(\"get请求传输 异常.url:{}, headMap:{}\", path, headMap);\n }\n } catch (ClientProtocolException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n } catch (IOException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n }\n finally {\n try {\n if (null!=response){\n response.close();\n }\n } catch (IOException e) {\n log.error(\"关闭response 异常 \", e);\n }\n }\n log.debug(result);\n return result;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendGetData()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendGetData", + "reference": " @Test\n void testSendGetData() {\n final Map headMap = new HashMap<>();\n final String result = HttpUtils.sendGetData(\"path\", headMap);\n assertThat(result).isEqualTo(\"\");\n }", + "completions": [ + "```java\n @Test\n void testSendGetData() {\n String path = \"http://localhost:8080/api/test\";\n Map headMap = new HashMap<>();\n headMap.put(\"Content-Type\", \"application/json\");\n headMap.put(\"Authorization\", \"Bearer your_token\");\n\n String result = sendGetData(path, headMap);\n\n // Add assertions to verify the result\n // For example, if the result is expected to be a JSON string, you can use a JSON library to parse the result and assert the values\n // Here is an example using the JSON library Jackson\n ObjectMapper objectMapper = new ObjectMapper();\n JsonNode jsonNode = objectMapper.readTree(result);\n assertEquals(\"expectedValue\", jsonNode.get(\"key\").asText());\n }\n```\nPlease replace \"your_token\" with the actual token you want to use for the test. Also, replace \"expectedValue\" and \"key\" with the actual expected value and key from the JSON response.\n" + ], + "line": 51, + "token": 367, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 36, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ByteString retainKeyPrefix(ByteString tenantNS, List topicFilterLevels) {\n ByteString prefix = ByteString.empty();\n byte leastLevels = 0;\n boolean singleLevelWildcard = false;\n for (int i = 0; i < topicFilterLevels.size(); i++) {\n String tfl = topicFilterLevels.get(i);\n if (\"+\".equals(tfl)) {\n leastLevels++;\n singleLevelWildcard = true;\n continue;\n }\n if (\"#\".equals(tfl)) {\n break;\n }\n leastLevels++;\n if (!singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(tfl));\n }\n if (i + 1 < topicFilterLevels.size()) {\n if (!topicFilterLevels.get(i + 1).equals(\"#\") && !singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(NUL));\n }\n }\n }\n return tenantNS.concat(unsafeWrap(new byte[] {leastLevels})).concat(prefix);\n }\n```\n###test function signature:\n```java\n @Test\n public void testRetainKeyPrefix()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRetainKeyPrefix", + "reference": " @Test\n public void testRetainKeyPrefix() {\n String tenantId = \"tenantA\";\n ByteString tenantNS = tenantNS(tenantId);\n assertEquals(retainKeyPrefix(tenantNS, parse(\"#\", false)), tenantNS.concat(levelByte(0)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/#\", false)),\n tenantNS.concat(levelByte(1).concat(copyFromUtf8(\"a\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/+\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(\"a\").concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/b\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)).concat(copyFromUtf8(\"b\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/#\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+\", false)),\n tenantNS.concat(\n levelByte(3).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+/+\", false)),\n tenantNS.concat(\n levelByte(4).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/\", false)),\n tenantNS.concat(levelByte(4).concat(copyFromUtf8(NUL))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/+/\", false)),\n tenantNS.concat(levelByte(5).concat(copyFromUtf8(NUL))));\n }", + "completions": [ + "```java\n @Test\n public void testRetainKeyPrefix() {\n ByteString tenantNS = ByteString.copyFromUtf8(\"tenant1\");\n List topicFilterLevels = Arrays.asList(\"level1\", \"level2\", \"+\", \"#\");\n ByteString expectedPrefix = ByteString.copyFromUtf8(\"tenant1\\u0001level1\\u0000level2\");\n ByteString actualPrefix = retainKeyPrefix(tenantNS, topicFilterLevels);\n assertEquals(expectedPrefix, actualPrefix);\n }\n```\n" + ], + "line": 37, + "token": 255, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 37, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n @Transactional(rollbackFor = Exception.class)\n public void updateFileConfigMaster(Long id) {\n // 校验存在\n validateFileConfigExists(id);\n // 更新其它为非 master\n fileConfigMapper.updateBatch(new FileConfigDO().setMaster(false));\n // 更新\n fileConfigMapper.updateById(new FileConfigDO().setId(id).setMaster(true));\n // 发送刷新配置的消息\n TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {\n\n @Override\n public void afterCommit() {\n fileConfigProducer.sendFileConfigRefreshMessage();\n }\n\n });\n }\n```\n###test function signature:\n```java\n @Test\n public void testUpdateFileConfigMaster_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testUpdateFileConfigMaster_success", + "reference": " @Test\n public void testUpdateFileConfigMaster_success() {\n // mock 数据\n FileConfigDO dbFileConfig = randomFileConfigDO().setMaster(false);\n fileConfigMapper.insert(dbFileConfig);// @Sql: 先插入出一条存在的数据\n FileConfigDO masterFileConfig = randomFileConfigDO().setMaster(true);\n fileConfigMapper.insert(masterFileConfig);// @Sql: 先插入出一条存在的数据\n\n // 调用\n fileConfigService.updateFileConfigMaster(dbFileConfig.getId());\n // 断言数据\n assertTrue(fileConfigMapper.selectById(dbFileConfig.getId()).getMaster());\n assertFalse(fileConfigMapper.selectById(masterFileConfig.getId()).getMaster());\n // verify 调用\n verify(fileConfigProducer).sendFileConfigRefreshMessage();\n }", + "completions": [ + "```java\n @Test\n public void testUpdateFileConfigMaster_success() {\n // 准备数据\n FileConfigDO fileConfigDO = new FileConfigDO();\n fileConfigDO.setId(1L);\n fileConfigDO.setMaster(false);\n fileConfigMapper.insert(fileConfigDO);\n\n // 执行方法\n fileConfigService.updateFileConfigMaster(1L);\n\n // 验证结果\n FileConfigDO updatedFileConfigDO = fileConfigMapper.selectById(1L);\n Assert.assertTrue(updatedFileConfigDO.getMaster());\n\n // 清理数据\n fileConfigMapper.deleteById(1L);\n }\n```\n" + ], + "line": 28, + "token": 164, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\t@Override\n\tpublic void add(List documents) {\n\n\t\tAssert.notNull(documents, \"The document list should not be null.\");\n\t\tif (CollectionUtils.isEmpty(documents)) {\n\t\t\treturn; // nothing to do;\n\t\t}\n\n\t\tfinal var searchDocuments = documents.stream().map(document -> {\n\t\t\tfinal var embeddings = this.embeddingClient.embed(document);\n\t\t\tSearchDocument searchDocument = new SearchDocument();\n\t\t\tsearchDocument.put(ID_FIELD_NAME, document.getId());\n\t\t\tsearchDocument.put(EMBEDDING_FIELD_NAME, embeddings);\n\t\t\tsearchDocument.put(CONTENT_FIELD_NAME, document.getContent());\n\t\t\tsearchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());\n\n\t\t\t// Add the filterable metadata fields as top level fields, allowing filler\n\t\t\t// expressions on them.\n\t\t\tfor (MetadataField mf : this.filterMetadataFields) {\n\t\t\t\tif (document.getMetadata().containsKey(mf.name())) {\n\t\t\t\t\tsearchDocument.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name()));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn searchDocument;\n\t\t}).toList();\n\n\t\tIndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);\n\n\t\tfor (IndexingResult indexingResult : result.getResults()) {\n\t\t\tAssert.isTrue(indexingResult.isSucceeded(),\n\t\t\t\t\tString.format(\"Document with key %s upload is not successfully\", indexingResult.getKey()));\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void addAndSearchTest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "addAndSearchTest", + "reference": "\t@Test\n\tpublic void addAndSearchTest() {\n\n\t\tcontextRunner.run(context -> {\n\n\t\t\tVectorStore vectorStore = context.getBean(VectorStore.class);\n\n\t\t\tvectorStore.add(documents);\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\t\t\t}, hasSize(1));\n\n\t\t\tList results = vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\n\t\t\tassertThat(results).hasSize(1);\n\t\t\tDocument resultDoc = results.get(0);\n\t\t\tassertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());\n\t\t\tassertThat(resultDoc.getContent()).contains(\"The Great Depression (1929–1939) was an economic shock\");\n\t\t\tassertThat(resultDoc.getMetadata()).hasSize(2);\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"meta2\");\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"distance\");\n\n\t\t\t// Remove all documents from the store\n\t\t\tvectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Hello\").withTopK(1));\n\t\t\t}, hasSize(0));\n\t\t});\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void addAndSearchTest() {\n\t\t// Given\n\t\tList documents = new ArrayList<>();\n\t\tDocument document1 = new Document();\n\t\tdocument1.setId(\"1\");\n\t\tdocument1.setContent(\"This is a test document\");\n\t\tdocument1.setMetadata(new HashMap<>());\n\t\tdocument1.getMetadata().put(\"author\", \"John Doe\");\n\t\tdocument1.getMetadata().put(\"date\", \"2022-01-01\");\n\t\tdocuments.add(document1);\n\n\t\tDocument document2 = new Document();\n\t\tdocument2.setId(\"2\");\n\t\tdocument2.setContent(\"This is another test document\");\n\t\tdocument2.setMetadata(new HashMap<>());\n\t\tdocument2.getMetadata().put(\"author\", \"Jane Doe\");\n\t\tdocument2.getMetadata().put(\"date\", \"2022-01-02\");\n\t\tdocuments.add(document2);\n\n\t\t// When\n\t\tsearchService.add(documents);\n\n\t\t// Then\n\t\tSearchRequest searchRequest = new SearchRequest();\n\t\tsearchRequest.setQuery(\"test\");\n\t\tsearchRequest.setFilter(new HashMap<>());\n\t\tsearchRequest.getFilter().put(\"author\", \"John Doe\");\n\t\tSearchResponse searchResponse = searchService.search(searchRequest);\n\n\t\tAssert.assertEquals(1, searchResponse.getDocuments().size());\n\t\tAssert.assertEquals(\"1\", searchResponse.getDocuments().get(0).getId());\n\t}\n```\n" + ], + "line": 39, + "token": 312, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 39, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {\n long startTime = System.currentTimeMillis();\n\n // 校验参数\n try {\n validateParams(metrics);\n } catch (Exception e) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(e.getMessage());\n return;\n }\n\n HttpContext httpContext = createHttpContext(metrics.getNginx());\n HttpUriRequest request = createHttpRequest(metrics.getNginx());\n try (CloseableHttpResponse response = CommonHttpClient.getHttpClient().execute(request, httpContext)){\n // 发起http请求,获取响应数据\n int statusCode = response.getStatusLine().getStatusCode();\n if (statusCode != SUCCESS_CODE) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(\"StatusCode \" + statusCode);\n return;\n }\n String resp = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n\n Long responseTime = System.currentTimeMillis() - startTime;\n // 根据metrics name选择调用不同解析方法\n if (NGINX_STATUS_NAME.equals(metrics.getName()) || AVAILABLE.equals(metrics.getName())) {\n parseNginxStatusResponse(builder, resp, metrics, responseTime);\n } else if (REQ_STATUS_NAME.equals(metrics.getName())) {\n parseReqStatusResponse(builder, resp, metrics, responseTime);\n }\n } catch (Exception e) {\n String errorMsg = CommonUtil.getMessageFromThrowable(e);\n log.info(errorMsg);\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(errorMsg);\n } finally {\n if (request != null) {\n request.abort();\n }\n }\n\n }\n```\n###test function signature:\n```java\n @Test\n public void testNginxCollect()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNginxCollect", + "reference": " @Test\n public void testNginxCollect() {\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n long monitorId = 999;\n String app = \"testNginx\";\n\n Metrics metrics = new Metrics();\n metrics.setNginx(nginxProtocol);\n nginxCollect.collect(builder, monitorId, app, metrics);\n }", + "completions": [ + "```java\n @Test\n public void testNginxCollect() {\n // 创建一个Metrics对象\n Metrics metrics = new Metrics();\n metrics.setName(\"nginx_status\");\n metrics.setNginx(new Nginx(\"http://localhost:8080/nginx_status\"));\n\n // 创建一个CollectRep.MetricsData.Builder对象\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n\n // 创建一个NginxCollect对象\n NginxCollect nginxCollect = new NginxCollect();\n\n // 调用collect方法\n nginxCollect.collect(builder, 1, \"test\", metrics);\n\n // 验证结果\n assertEquals(CollectRep.Code.SUCCESS, builder.getCode());\n assertNotNull(builder.getData());\n assertTrue(builder.getData().getMetrics().containsKey(\"active_connections\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"accepts\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"handled\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"requests\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"reading\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"writing\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"waiting\"));\n }\n```\n" + ], + "line": 51, + "token": 335, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 40, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Transactional(rollbackOn = Exception.class)\n public void submit(@NonNull PartitionPlanConfig partitionPlanConfig)\n throws SchedulerException, ClassNotFoundException {\n Long databaseId = partitionPlanConfig.getDatabaseId();\n Validate.notNull(databaseId, \"DatabaseId can not be null\");\n // disable all related partition plan task\n Database database = this.databaseService.detail(databaseId);\n disablePartitionPlan(database.getId());\n PartitionPlanEntity partitionPlanEntity = modelToEntity(partitionPlanConfig);\n partitionPlanEntity = this.partitionPlanRepository.save(partitionPlanEntity);\n if (!partitionPlanConfig.isEnabled()\n || CollectionUtils.isEmpty(partitionPlanConfig.getPartitionTableConfigs())) {\n log.info(\"Partition plan is disabled or table config is empty, do nothing and return\");\n return;\n }\n Validate.isTrue(partitionPlanConfig.getCreationTrigger() != null, \"Creation trigger can not be null\");\n if (partitionPlanConfig.getDroppingTrigger() == null) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(partitionPlanConfig.getPartitionTableConfigs(),\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n return;\n }\n Map> strategy2TblCfgs =\n partitionPlanConfig.getPartitionTableConfigs().stream().flatMap(tableConfig -> {\n Map> strategy2Cfgs =\n tableConfig.getPartitionKeyConfigs().stream()\n .collect(Collectors.groupingBy(PartitionPlanKeyConfig::getStrategy));\n return strategy2Cfgs.values().stream().map(cfgs -> {\n PartitionPlanTableConfig cfg = new PartitionPlanTableConfig();\n cfg.setPartitionKeyConfigs(cfgs);\n cfg.setTableName(tableConfig.getTableName());\n cfg.setEnabled(tableConfig.isEnabled());\n cfg.setPartitionNameInvoker(tableConfig.getPartitionNameInvoker());\n cfg.setPartitionNameInvokerParameters(tableConfig.getPartitionNameInvokerParameters());\n return cfg;\n });\n }).collect(Collectors.groupingBy(cfg -> cfg.getPartitionKeyConfigs().get(0).getStrategy()));\n List createConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.CREATE);\n List dropConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.DROP);\n if (CollectionUtils.isNotEmpty(createConfigs)) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(createConfigs,\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n if (CollectionUtils.isNotEmpty(dropConfigs)) {\n ScheduleEntity dropScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getDroppingTrigger());\n createPartitionPlanTables(dropConfigs,\n partitionPlanEntity.getId(), dropScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "submit_bothCreateAndDropTrigger_submitSucceed", + "reference": " @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n PartitionPlanTableConfig tableConfig = new PartitionPlanTableConfig();\n tableConfig.setTableName(MYSQL_REAL_RANGE_TABLE_NAME);\n tableConfig.setPartitionNameInvoker(\"CUSTOM_PARTITION_NAME_GENERATOR\");\n SqlExprBasedGeneratorConfig config = new SqlExprBasedGeneratorConfig();\n config.setGenerateExpr(\"concat('p', date_format(from_unixtime(unix_timestamp(\"\n + \"STR_TO_DATE(20240125, '%Y%m%d')) + \"\n + PartitionPlanVariableKey.INTERVAL.getVariable() + \"), '%Y%m%d'))\");\n config.setIntervalGenerateExpr(\"86400\");\n tableConfig.setPartitionNameInvokerParameters(getSqlExprBasedNameGeneratorParameters(config));\n PartitionPlanKeyConfig c3Create = getMysqlc3CreateConfig();\n PartitionPlanKeyConfig datekeyCreate = getMysqldatekeyCreateConfig();\n PartitionPlanKeyConfig dropConfig = getDropConfig();\n tableConfig.setPartitionKeyConfigs(Arrays.asList(c3Create, datekeyCreate, dropConfig));\n\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setPartitionTableConfigs(Collections.singletonList(tableConfig));\n partitionPlanConfig.setFlowInstanceId(1L);\n partitionPlanConfig.setTimeoutMillis(180000L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setDatabaseId(1L);\n\n long t1 = System.currentTimeMillis();\n long t2 = System.currentTimeMillis() + 1000;\n TriggerConfig createTrigger = TestRandom.nextObject(TriggerConfig.class);\n createTrigger.setStartAt(new Date(t1));\n partitionPlanConfig.setCreationTrigger(createTrigger);\n\n TriggerConfig dropTrigger = TestRandom.nextObject(TriggerConfig.class);\n dropTrigger.setStartAt(new Date(t2));\n partitionPlanConfig.setDroppingTrigger(dropTrigger);\n Database database = TestRandom.nextObject(Database.class);\n database.setId(1L);\n Mockito.when(this.databaseService.detail(1L)).thenReturn(database);\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t1;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t2;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.doNothing().when(this.scheduleService).enable(Mockito.isA(ScheduleEntity.class));\n this.partitionPlanScheduleService.submit(partitionPlanConfig);\n List actuals = this.partitionPlanTableRepository.findAll();\n Set scheduleIds = actuals.stream()\n .map(PartitionPlanTableEntity::getScheduleId).collect(Collectors.toSet());\n Assert.assertEquals(scheduleIds.size(), 2);\n }", + "completions": [ + "```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n // Given\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setDatabaseId(1L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanConfig.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanConfig.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanConfig.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanConfig.setTaskId(\"taskId\");\n partitionPlanConfig.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanConfig.setTimeoutMillis(1000L);\n\n Database database = new Database();\n database.setId(1L);\n database.setName(\"database\");\n database.setType(DatabaseType.MYSQL);\n database.setHost(\"localhost\");\n database.setPort(3306);\n database.setUsername(\"root\");\n database.setPassword(\"password\");\n database.setDatabaseName(\"test\");\n database.setProperties(\"{}\");\n database.setStatus(DatabaseStatus.ONLINE);\n database.setCreateTime(new Date());\n database.setUpdateTime(new Date());\n when(databaseService.detail(1L)).thenReturn(database);\n\n PartitionPlanEntity partitionPlanEntity = new PartitionPlanEntity();\n partitionPlanEntity.setId(1L);\n partitionPlanEntity.setDatabaseId(1L);\n partitionPlanEntity.setEnabled(true);\n partitionPlanEntity.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanEntity.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanEntity.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanEntity.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanEntity.setTaskId(\"taskId\");\n partitionPlanEntity.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanEntity.setTimeoutMillis(1000L);\n when(partitionPlanRepository.save(any(PartitionPlanEntity.class))).thenReturn(partitionPlanEntity);\n\n ScheduleEntity createScheduleEntity = new ScheduleEntity();\n createScheduleEntity.setId(1L);\n createScheduleEntity.setDatabaseId(1L);\n createScheduleEntity.setTriggerConfig(new TriggerConfig(TriggerType" + ], + "line": 70, + "token": 530, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 41, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAllQueryWithFullTextSearchDisabled", + "reference": " @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n Instant time = Instant.now();\n Trace.KeyValue customField =\n Trace.KeyValue.newBuilder()\n .setVStr(\"value\")\n .setKey(\"customField\")\n .setFieldType(Schema.SchemaFieldType.KEYWORD)\n .build();\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time, List.of(customField)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n SearchResult termQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"customField:value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(termQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermStrQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermStrQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermNumericQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermNumericQuery.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"*\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNull(result.getInternalAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 42, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollExecutedCommitAsyncCallbacks", + "reference": " @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks() {\n consumer = newConsumer();\n MockCommitCallback callback = new MockCommitCallback();\n completeCommitAsyncApplicationEventSuccessfully();\n doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n completeFetchedCommittedOffsetApplicationEventSuccessfully(mkMap());\n\n consumer.assign(Collections.singleton(new TopicPartition(\"foo\", 0)));\n assertDoesNotThrow(() -> consumer.commitAsync(new HashMap<>(), callback));\n assertMockCommitCallbackInvoked(() -> consumer.poll(Duration.ZERO),\n callback,\n null);\n }", + "completions": [ + "```java\n @Test\n public void testPoll() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertNotNull(records);\n assertEquals(0, records.count());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 43, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static URI resolve(URI base, URI ref) {\n String path = ref.getRawPath();\n if (path != null && !path.isEmpty()) {\n path = base.getRawPath();\n if (path == null || !path.endsWith(\"/\")) {\n try {\n base = new URI(base.getScheme(), base.getAuthority(), base.getPath() + '/', null, null);\n } catch (URISyntaxException e) {\n throw new IllegalStateException(e);\n }\n }\n }\n return URIUtils.resolve(base, ref);\n }\n```\n###test function signature:\n```java\n @Test\n void testResolve_BasePathTrailingSlash()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testResolve_BasePathTrailingSlash", + "reference": " @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://host/sub/dir/\");\n assertEquals(\"http://host/sub/dir/file.jar\", resolve(base, \"file.jar\"));\n assertEquals(\"http://host/sub/dir/dir/file.jar\", resolve(base, \"dir/file.jar\"));\n assertEquals(\"http://host/sub/dir/?arg=val\", resolve(base, \"?arg=val\"));\n assertEquals(\"http://host/sub/dir/file?arg=val\", resolve(base, \"file?arg=val\"));\n assertEquals(\"http://host/sub/dir/dir/file?arg=val\", resolve(base, \"dir/file?arg=val\"));\n }", + "completions": [ + "```java\n @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://example.com/path/\");\n URI ref = URI.create(\"http://example.com/path/file.txt\");\n URI expected = URI.create(\"http://example.com/path/file.txt\");\n URI actual = resolve(base, ref);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 25, + "token": 187, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_collection_of_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_collection_of_non_serializable_object", + "reference": " @Test\n public void should_clone_collection_of_non_serializable_object() {\n List original = new ArrayList<>();\n original.add(new NonSerializableObject(\"value\"));\n List cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_collection_of_non_serializable_object() {\n // Given\n List nonSerializableList = new ArrayList<>();\n nonSerializableList.add(new NonSerializableObject(\"test\"));\n nonSerializableList.add(new NonSerializableObject(\"test2\"));\n\n // When\n List clonedList = objectCloner.clone(nonSerializableList);\n\n // Then\n assertNotSame(nonSerializableList, clonedList);\n assertEquals(nonSerializableList.size(), clonedList.size());\n for (int i = 0; i < nonSerializableList.size(); i++) {\n assertNotSame(nonSerializableList.get(i), clonedList.get(i));\n assertEquals(nonSerializableList.get(i), clonedList.get(i));\n }\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 45, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssign()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssign", + "reference": " @Test\n public void testAssign() {\n consumer = newConsumer();\n final TopicPartition tp = new TopicPartition(\"foo\", 3);\n consumer.assign(singleton(tp));\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().contains(tp));\n verify(applicationEventHandler).add(any(AssignmentChangeEvent.class));\n verify(applicationEventHandler).add(any(NewTopicsMetadataUpdateRequestEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testAssign() {\n // Create a mock Consumer instance\n Consumer mockConsumer = mock(Consumer.class);\n\n // Create a collection of TopicPartition instances\n Collection partitions = new ArrayList<>();\n partitions.add(new TopicPartition(\"testTopic\", 0));\n partitions.add(new TopicPartition(\"testTopic\", 1));\n\n // Call the assign method\n mockConsumer.assign(partitions);\n\n // Verify that the assign method was called with the correct parameters\n verify(mockConsumer).assign(partitions);\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 46, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String relativePath(Path absolutePath, Path basePath) {\n Preconditions.checkArgument(absolutePath.isAbsolute(), \"absolutePath must be an absolute path\");\n Preconditions.checkArgument(basePath.isAbsolute(), \"basePath must be an absolute path\");\n\n List absolutePathComponents = Lists.newArrayList(toPathComponents(absolutePath)); // make a copy\n List basePathComponents = toPathComponents(basePath);\n boolean hasCommonPrefix = basePathComponents.isEmpty(); // when basePath is \"/\", we always have a common prefix\n\n for (String base : basePathComponents) {\n if (absolutePathComponents.get(0).equals(base)) {\n absolutePathComponents.remove(0);\n hasCommonPrefix = true;\n } else {\n break;\n }\n }\n if (hasCommonPrefix) {\n return PATH_JOINER.join(absolutePathComponents);\n } else {\n return absolutePath.toString();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testRelativePath() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRelativePath", + "reference": " @Test\n public void testRelativePath() throws Exception {\n assertEquals(\"a\", PathUtils.relativePath(Path.of(\"/a\"), Path.of(\"/\")));\n assertEquals(\"b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/a\")));\n assertEquals(\"b/c.json\", PathUtils.relativePath(Path.of(\"/a/b/c.json\"), Path.of(\"/a\")));\n assertEquals(\"c/d/e\", PathUtils.relativePath(Path.of(\"/a/b/c/d/e\"), Path.of(\"/a/b\")));\n assertEquals(\"/a/b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/c/d\"))); // no common prefix\n }", + "completions": [ + "```java\n @Test\n public void testRelativePath() throws Exception {\n Path absolutePath = Paths.get(\"/a/b/c/d\");\n Path basePath = Paths.get(\"/a/b\");\n String expected = \"c/d\";\n String actual = relativePath(absolutePath, basePath);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 31, + "token": 234, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Boolean isJson(String value) {\n JsonElement jsonElement;\n try {\n jsonElement = JsonParser.parseString(value);\n } catch (Exception e) {\n return false;\n }\n if (jsonElement == null || jsonElement.isJsonNull()) {\n return false;\n }\n if (!jsonElement.isJsonObject() && !jsonElement.isJsonArray()) {\n return false;\n }\n return true;\n }\n```\n###test function signature:\n```java\n @Test\n void testIsJson()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJson", + "reference": " @Test\n void testIsJson() {\n assertThat(JsonUtil.isJson(\"value\")).isFalse();\n }", + "completions": [ + "```java\n @Test\n void testIsJson() {\n // Test case 1: Valid JSON object\n String jsonObject = \"{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}\";\n assertTrue(isJson(jsonObject));\n\n // Test case 2: Valid JSON array\n String jsonArray = \"[{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}, {\\\"name\\\":\\\"Jane\\\", \\\"age\\\":25, \\\"city\\\":\\\"Chicago\\\"}]\";\n assertTrue(isJson(jsonArray));\n\n // Test case 3: Invalid JSON (not an object or array)\n String invalidJson = \"\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"\";\n assertFalse(isJson(invalidJson));\n\n // Test case 4: Empty string\n String emptyString = \"\";\n assertFalse(isJson(emptyString));\n\n // Test case 5: Null value\n String nullValue = null;\n assertFalse(isJson(nullValue));\n }\n```\n" + ], + "line": 26, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchOnlyHistogram()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchOnlyHistogram", + "reference": " @Test\n public void testSearchOnlyHistogram() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 0,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(0);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(2);\n\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(1).getDocCount()).isEqualTo(1);\n\n assertThat(\n Long.parseLong(histogram.getBuckets().get(0).getKeyAsString()) >= time.toEpochMilli())\n .isTrue();\n assertThat(\n Long.parseLong(histogram.getBuckets().get(1).getKeyAsString())\n <= time.plusSeconds(10).toEpochMilli())\n .isTrue();\n }", + "completions": [ + "```java\n @Test\n public void testSearchOnlyHistogram() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 0;\n AggBuilder aggBuilder = new AggBuilder();\n aggBuilder.addHistogramAgg(\"testField\", \"testInterval\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation().getHistogramAggs().size() > 0);\n assertEquals(\"testField\", result.getAggregation().getHistogramAggs().get(0).getField());\n assertEquals(\"testInterval\", result.getAggregation().getHistogramAggs().get(0).getInterval());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 49, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent", + "reference": " @Test\n public void testReadComponent() {\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n EntityMetadata entityMetadata = mock(EntityMetadata.class);\n VariableSource variableSource = mock(VariableSource.class);\n when(contentPermissionChecker.isPermitted(anyString(), any(), eq(BreadActions.BROWSE), any())).thenReturn(true);\n when(component.getEntityMetadata()).thenReturn(entityMetadata);\n when(entityMetadata.getId()).thenReturn(new DetachedEntityId(\"someid\"));\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(Arrays.asList(asset));\n when(assetVariableResolver.fromAsset(asset)).thenReturn(variableSource);\n ComponentXO componentXO = underTest.readComponent(\"someid\", \"testRepositoryName\");\n\n assertThat(componentXO, is(notNullValue()));\n assertThat(componentXO.getId(), is(\"someid\"));\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n Supplier txSupplier = () -> storageTx;\n\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(txSupplier);\n when(storageTx.findComponent(componentId)).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(ImmutableList.of(asset));\n\n // When\n ComponentXO result = readComponent(repository, componentId);\n\n // Then\n verify(repository).facet(StorageFacet.class);\n verify(storageFacet).txSupplier();\n verify(storageTx).begin();\n verify(storageTx).findComponent(componentId);\n verify(storageTx).browseAssets(component);\n verify(storageTx).close();\n // Add assertions for the result\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 50, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ListResult listEntities(String tableName, String searchIndexName, List matchFilters,\n List queryFilters, List multiMatchFilter, String nextToken, List sorters, Class clazz) {\n SearchQuery searchQuery = createSearchQuery(matchFilters, queryFilters, multiMatchFilter);\n SearchRequest searchRequest = new SearchRequest(tableName, searchIndexName, searchQuery);\n if (!StringUtils.isEmpty(nextToken)) {\n byte[] tokenBytes = EncryptionUtil.decode(nextToken);\n searchRequest.getSearchQuery().setToken(tokenBytes);\n } else {\n if (sorters != null &&!sorters.isEmpty()) {\n searchQuery.setSort(new Sort(sorters));\n }\n }\n SearchRequest.ColumnsToGet columnsToGet = new SearchRequest.ColumnsToGet();\n columnsToGet.setColumns(ReflectionUtil.getPropertyNames(clazz));\n searchRequest.setColumnsToGet(columnsToGet);\n log.info(\"searchRequest:{}\", JsonUtil.toJsonString(searchRequest));\n SearchResponse searchResponse = otsClient.search(searchRequest);\n if (searchResponse == null || searchResponse.getRows() == null) {\n return ListResult.genSuccessListResult(null, 0);\n }\n byte[] nextTokenBytes = searchResponse.getNextToken();\n nextToken = nextTokenBytes == null || nextTokenBytes.length == 0 ? null : EncryptionUtil.encode(nextTokenBytes);\n List result = searchResponse.getRows().stream()\n .map(row -> OtsUtil.convertRowToDTO(row, clazz))\n .collect(Collectors.toList());\n return ListResult.genSuccessListResult(result, searchResponse.getTotalCount(), nextToken);\n }\n```\n###test function signature:\n```java\n @Test\n void testListEntities()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListEntities", + "reference": " @Test\n void testListEntities() {\n final List matchFilters = createMatchFilters();\n final List queryFilters = createQueryFilters();\n final ListResult expectedResult = new ListResult<>();\n FieldSort fieldSort = new FieldSort(OrderOtsConstant.GMT_CREATE_LONG);\n fieldSort.setOrder(SortOrder.DESC);\n expectedResult.setData(Arrays.asList());\n expectedResult.setCount(0L);\n String nextToken = \"CAESFQoTChEKDWdtdENyZWF0ZUxvbmcQARgBIlQKCQBI8UqGigEAAApHA0IAAAAxUzM1MzQzMTM0NjQzMjYzMzAzMzYyMzE2MTMzMzkzOTM1MzEzNjM2MzM2NDM2MzAzMDMwNjYzNTM1MzA2NjY0MzM=\";\n expectedResult.setNextToken(nextToken);\n final SearchResponse searchResponse = new SearchResponse(new Response(\"requestId\"));\n when(mockOtsClient.search(any(SearchRequest.class))).thenReturn(searchResponse);\n ListResult result = baseOtsHelper.listEntities(\"order\", \"order_index\", matchFilters, queryFilters, null, nextToken, Arrays.asList(fieldSort), OrderDTO.class);\n assertThat(result.getCount()).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n void testListEntities() {\n // Given\n String tableName = \"testTable\";\n String searchIndexName = \"testIndex\";\n List matchFilters = Arrays.asList(new OtsFilter(\"field1\", OtsFilter.CompareOperator.EQUAL, \"value1\"),\n new OtsFilter(\"field2\", OtsFilter.CompareOperator.EQUAL, \"value2\"));\n List queryFilters = Arrays.asList(new OtsFilter(\"field3\", OtsFilter.CompareOperator.EQUAL, \"value3\"),\n new OtsFilter(\"field4\", OtsFilter.CompareOperator.EQUAL, \"value4\"));\n List multiMatchFilter = Arrays.asList(new OtsFilter(\"field5\", OtsFilter.CompareOperator.EQUAL, \"value5\"),\n new OtsFilter(\"field6\", OtsFilter.CompareOperator.EQUAL, \"value6\"));\n String nextToken = \"testNextToken\";\n List sorters = Arrays.asList(new Sort.Sorter(\"field7\", Sort.SortOrder.ASC),\n new Sort.Sorter(\"field8\", Sort.SortOrder.DESC));\n Class clazz = TestDTO.class;\n\n // When\n ListResult result = listEntities(tableName, searchIndexName, matchFilters, queryFilters, multiMatchFilter, nextToken, sorters, clazz);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalCount());\n assertNull(result.getNextToken());\n assertTrue(result.getData().isEmpty());\n }\n```\n" + ], + "line": 38, + "token": 346, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 51, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String[] listFiles(URI fileUri, boolean recursive) throws IOException {\n try {\n ImmutableList.Builder builder = ImmutableList.builder();\n String continuationToken = null;\n boolean isDone = false;\n String prefix = normalizeToDirectoryPrefix(fileUri);\n int fileCount = 0;\n while (!isDone) {\n ListObjectsV2Request.Builder listObjectsV2RequestBuilder =\n ListObjectsV2Request.builder().bucket(fileUri.getHost());\n if (!prefix.equals(DELIMITER)) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.prefix(prefix);\n }\n if (!recursive) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.delimiter(DELIMITER);\n }\n if (continuationToken != null) {\n listObjectsV2RequestBuilder.continuationToken(continuationToken);\n }\n ListObjectsV2Request listObjectsV2Request = listObjectsV2RequestBuilder.build();\n LOG.debug(\"Trying to send ListObjectsV2Request {}\", listObjectsV2Request);\n ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);\n LOG.debug(\"Getting ListObjectsV2Response: {}\", listObjectsV2Response);\n List filesReturned = listObjectsV2Response.contents();\n fileCount += filesReturned.size();\n filesReturned.stream()\n .forEach(\n object -> {\n // Only add files and not directories\n if (!object.key().equals(fileUri.getPath())\n && !object.key().endsWith(DELIMITER)) {\n String fileKey = object.key();\n if (fileKey.startsWith(DELIMITER)) {\n fileKey = fileKey.substring(1);\n }\n builder.add(S3_SCHEME + fileUri.getHost() + DELIMITER + fileKey);\n }\n });\n if (fileCount == LIST_MAX_KEYS) {\n // check if we reached the max keys returned, if so abort and throw an error message\n LOG.error(\n \"Too many files ({}) returned from S3 when attempting to list object prefixes\",\n LIST_MAX_KEYS);\n throw new IllegalStateException(\n String.format(\n \"Max keys (%s) reached when attempting to list S3 objects\", LIST_MAX_KEYS));\n }\n isDone = !listObjectsV2Response.isTruncated();\n continuationToken = listObjectsV2Response.nextContinuationToken();\n }\n String[] listedFiles = builder.build().toArray(new String[0]);\n LOG.debug(\n \"Listed {} files from URI: {}, is recursive: {}\", listedFiles.length, fileUri, recursive);\n return listedFiles;\n } catch (Throwable t) {\n throw new IOException(t);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListFilesInFolderRecursive", + "reference": " @Test\n public void testListFilesInFolderRecursive() throws Exception {\n String folder = \"list-files-rec\";\n String[] nestedFolders = new String[] {\"list-files-child-1\", \"list-files-child-2\"};\n String[] originalFiles = new String[] {\"a-list-3.txt\", \"b-list-3.txt\", \"c-list-3.txt\"};\n\n List expectedResultList = new ArrayList<>();\n for (String childFolder : nestedFolders) {\n String folderName = folder + DELIMITER + childFolder;\n for (String fileName : originalFiles) {\n createEmptyFile(folderName, fileName);\n expectedResultList.add(\n String.format(FILE_FORMAT, SCHEME, bucket, folderName + DELIMITER + fileName));\n }\n }\n String[] actualFiles =\n s3BlobFs.listFiles(URI.create(String.format(FILE_FORMAT, SCHEME, bucket, folder)), true);\n\n actualFiles =\n Arrays.stream(actualFiles).filter(x -> x.contains(\"list-3\")).toArray(String[]::new);\n assertEquals(actualFiles.length, expectedResultList.size());\n assertTrue(Arrays.equals(expectedResultList.toArray(), actualFiles));\n }", + "completions": [ + "```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception {\n // Given\n URI fileUri = new URI(\"s3://bucket/folder/\");\n boolean recursive = true;\n String[] expectedFiles = {\n \"s3://bucket/folder/file1.txt\",\n \"s3://bucket/folder/file2.txt\",\n \"s3://bucket/folder/subfolder/file3.txt\",\n \"s3://bucket/folder/subfolder/file4.txt\"\n };\n\n // When\n String[] actualFiles = s3FileSystem.listFiles(fileUri, recursive);\n\n // Then\n assertArrayEquals(expectedFiles, actualFiles);\n }\n```\n" + ], + "line": 70, + "token": 490, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 52, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_nameDuplicate", + "reference": " @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // mock 数据\n RoleDO roleDO = randomPojo(RoleDO.class, o -> o.setName(\"role_name\"));\n roleMapper.insert(roleDO);\n // 准备参数\n String name = \"role_name\";\n\n // 调用,并断言异常\n assertServiceException(() -> roleService.validateRoleDuplicate(name, randomString(), null),\n ROLE_NAME_DUPLICATE, name);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // Given\n String name = \"duplicateName\";\n String code = \"uniqueCode\";\n Long id = 1L;\n RoleDO existingRole = new RoleDO();\n existingRole.setId(2L);\n existingRole.setName(name);\n existingRole.setCode(code);\n when(roleMapper.selectByName(name)).thenReturn(existingRole);\n\n // When\n try {\n roleService.validateRoleDuplicate(name, code, id);\n fail(\"Expected exception not thrown\");\n } catch (ServiceException e) {\n // Then\n assertEquals(ROLE_NAME_DUPLICATE, e.getCode());\n assertEquals(name, e.getArgs()[0]);\n }\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public String sign(Map map) {\n try {\n Date date = new Date(System.currentTimeMillis() + jwtProperties.getExpireTimeMills());\n Algorithm algorithm = Algorithm.HMAC256(jwtProperties.getTokenSecret());\n Map header = new HashMap<>(2);\n header.put(\"typ\", \"jwt\");\n JWTCreator.Builder builder = JWT.create()\n .withHeader(header)\n .withIssuedAt(new Date())\n .withExpiresAt(date);\n map.entrySet().forEach(entry -> {\n if (entry.getValue() instanceof Integer) {\n builder.withClaim(entry.getKey(), (Integer) entry.getValue());\n } else if (entry.getValue() instanceof Long) {\n builder.withClaim(entry.getKey(), (Long) entry.getValue());\n } else if (entry.getValue() instanceof Boolean) {\n builder.withClaim(entry.getKey(), (Boolean) entry.getValue());\n } else if (entry.getValue() instanceof String) {\n builder.withClaim(entry.getKey(), String.valueOf(entry.getValue()));\n } else if (entry.getValue() instanceof Double) {\n builder.withClaim(entry.getKey(), (Double) entry.getValue());\n } else if (entry.getValue() instanceof Date) {\n builder.withClaim(entry.getKey(), (Date) entry.getValue());\n }\n });\n return builder.sign(algorithm);\n } catch (Exception e) {\n return null;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sign_generateJwtToken_succeed()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sign_generateJwtToken_succeed", + "reference": " @Test\n public void sign_generateJwtToken_succeed() {\n Map map = new HashMap<>();\n map.put(\"userId\", 123);\n map.put(\"username\", \"testUser\");\n String token = jwtService.sign(map);\n Assert.assertNotNull(token);\n }", + "completions": [ + "```java\n @Test\n public void sign_generateJwtToken_succeed() {\n // Given\n Map map = new HashMap<>();\n map.put(\"userId\", 1);\n map.put(\"username\", \"testUser\");\n map.put(\"isAdmin\", true);\n map.put(\"expireTime\", System.currentTimeMillis() + 3600000);\n\n // When\n String jwtToken = jwtService.sign(map);\n\n // Then\n assertNotNull(jwtToken);\n assertTrue(jwtToken.length() > 0);\n }\n```\n" + ], + "line": 41, + "token": 349, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 54, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_non_serializable_value_with_null_value", + "reference": " @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map original = new LinkedHashMap<>();\n original.put(\"null\", null);\n original.put(\"key\", new NonSerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map map = new HashMap<>();\n map.put(\"key\", null);\n\n Map clonedMap = (Map) clone(map);\n\n assertNotSame(map, clonedMap);\n assertEquals(map, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 55, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void init(final DiscoveryConfig config) {\n if (this.namingService != null) {\n LOGGER.info(\"Nacos naming service already registered\");\n return;\n }\n Properties properties = config.getProps();\n Properties nacosProperties = new Properties();\n this.groupName = properties.getProperty(\"groupName\", \"SHENYU_GROUP\");\n String serverAddr = config.getServerList();\n nacosProperties.put(PropertyKeyConst.SERVER_ADDR, serverAddr);\n nacosProperties.put(PropertyKeyConst.NAMESPACE, properties.getProperty(NAMESPACE, \"\"));\n nacosProperties.put(PropertyKeyConst.USERNAME, properties.getProperty(PropertyKeyConst.USERNAME, \"\"));\n nacosProperties.put(PropertyKeyConst.PASSWORD, properties.getProperty(PropertyKeyConst.PASSWORD, \"\"));\n nacosProperties.put(PropertyKeyConst.ACCESS_KEY, properties.getProperty(PropertyKeyConst.ACCESS_KEY, \"\"));\n nacosProperties.put(PropertyKeyConst.SECRET_KEY, properties.getProperty(PropertyKeyConst.SECRET_KEY, \"\"));\n try {\n this.namingService = NamingFactory.createNamingService(nacosProperties);\n LOGGER.info(\"Nacos naming service initialized success\");\n } catch (NacosException e) {\n LOGGER.error(\"Error initializing Nacos naming service\", e);\n throw new ShenyuException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testInit", + "reference": " @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n // Set the discovery config\n setField(nacosDiscoveryServiceUnderTest.getClass(), \"namingService\", null);\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n config.setServerList(\"127.0.0.1:8848\");\n properties.setProperty(\"groupName\", \"SHENYU_GROUP\");\n config.setProps(properties);\n\n try (MockedStatic mockedNamingFactory = mockStatic(NamingFactory.class)) {\n // Mock the successful creation of NamingService\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenReturn(namingService);\n nacosDiscoveryServiceUnderTest.init(config);\n mockedNamingFactory.verify(() -> NamingFactory.createNamingService(any(Properties.class)));\n assertEquals(namingService, getField(nacosDiscoveryServiceUnderTest, \"namingService\"));\n // Mock the situation where NamingService fails to be created and throws an exception\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenThrow(new NacosException());\n assertDoesNotThrow(() -> nacosDiscoveryServiceUnderTest.init(config));\n }\n }", + "completions": [ + "```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n properties.put(\"groupName\", \"SHENYU_GROUP\");\n properties.put(\"serverAddr\", \"localhost:8848\");\n properties.put(\"namespace\", \"\");\n properties.put(\"username\", \"\");\n properties.put(\"password\", \"\");\n properties.put(\"accessKey\", \"\");\n properties.put(\"secretKey\", \"\");\n config.setProps(properties);\n config.setServerList(\"localhost:8848\");\n\n NacosDiscoveryService nacosDiscoveryService = new NacosDiscoveryService();\n nacosDiscoveryService.init(config);\n\n Field namingServiceField = NacosDiscoveryService.class.getDeclaredField(\"namingService\");\n namingServiceField.setAccessible(true);\n NamingService namingService = (NamingService) namingServiceField.get(nacosDiscoveryService);\n\n assertNotNull(namingService);\n }\n```\n" + ], + "line": 35, + "token": 273, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 56, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyRecoveryNotBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(850, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 2) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(750, 0, indexerConfig));\n\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n assertThat(recoveryTaskCreator.determineStartingOffset(1201, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTaskCreator.determineStartingOffset(1200, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(1150, 0, indexerConfig));\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore))\n .contains(recoveryTask1, recoveryTask11);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 57, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public boolean isJsonMime(String mediaType) {\n // \"* / *\" is default to JSON\n if (\"*/*\".equals(mediaType)) {\n return true;\n }\n\n try {\n return isJsonMime(MediaType.parseMediaType(mediaType));\n } catch (InvalidMediaTypeException e) {\n }\n return false;\n }\n```\n###test function signature:\n```java\n @Test\n public void testIsJsonMime()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJsonMime", + "reference": " @Test\n public void testIsJsonMime() {\n assertFalse(apiClient.isJsonMime(null));\n assertFalse(apiClient.isJsonMime(\"\"));\n assertFalse(apiClient.isJsonMime(\"text/plain\"));\n assertFalse(apiClient.isJsonMime(\"application/xml\"));\n assertFalse(apiClient.isJsonMime(\"application/jsonp\"));\n assertFalse(apiClient.isJsonMime(\"example/json\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+jsonx\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+xjson\"));\n\n assertTrue(apiClient.isJsonMime(\"application/json\"));\n assertTrue(apiClient.isJsonMime(\"application/json; charset=UTF8\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/JSON\"));\n\n assertTrue(apiClient.isJsonMime(\"application/problem+json\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/PROBLEM+JSON\"));\n assertTrue(apiClient.isJsonMime(\"application/json\\t\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+bar+json\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json;x;y\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json\\t;\"));\n assertTrue(apiClient.isJsonMime(\"Example/fOO+JSON\"));\n }", + "completions": [ + "```java\n @Test\n public void testIsJsonMime() {\n // Test case 1: \"* / *\"\n assertTrue(isJsonMime(\"*/*\"));\n\n // Test case 2: \"application/json\"\n assertTrue(isJsonMime(\"application/json\"));\n\n // Test case 3: \"application/json; charset=UTF-8\"\n assertTrue(isJsonMime(\"application/json; charset=UTF-8\"));\n\n // Test case 4: \"application/xml\"\n assertFalse(isJsonMime(\"application/xml\"));\n\n // Test case 5: \"text/plain\"\n assertFalse(isJsonMime(\"text/plain\"));\n\n // Test case 6: \"invalid/media/type\"\n assertFalse(isJsonMime(\"invalid/media/type\"));\n\n // Test case 7: null\n assertFalse(isJsonMime(null));\n }\n```\n" + ], + "line": 22, + "token": 139, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n dpopService.validateDpopThumprint(\"test\", \"test\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n // Given\n String existingThumprint = \"existingThumprint\";\n String requestThumprint = \"existingThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n // Call the method to test\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String getExpandUrl(String url, Object paramObj) {\n UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url);\n if (paramObj instanceof Map) {\n for (Map.Entry param : ((Map) paramObj).entrySet()) {\n if (null != param.getValue()) {\n builder.queryParam(param.getKey(), param.getValue());\n }\n }\n } else {\n for (Field param : paramObj.getClass().getDeclaredFields()) {\n param.setAccessible(true);\n\n try {\n Object value = param.get(paramObj);\n if (null != value) {\n if (value instanceof List) {\n builder.queryParam(param.getName(), (List) value);\n } else {\n builder.queryParam(param.getName(), value);\n }\n }\n } catch (IllegalAccessException e) {\n log.warn(\"Field: {} value retrieval failed, exception: {}\", param.getName(), e);\n }\n }\n }\n\n return builder.build().encode().toUri().toString();\n }\n```\n###test function signature:\n```java\n @Test\n void testGetExpandUrl()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetExpandUrl", + "reference": " @Test\n void testGetExpandUrl() {\n String result = HttpUtil.getExpandUrl(\"http://baidu.com\", \"paramObj\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testGetExpandUrl() {\n // Test with a Map\n Map map = new HashMap<>();\n map.put(\"param1\", \"value1\");\n map.put(\"param2\", \"value2\");\n String url = \"http://example.com\";\n String expected = \"http://example.com?param1=value1¶m2=value2\";\n String actual = getExpandUrl(url, map);\n assertEquals(expected, actual);\n\n // Test with a custom object\n TestObject obj = new TestObject(\"value1\", \"value2\");\n expected = \"http://example.com?param1=value1¶m2=value2\";\n actual = getExpandUrl(url, obj);\n assertEquals(expected, actual);\n }\n\n public static class TestObject {\n private String param1;\n private String param2;\n\n public TestObject(String param1, String param2) {\n this.param1 = param1;\n this.param2 = param2;\n }\n\n public String getParam1() {\n return param1;\n }\n\n public String getParam2() {\n return param2;\n }\n }\n```\n" + ], + "line": 38, + "token": 277, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 60, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"asdf\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n String invalidExpirationTime = \"invalid\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(invalidExpirationTime);\n assertEquals(CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 61, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Get\n @Path(\"/:indexName/_mapping\")\n public HttpResponse mapping(\n @Param(\"indexName\") Optional indexName,\n @Param(\"startTimeEpochMs\") Optional startTimeEpochMs,\n @Param(\"endTimeEpochMs\") Optional endTimeEpochMs)\n throws IOException {\n // Use a tree map so the results are naturally sorted\n Map> propertiesMap = new TreeMap<>();\n\n // we default the schema search to the last hour if params are not provided\n AstraSearch.SchemaResult schemaResult =\n searcher.getSchema(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(\n startTimeEpochMs.orElse(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build());\n\n Map schema = SearchResultUtils.fromSchemaResultProto(schemaResult);\n schema.forEach((key, value) -> propertiesMap.put(key, Map.of(\"type\", value.getName())));\n\n // todo - remove this after we add support for a \"date\" type\n // override the timestamp as a date field for proper autocomplete\n propertiesMap.put(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, Map.of(\"type\", \"date\"));\n\n return HttpResponse.of(\n HttpStatus.OK,\n MediaType.JSON,\n JsonUtil.writeAsString(\n ImmutableMap.of(\n indexName.orElseThrow(),\n ImmutableMap.of(\"mappings\", ImmutableMap.of(\"properties\", propertiesMap)))));\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexMapping() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexMapping", + "reference": " @Test\n public void testIndexMapping() throws IOException {\n AstraQueryServiceBase searcher = mock(AstraQueryServiceBase.class);\n ElasticsearchApiService serviceUnderTest = new ElasticsearchApiService(searcher);\n\n Instant start = Instant.now();\n Instant end = start.minusSeconds(60);\n\n when(searcher.getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build())))\n .thenReturn(AstraSearch.SchemaResult.newBuilder().build());\n\n HttpResponse response =\n serviceUnderTest.mapping(\n Optional.of(\"foo\"), Optional.of(start.toEpochMilli()), Optional.of(end.toEpochMilli()));\n verify(searcher)\n .getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build()));\n\n // handle response\n AggregatedHttpResponse aggregatedRes = response.aggregate().join();\n String body = aggregatedRes.content(StandardCharsets.UTF_8);\n JsonNode jsonNode = new ObjectMapper().readTree(body);\n\n assertThat(aggregatedRes.status().code()).isEqualTo(200);\n\n assertThat(jsonNode.findValue(\"foo\")).isNotNull();\n assertThat(\n jsonNode.findValue(\"foo\").findValue(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName))\n .isNotNull();\n\n when(searcher.getSchema(any()))\n .thenAnswer(\n invocationOnMock -> {\n AstraSearch.SchemaRequest request =\n ((AstraSearch.SchemaRequest) invocationOnMock.getArguments()[0]);\n assertThat(request.getDataset()).isEqualTo(\"bar\");\n assertThat(request.getStartTimeEpochMs())\n .isCloseTo(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli(),\n Offset.offset(1000L));\n assertThat(request.getEndTimeEpochMs())\n .isCloseTo(Instant.now().toEpochMilli(), Offset.offset(1000L));\n return AstraSearch.SchemaResult.newBuilder().build();\n });\n serviceUnderTest.mapping(Optional.of(\"bar\"), Optional.empty(), Optional.empty());\n }", + "completions": [ + "```java\n @Test\n public void testIndexMapping() throws IOException {\n Optional indexName = Optional.of(\"testIndex\");\n Optional startTimeEpochMs = Optional.of(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli());\n Optional endTimeEpochMs = Optional.of(Instant.now().toEpochMilli());\n\n AstraSearch.SchemaResult schemaResult =\n AstraSearch.SchemaResult.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(startTimeEpochMs.orElse(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build();\n\n when(searcher.getSchema(any(AstraSearch.SchemaRequest.class))).thenReturn(schemaResult);\n\n HttpResponse response = mapping(indexName, startTimeEpochMs, endTimeEpochMs);\n\n assertEquals(HttpStatus.OK, response.getStatus());\n assertEquals(MediaType.JSON, response.getContentType());\n assertNotNull(response.getContent());\n }\n```\n" + ], + "line": 43, + "token": 354, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 62, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "@SuppressWarnings", + "reference": " @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n consumer = newConsumer();\n Exception e = assertThrows(UnsupportedOperationException.class, () -> consumer.poll(0L));\n assertEquals(\"Consumer.poll(long) is not supported when \\\"group.protocol\\\" is \\\"consumer\\\". \" +\n \"This method is deprecated and will be removed in the next major release.\", e.getMessage());\n }", + "completions": [ + "```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 46, + "token": 373, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 63, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testNullSearchString()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNullSearchString", + "reference": " @Test\n public void testNullSearchString() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n assertThatExceptionOfType(IllegalArgumentException.class)\n .isThrownBy(\n () ->\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME + \"miss\",\n null,\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\")));\n }", + "completions": [ + "```java\n @Test\n public void testNullSearchString() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = null;\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n try {\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n fail(\"Expected an IllegalArgumentException to be thrown\");\n } catch (IllegalArgumentException e) {\n // Expected\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 64, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture sendMessageBack(ProxyContext ctx, ReceiptHandle handle, String messageId,\n ConsumerSendMsgBackRequestHeader requestHeader, long timeoutMillis) {\n // Build the response.\n final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null, null);\n\n Integer delayLevel = requestHeader.getDelayLevel();\n if (Objects.isNull(delayLevel)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument delay level is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n Long offset = requestHeader.getOffset();\n if (Objects.isNull(offset)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument offset is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n VirtualQueue virtualQueue = new VirtualQueue(requestHeader.getBname());\n\n CompletableFuture topicFuture = topicOf(requestHeader.getOriginTopic());\n CompletableFuture groupFuture = consumerGroupOf(requestHeader.getGroup());\n return topicFuture.thenCombine(groupFuture, (topic, group) -> {\n if (topic.getTopicId() != virtualQueue.topicId()) {\n LOGGER.error(\"Topic id in request header {} does not match topic id in message queue {}, maybe the topic is recreated.\",\n topic.getTopicId(), virtualQueue.topicId());\n throw new ProxyException(apache.rocketmq.v2.Code.TOPIC_NOT_FOUND, \"Topic resource does not exist.\");\n }\n return Pair.of(topic, group);\n }).thenCompose(pair -> {\n Topic topic = pair.getLeft();\n ConsumerGroup group = pair.getRight();\n\n return store.pull(StoreContext.EMPTY, group.getGroupId(), topic.getTopicId(), virtualQueue.physicalQueueId(),\n Filter.DEFAULT_FILTER, requestHeader.getOffset(), 1, false)\n .thenApply(pullResult -> {\n if (pullResult.status() == com.automq.rocketmq.store.model.message.PullResult.Status.FOUND) {\n return pullResult.messageList().get(0);\n }\n throw new ProxyException(apache.rocketmq.v2.Code.MESSAGE_NOT_FOUND, \"Message not found from server.\");\n }).thenCompose(messageExt -> {\n if (messageExt.deliveryAttempts() > group.getMaxDeliveryAttempt()) {\n return deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n }\n\n // Message consume retry strategy\n // <0: no retry,put into DLQ directly\n // =0: broker control retry frequency\n // >0: client control retry frequency\n return switch (Integer.compare(delayLevel, 0)) {\n case -1 ->\n deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n case 0 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n // Keep the same logic as apache RocketMQ.\n int serverDelayLevel = messageExt.deliveryAttempts() + 1;\n messageExt.setDeliveryAttempts(serverDelayLevel);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(serverDelayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put messageExt to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n case 1 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n messageExt.setDeliveryAttempts(messageExt.deliveryAttempts() + 1);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(delayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put message to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n default -> throw new IllegalStateException(\"Never reach here\");\n };\n });\n }).whenComplete((nil, throwable) -> {\n if (throwable != null) {\n response.setCode(ResponseCode.SYSTEM_ERROR);\n response.setRemark(throwable.getMessage());\n }\n }).thenApply(nil -> response);\n }\n```\n###test function signature:\n```java\n @Test\n void sendMessageBack()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendMessageBack", + "reference": " @Test\n void sendMessageBack() {\n String topicName = \"topic\";\n VirtualQueue virtualQueue = new VirtualQueue(2, 0);\n ConsumerSendMsgBackRequestHeader header = new ConsumerSendMsgBackRequestHeader();\n header.setBname(virtualQueue.brokerName());\n header.setGroup(\"group\");\n header.setOriginTopic(topicName);\n header.setOffset(0L);\n header.setDelayLevel(0);\n\n CompletionException exception = assertThrowsExactly(CompletionException.class, () -> messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join());\n assertInstanceOf(ProxyException.class, ExceptionUtils.getRealException(exception));\n ProxyException realException = (ProxyException) ExceptionUtils.getRealException(exception);\n assertEquals(realException.getErrorCode(), Code.MESSAGE_NOT_FOUND);\n assertEquals(realException.getMessage(), \"Message not found from server.\");\n\n long topicId = metadataService.topicOf(topicName).join().getTopicId();\n messageStore.put(StoreContext.EMPTY, FlatMessageUtil.convertTo(ProxyContextExt.create(), topicId, 0, \"\", new Message(topicName, \"\", new byte[] {})));\n\n // Broker controlled delay level.\n header.setDelayLevel(0);\n RemotingCommand response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Client controlled delay level.\n header.setDelayLevel(16);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Forward message into dead letter topic.\n header.setDelayLevel(-1);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.times(1)).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n }", + "completions": [ + "```java\n @Test\n void sendMessageBack() {\n // Given\n ProxyContext ctx = mock(ProxyContext.class);\n ReceiptHandle handle = mock(ReceiptHandle.class);\n String messageId = \"testMessageId\";\n ConsumerSendMsgBackRequestHeader requestHeader = mock(ConsumerSendMsgBackRequestHeader.class);\n long timeoutMillis = 1000L;\n\n when(requestHeader.getDelayLevel()).thenReturn(1);\n when(requestHeader.getOffset()).thenReturn(1L);\n when(requestHeader.getOriginTopic()).thenReturn(\"testOriginTopic\");\n when(requestHeader.getGroup()).thenReturn(\"testGroup\");\n when(requestHeader.getBname()).thenReturn(\"testBname\");\n\n Topic topic = mock(Topic.class);\n when(topic.getTopicId()).thenReturn(1);\n\n ConsumerGroup group = mock(ConsumerGroup.class);\n when(group.getGroupId()).thenReturn(\"testGroupId\");\n when(group.getMaxDeliveryAttempt()).thenReturn(3);\n\n FlatMessage message = mock(FlatMessage.class);\n when(message.deliveryAttempts()).thenReturn(1);\n when(message.originalOffset()).thenReturn(1L);\n\n MessageExt messageExt = mock(MessageExt.class);\n when(messageExt.message()).thenReturn(message);\n when(messageExt.deliveryAttempts()).thenReturn(1);\n\n PullResult pullResult = mock(PullResult.class);\n when(pullResult.status()).thenReturn(PullResult.Status.FOUND);\n when(pullResult.messageList()).thenReturn(List.of(messageExt));\n\n Store store = mock(Store.class);\n when(store.pull(any(), any(), any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(pullResult));\n when(store.put(any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n DeadLetterService deadLetterService = mock(DeadLetterService.class);\n when(deadLetterService.send(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n ProxyContextExt ctxExt = mock(ProxyContextExt.class);\n\n TopicService topicService = mock(TopicService.class);\n when(topicService.topicOf(any())).thenReturn(CompletableFuture.completedFuture(topic));\n when(topicService.consumerGroupOf(any())).thenReturn(CompletableFuture.completedFuture(group));\n\n ConsumerService consumerService = new ConsumerService(store, deadLetterService, topicService);\n\n // When\n CompletableFuture future = consumerService.sendMessageBack(ctx, handle, messageId, requestHeader, timeoutMillis);\n\n // Then\n RemotingCommand response = future.join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n }\n```\n" + ], + "line": 99, + "token": 793, + "line_diff": 5, + "token_diff": 6 + }, + { + "id": 65, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInRequestThreads()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInRequestThreads", + "reference": " @Test\n public void testGetTracerInRequestThreads() {\n ApolloAuditTracer mockTracer = new ApolloAuditTracer(Mockito.mock(ApolloAuditScopeManager.class), supplier);\n RequestAttributes mockRequestAttributes = Mockito.mock(RequestAttributes.class);\n RequestContextHolder.setRequestAttributes(mockRequestAttributes);\n Mockito.when(mockRequestAttributes.getAttribute(Mockito.eq(ApolloAuditConstants.TRACER), Mockito.eq(RequestAttributes.SCOPE_REQUEST)))\n .thenReturn(mockTracer);\n ApolloAuditTracer get = traceContext.tracer();\n assertNotNull(get);\n Mockito.verify(traceContext, Mockito.times(0))\n .setTracer(Mockito.any(ApolloAuditTracer.class));\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInRequestThreads() {\n // Given\n RequestAttributes requestAttributes = Mockito.mock(RequestAttributes.class);\n ApolloAuditTracer tracer = Mockito.mock(ApolloAuditTracer.class);\n Mockito.when(requestAttributes.getAttribute(ApolloAuditConstants.TRACER, RequestAttributes.SCOPE_REQUEST)).thenReturn(tracer);\n Mockito.when(RequestContextHolder.getRequestAttributes()).thenReturn(requestAttributes);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n Assert.assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Map generate(String modelName) {\n Map kv = new HashMap<>();\n if (modelName != null) {\n final Schema schema = this.globalContext.getSchemaMap().get(modelName);\n if (schema != null) {\n String example = Json.pretty(this.resolveModelToExample(modelName, schema));\n\n if (example != null) {\n kv.put(EXAMPLE, example);\n return Map.copyOf(kv);\n }\n }\n }\n return Collections.emptyMap();\n }\n```\n###test function signature:\n```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "shouldGenerateOneOfWhenOneOfInRoot", + "reference": " @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n OpenAPIModelGenerator generator = setupPayloadGenerator();\n\n Map example = generator.generate(\"PetType\");\n String exampleJson = example.get(\"example\");\n\n Assertions.assertThat(exampleJson)\n .contains(\"PetTypeONE_OF#/components/schemas/Husky\")\n .contains(\"PetTypeONE_OF#/components/schemas/Labrador\");\n }", + "completions": [ + "```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n // Given\n String modelName = \"ModelName\";\n Schema schema = new Schema();\n schema.setOneOf(Arrays.asList(new Schema(), new Schema()));\n this.globalContext.getSchemaMap().put(modelName, schema);\n\n // When\n Map result = this.generate(modelName);\n\n // Then\n assertNotNull(result);\n assertFalse(result.isEmpty());\n assertTrue(result.containsKey(EXAMPLE));\n assertNotNull(result.get(EXAMPLE));\n }\n```\n" + ], + "line": 25, + "token": 181, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInAnotherThreadButSameRequest", + "reference": " @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n ApolloAuditTracer mockTracer = Mockito.mock(ApolloAuditTracer.class);\n {\n Mockito.when(traceContext.tracer()).thenReturn(mockTracer);\n }\n CountDownLatch latch = new CountDownLatch(1);\n Executors.newSingleThreadExecutor().submit(() -> {\n ApolloAuditTracer tracer = traceContext.tracer();\n\n assertEquals(mockTracer, tracer);\n\n latch.countDown();\n });\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n // Given\n RequestAttributes requestAttributes = new ServletRequestAttributes(new MockHttpServletRequest());\n RequestContextHolder.setRequestAttributes(requestAttributes);\n ApolloAuditTracer tracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n requestAttributes.setAttribute(ApolloAuditConstants.TRACER, tracer, RequestAttributes.SCOPE_REQUEST);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchMultipleItemsAndIndices()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchMultipleItemsAndIndices", + "reference": " @Test\n public void testSearchMultipleItemsAndIndices() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(1);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchMultipleItemsAndIndices() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(10, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 69, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n static Map>> getCandidateJobs(String tableNameWithType,\n Map> allJobMetadata)\n throws Exception {\n long nowMs = System.currentTimeMillis();\n Map>> candidates = new HashMap<>();\n // If the job started most recently has already completed, then skip retry for the table.\n Pair latestStartedJob = null;\n Pair latestCompletedJob = null;\n // The processing order of job metadata from the given Map is not deterministic. Track the completed original\n // jobs so that we can simply skip the retry jobs belonging to the completed original jobs.\n Map completedOriginalJobs = new HashMap<>();\n Set cancelledOriginalJobs = new HashSet<>();\n for (Map.Entry> entry : allJobMetadata.entrySet()) {\n String jobId = entry.getKey();\n Map jobMetadata = entry.getValue();\n long statsUpdatedAt = Long.parseLong(jobMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));\n String jobStatsInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);\n if (StringUtils.isEmpty(jobStatsInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job progress stats\", jobId);\n continue;\n }\n String jobCtxInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n if (StringUtils.isEmpty(jobCtxInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job context\", jobId);\n continue;\n }\n TableRebalanceProgressStats jobStats = JsonUtils.stringToObject(jobStatsInStr, TableRebalanceProgressStats.class);\n TableRebalanceContext jobCtx = JsonUtils.stringToObject(jobCtxInStr, TableRebalanceContext.class);\n long jobStartTimeMs = jobStats.getStartTimeMs();\n if (latestStartedJob == null || latestStartedJob.getRight() < jobStartTimeMs) {\n latestStartedJob = Pair.of(jobId, jobStartTimeMs);\n }\n String originalJobId = jobCtx.getOriginalJobId();\n RebalanceResult.Status jobStatus = jobStats.getStatus();\n if (jobStatus == RebalanceResult.Status.DONE || jobStatus == RebalanceResult.Status.NO_OP) {\n LOGGER.info(\"Skip rebalance job: {} as it has completed with status: {}\", jobId, jobStatus);\n completedOriginalJobs.put(originalJobId, jobId);\n if (latestCompletedJob == null || latestCompletedJob.getRight() < jobStartTimeMs) {\n latestCompletedJob = Pair.of(jobId, jobStartTimeMs);\n }\n continue;\n }\n if (jobStatus == RebalanceResult.Status.FAILED || jobStatus == RebalanceResult.Status.ABORTED) {\n LOGGER.info(\"Found rebalance job: {} for original job: {} has been stopped with status: {}\", jobId,\n originalJobId, jobStatus);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n continue;\n }\n if (jobStatus == RebalanceResult.Status.CANCELLED) {\n LOGGER.info(\"Found cancelled rebalance job: {} for original job: {}\", jobId, originalJobId);\n cancelledOriginalJobs.add(originalJobId);\n continue;\n }\n // Check if an IN_PROGRESS job is still actively running.\n long heartbeatTimeoutMs = jobCtx.getConfig().getHeartbeatTimeoutInMs();\n if (nowMs - statsUpdatedAt < heartbeatTimeoutMs) {\n LOGGER.info(\"Rebalance job: {} is actively running with status updated at: {} within timeout: {}. Skip \"\n + \"retry for table: {}\", jobId, statsUpdatedAt, heartbeatTimeoutMs, tableNameWithType);\n return Collections.emptyMap();\n }\n // The job is considered failed, but it's possible it is still running, then we might end up with more than one\n // rebalance jobs running in parallel for a table. The rebalance algorithm is idempotent, so this should be fine\n // for the correctness.\n LOGGER.info(\"Found stuck rebalance job: {} for original job: {}\", jobId, originalJobId);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n }\n if (latestCompletedJob != null && latestCompletedJob.getLeft().equals(latestStartedJob.getLeft())) {\n LOGGER.info(\"Rebalance job: {} started most recently has already done. Skip retry for table: {}\",\n latestCompletedJob.getLeft(), tableNameWithType);\n return Collections.emptyMap();\n }\n for (String jobId : cancelledOriginalJobs) {\n LOGGER.info(\"Skip original job: {} as it's cancelled\", jobId);\n candidates.remove(jobId);\n }\n for (Map.Entry entry : completedOriginalJobs.entrySet()) {\n LOGGER.info(\"Skip original job: {} as it's completed by attempt: {}\", entry.getKey(), entry.getValue());\n candidates.remove(entry.getKey());\n }\n return candidates;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetCandidateJobs()\n throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetCandidateJobs", + "reference": " @Test\n public void testGetCandidateJobs()\n throws Exception {\n String tableName = \"table01\";\n Map> allJobMetadata = new HashMap<>();\n\n // Original job run as job1, and all its retry jobs failed too.\n RebalanceConfig jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n TableRebalanceProgressStats stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(1000);\n TableRebalanceContext jobCtx = TableRebalanceContext.forInitialAttempt(\"job1\", jobCfg);\n Map jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job1\", stats, jobCtx);\n allJobMetadata.put(\"job1\", jobMetadata);\n // 3 failed retry runs for job1\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 2, 1100, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 3, 1200, RebalanceResult.Status.ABORTED);\n allJobMetadata.put(\"job1_3\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 4, 1300, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_4\", jobMetadata);\n\n // Original job run as job2, and its retry job job2_1 completed.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(2000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job2\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job2\", stats, jobCtx);\n allJobMetadata.put(\"job2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job2\", 2, 2100, RebalanceResult.Status.DONE);\n allJobMetadata.put(\"job2_2\", jobMetadata);\n\n // Original job run as job3, and failed to send out heartbeat in time.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.IN_PROGRESS);\n stats.setStartTimeMs(3000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job3\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job3\", stats, jobCtx);\n jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"3000\");\n allJobMetadata.put(\"job3\", jobMetadata);\n\n // Original job run as job4, which didn't have retryJobCfg as from old version of the code.\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(4000);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job4\", stats, null);\n jobMetadata.remove(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n allJobMetadata.put(\"job4\", jobMetadata);\n\n // Only need to retry job1 and job3, as job2 is completed and job4 is from old version of code.\n Map>> jobs =\n RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 2);\n assertTrue(jobs.containsKey(\"job1\"));\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job1\").size(), 4); // four runs including job1,job1_1,job1_2,job1_3\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Abort job1 and cancel its retries, then only job3 is retry candidate.\n jobMetadata = allJobMetadata.get(\"job1_4\");\n cancelRebalanceJob(jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 1);\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Add latest job5 that's already done, thus no need to retry for table.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.DONE);\n stats.setStartTimeMs(5000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job5\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job5\", stats, jobCtx);\n allJobMetadata.put(\"job5\", jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 0);\n }", + "completions": [ + "```java\n @Test\n public void testGetCandidateJobs()\n throws Exception {\n // Given\n String tableNameWithType = \"table1_type1\";\n Map> allJobMetadata = new HashMap<>();\n Map jobMetadata1 = new HashMap<>();\n jobMetadata1.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"FAILED\\\"}\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job1\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job1\", jobMetadata1);\n Map jobMetadata2 = new HashMap<>();\n jobMetadata2.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"CANCELLED\\\"}\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job2\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job2\", jobMetadata2);\n\n // When\n Map>> result = getCandidateJobs(tableNameWithType, allJobMetadata);\n\n // Then\n assertEquals(1, result.size());\n assertTrue(result.containsKey(\"job2\"));\n Set> jobSet = result.get(\"job2\");\n assertEquals(1, jobSet.size());\n Pair jobPair = jobSet.iterator().next();\n assertEquals(\"job2\", jobPair.getLeft().getOriginalJobId());\n assertEquals(1640995200000L, jobPair.getRight().longValue());\n }\n```\n" + ], + "line": 94, + "token": 952, + "line_diff": 5, + "token_diff": 7 + }, + { + "id": 70, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchError()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchError", + "reference": " @Test\n public void distClientMatchError() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.ERROR);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.ERROR);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.NO_SUB));\n for (IMQTTTransientSession session : sessions) {\n verify(session, never()).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchError() {\n // Create a mock for the ITenantMeter interface\n ITenantMeter mockTenantMeter = Mockito.mock(ITenantMeter.class);\n // Create a mock for the ILocalDistService interface\n ILocalDistService mockLocalDistService = Mockito.mock(ILocalDistService.class);\n // Create a mock for the sessionMap\n Map mockSessionMap = new HashMap<>();\n // Create a mock for the routeMap\n Map> mockRouteMap = new HashMap<>();\n // Create a mock for the resourceThrottler\n IResourceThrottler mockResourceThrottler = Mockito.mock(IResourceThrottler.class);\n\n // Set up the mock objects\n Mockito.when(mockTenantMeter.get(Mockito.anyString())).thenReturn(mockTenantMeter);\n Mockito.when(mockLocalDistService.isGlobal(Mockito.anyString())).thenReturn(false);\n Mockito.when(mockSessionMap.get(Mockito.anyString())).thenReturn(null);\n Mockito.when(mockRouteMap.get(Mockito.any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(null));\n Mockito.when(mockResourceThrottler.hasResource(Mockito.anyString(), Mockito.anyLong())).thenReturn(false);\n\n // Create a new instance of the class under test\n DistService distService = new DistService(mockTenantMeter, mockLocalDistService, mockSessionMap, mockRouteMap, mockResourceThrottler);\n\n // Create a test request\n DeliveryRequest request = DeliveryRequest.newBuilder().build();\n\n // Call the method under test\n CompletableFuture replyFuture = distService.dist(request);\n\n // Verify the results\n DeliveryReply reply = replyFuture.join();\n Assert.assertEquals(0, reply.getResultMap().size());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 6, + "token_diff": 5 + }, + { + "id": 71, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchOK()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchOK", + "reference": " @Test\n public void distClientMatchOK() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n when(session.publish(any(), anyList())).thenReturn(true);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.OK);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.OK);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.OK));\n for (IMQTTTransientSession session : sessions) {\n verify(session, times(1)).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchOK() {\n // Given\n DeliveryRequest request = DeliveryRequest.newBuilder()\n .putPackage(\"tenant1\", DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .setMessagePack(TopicMessagePack.newBuilder()\n .setTopic(\"topic1\")\n .setPayload(\"payload1\")\n .build())\n .addMatchInfo(MatchInfo.newBuilder()\n .setReceiverId(\"receiver1\")\n .setTopicFilter(\"topic1\")\n .build())\n .build())\n .build())\n .build();\n sessionMap.put(\"receiver1\", mock(IMQTTTransientSession.class));\n when(sessionMap.get(\"receiver1\").publish(any(MatchInfo.class), anyList())).thenReturn(true);\n when(routeMap.get(any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(LocalRoutes.newBuilder()\n .setLocalizedReceiverId(\"receiver1\")\n .putRoute(\"receiver1\", mock(IMQTTTransientSession.class))\n .build()));\n when(ITenantMeter.get(\"tenant1\")).thenReturn(mock(ITenantMeter.class));\n when(resourceThrottler.hasResource(\"tenant1\", TotalTransientFanOutBytesPerSeconds)).thenReturn(true);\n\n // When\n CompletableFuture replyFuture = dist(request);\n\n // Then\n DeliveryReply reply = replyFuture.join();\n DeliveryResults results = reply.getResultMap().get(\"tenant1\");\n assertEquals(1, results.getResultCount());\n DeliveryResult result = results.getResult(0);\n assertEquals(DeliveryResult.Code.OK, result.getCode());\n assertEquals(\"receiver1\", result.getMatchInfo().getReceiverId());\n assertEquals(\"topic1\", result.getMatchInfo().getTopicFilter());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 6, + "token_diff": 5 + }, + { + "id": 72, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Range calcInterquartileRangeApproximation(List sortedValues) {\n Range iqr = null;\n\n if (sortedValues.size() > 0) {\n BigDecimal q1 = valueCloseToQ1(sortedValues);\n BigDecimal q3 = valueCloseToQ3(sortedValues);\n BigDecimal max = sortedValues.get(sortedValues.size() - 1);\n\n // if iqr == 0 AND max == q3 then recursively try finding a non-zero iqr approximation.\n if (q1.compareTo(q3) == 0 && max.compareTo(q3) == 0) {\n // filter out max and try again\n iqr = this.calcInterquartileRangeApproximation(\n sortedValues.stream().filter(d -> d.compareTo(max) < 0).collect(Collectors.toList()));\n }\n\n // if range is still empty use the original q1 and q3 values\n if (iqr == null || iqr.isEmpty()) {\n iqr = Range.closedOpen(q1, q3);\n }\n }\n\n return iqr;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCalcInterquartileRangeApproximation()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCalcInterquartileRangeApproximation", + "reference": " @Test\n public void testCalcInterquartileRangeApproximation() {\n List values = decList(1, 2, 3, 4, 5, 6, 7, 8);\n assertEquals(new BigDecimal(\"3\"), dataBinHelper.calcInterquartileRangeApproximation(values).lowerEndpoint());\n assertEquals(new BigDecimal(\"7\"), dataBinHelper.calcInterquartileRangeApproximation(values).upperEndpoint());\n }", + "completions": [ + "```java\n @Test\n public void testCalcInterquartileRangeApproximation() {\n List sortedValues = Arrays.asList(\n new BigDecimal(\"1.0\"),\n new BigDecimal(\"2.0\"),\n new BigDecimal(\"3.0\"),\n new BigDecimal(\"4.0\"),\n new BigDecimal(\"5.0\"),\n new BigDecimal(\"6.0\"),\n new BigDecimal(\"7.0\"),\n new BigDecimal(\"8.0\"),\n new BigDecimal(\"9.0\"),\n new BigDecimal(\"10.0\")\n );\n\n Range expectedIqr = Range.closedOpen(new BigDecimal(\"2.5\"), new BigDecimal(\"7.5\"));\n Range actualIqr = calcInterquartileRangeApproximation(sortedValues);\n\n assertEquals(expectedIqr, actualIqr);\n }\n```\n" + ], + "line": 30, + "token": 250, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 73, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_non_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new NonSerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new NonSerializableObject(\"name2\"),\n new NonSerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObjectWithNonSerializableNestedObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serialable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key1\", \"value1\");\n map.put(\"key2\", new NonSerializableObject(\"nonSerializableValue\"));\n\n // When\n Map clonedMap = objectCloner.clone(map);\n\n // Then\n assertNotSame(map, clonedMap);\n assertEquals(map.size(), clonedMap.size());\n assertEquals(map.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(map.get(\"key2\"), clonedMap.get(\"key2\"));\n assertEquals(map.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 74, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SignedJWT token2Jwt(String token) {\n if (token == null) {\n log.error(\"Token is null.\");\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n try {\n return SignedJWT.parse(token);\n } catch (ParseException e) {\n log.error(\"Token to JWT Error,message = {}\", e.getMessage());\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testToken2Jwt()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testToken2Jwt", + "reference": " @Test\n public void testToken2Jwt() {\n Assertions.assertThrows(BizException.class, () -> tokenParseHelper.token2Jwt(null));\n }", + "completions": [ + "```java\n @Test\n public void testToken2Jwt() {\n // Given\n String token = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c\";\n\n // When\n SignedJWT jwt = token2Jwt(token);\n\n // Then\n assertNotNull(jwt);\n assertEquals(token, jwt.serialize());\n }\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testLongPollWaitIsLimited()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testLongPollWaitIsLimited", + "reference": " @Test\n public void testLongPollWaitIsLimited() {\n consumer = newConsumer();\n String topicName = \"topic1\";\n consumer.subscribe(singletonList(topicName));\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertTrue(consumer.assignment().isEmpty());\n\n final int partition = 3;\n final TopicPartition tp = new TopicPartition(topicName, partition);\n final List> records = asList(\n new ConsumerRecord<>(topicName, partition, 2, \"key1\", \"value1\"),\n new ConsumerRecord<>(topicName, partition, 3, \"key2\", \"value2\")\n );\n\n // On the first iteration, return no data; on the second, return two records\n doAnswer(invocation -> {\n // Mock the subscription being assigned as the first fetch is collected\n consumer.subscriptions().assignFromSubscribed(Collections.singleton(tp));\n return Fetch.empty();\n }).doAnswer(invocation -> {\n return Fetch.forPartition(tp, records, true);\n }).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n\n // And then poll for up to 10000ms, which should return 2 records without timing out\n ConsumerRecords returnedRecords = consumer.poll(Duration.ofMillis(10000));\n assertEquals(2, returnedRecords.count());\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertEquals(singleton(tp), consumer.assignment());\n }", + "completions": [ + "```java\n @Test\n public void testLongPollWaitIsLimited() {\n // Arrange\n Duration timeout = Duration.ofMillis(1000);\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // Act\n long startTime = System.currentTimeMillis();\n ConsumerRecords records = consumer.poll(timeout);\n long endTime = System.currentTimeMillis();\n\n // Assert\n long pollDuration = endTime - startTime;\n assertTrue(pollDuration >= timeout.toMillis());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 76, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForMinAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForMinAgg", + "reference": " @Test\n public void testFullIndexSearchForMinAgg() {\n Instant time = Instant.now();\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new MinAggBuilder(\n \"test\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalMin internalMin =\n (InternalMin) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n assertThat(Double.valueOf(internalMin.getValue()).longValue()).isEqualTo(time.toEpochMilli());\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForMinAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new MinAggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation() instanceof MinAgg);\n assertTrue(result.getElapsedTimeMicros() > 0);\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 77, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_bsaChecksumFetched", + "reference": " @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n setupMocks();\n when(connection.getResponseCode()).thenReturn(SC_OK);\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream(\"bsa-checksum\\ndata\".getBytes(UTF_8)));\n LazyBlockList download = fetcher.tryFetch(BlockListType.BLOCK);\n assertThat(download.getName()).isEqualTo(BlockListType.BLOCK);\n assertThat(download.checksum()).isEqualTo(\"bsa-checksum\");\n verify(connection, times(1)).setRequestMethod(\"GET\");\n verify(connection, times(1)).setRequestProperty(\"Authorization\", \"Bearer authToken\");\n }", + "completions": [ + "```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n // Given\n URL dataUrl = new URL(\"https://example.com/bsaChecksum\");\n HttpsURLConnection connection = mock(HttpsURLConnection.class);\n when(connection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);\n when(connection.getErrorStream()).thenReturn(new ByteArrayInputStream(\"\".getBytes()));\n when(urlConnectionService.createConnection(dataUrl)).thenReturn(connection);\n\n // When\n LazyBlockList lazyBlockList = bsaService.tryFetch(BlockListType.BSA_CHECKSUM);\n\n // Then\n assertEquals(BlockListType.BSA_CHECKSUM, lazyBlockList.getBlockListType());\n assertEquals(connection, lazyBlockList.getConnection());\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 78, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDisabledFullTextSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDisabledFullTextSearch", + "reference": " @Test\n public void testDisabledFullTextSearch() {\n Instant time = Instant.ofEpochSecond(1593365471);\n Trace.KeyValue field1Tag =\n Trace.KeyValue.newBuilder()\n .setVInt32(1234)\n .setKey(\"field1\")\n .setFieldType(Schema.SchemaFieldType.INTEGER)\n .build();\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(2, \"apple baby\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(3, \"baby car 1234\", time.plusSeconds(4)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Without the _all field as default.\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(3);\n\n // empty string\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"app*\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Returns baby or car, 2 messages.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby car\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Test numbers\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"apple 1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"123\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n }", + "completions": [ + "```java\n @Test\n public void testDisabledFullTextSearch() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getResults().size());\n assertNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 79, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyWithException", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(\n OperateRuntimeException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHitsRelation.EQUAL_TO));\n Mockito.when(searchHits.getHits()).thenReturn(new SearchHit[0]);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processInstanceListViewService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException to be thrown\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find unique process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 80, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_SecurityException_notRetriable", + "reference": " @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n when(connectionService.createConnection(any(URL.class)))\n .thenThrow(new GeneralSecurityException());\n assertThat(\n assertThrows(BsaException.class, () -> fetcher.tryFetch(BlockListType.BLOCK))\n .isRetriable())\n .isFalse();\n }", + "completions": [ + "```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n // Arrange\n String blockListTypeName = \"TEST\";\n BlockListType blockListType = new BlockListType(blockListTypeName);\n URL dataUrl = new URL(\"http://example.com\");\n when(blockListUrls.get(blockListTypeName)).thenReturn(dataUrl.toString());\n when(urlConnectionService.createConnection(dataUrl)).thenThrow(new GeneralSecurityException());\n\n // Act and Assert\n assertThrows(BsaException.class, () -> {\n try {\n blockListService.tryFetch(blockListType);\n } catch (BsaException e) {\n assertFalse(e.isRetriable());\n throw e;\n }\n });\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 81, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssignOnEmptyTopicPartition()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssignOnEmptyTopicPartition", + "reference": " @Test\n public void testAssignOnEmptyTopicPartition() {\n consumer = newConsumer();\n completeUnsubscribeApplicationEventSuccessfully();\n\n consumer.assign(Collections.emptyList());\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().isEmpty());\n }", + "completions": [ + "```java\n @Test\n public void testAssignOnEmptyTopicPartition() {\n // Given\n Collection partitions = new ArrayList<>();\n ConsumerImpl consumer = new ConsumerImpl();\n\n // When\n consumer.assign(partitions);\n\n // Then\n // Add assertions to verify the expected behavior\n // For example, you can check if the consumer is unsubscribed and if the fetch buffer is empty\n // assertFalse(consumer.isSubscribed());\n // assertTrue(consumer.getFetchBuffer().isEmpty());\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 82, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ImmutableList applyLabelDiff(\n ImmutableList labels,\n IdnChecker idnChecker,\n DownloadSchedule schedule,\n DateTime now) {\n ImmutableList.Builder nonBlockedDomains = new ImmutableList.Builder<>();\n ImmutableMap> labelsByType =\n ImmutableMap.copyOf(\n labels.stream().collect(groupingBy(BlockLabel::labelType, toImmutableList())));\n\n tm().transact(\n () -> {\n for (Map.Entry> entry :\n labelsByType.entrySet()) {\n switch (entry.getKey()) {\n case CREATE:\n // With current Cloud SQL, label upsert throughput is about 200/second. If\n // better performance is needed, consider bulk insert in native SQL.\n tm().putAll(\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(\n label ->\n new BsaLabel(label.label(), schedule.jobCreationTime()))\n .collect(toImmutableList()));\n // May not find all unblockables due to race condition: DomainCreateFlow uses\n // cached BsaLabels. Eventually will be consistent.\n nonBlockedDomains.addAll(\n tallyUnblockableDomainsForNewLabels(entry.getValue(), idnChecker, now));\n break;\n case DELETE:\n ImmutableSet deletedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n // Delete labels in DB. Also cascade-delete BsaUnblockableDomain.\n int nDeleted = Queries.deleteBsaLabelByLabels(deletedLabels);\n if (nDeleted != deletedLabels.size()) {\n logger.atSevere().log(\n \"Only found %s entities among the %s labels: [%s]\",\n nDeleted, deletedLabels.size(), deletedLabels);\n }\n break;\n case NEW_ORDER_ASSOCIATION:\n ImmutableSet affectedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n ImmutableSet labelsInDb =\n Queries.queryBsaLabelByLabels(affectedLabels)\n .map(BsaLabel::getLabel)\n .collect(toImmutableSet());\n verify(\n labelsInDb.size() == affectedLabels.size(),\n \"Missing labels in DB: %s\",\n LazyArgs.lazy(() -> difference(affectedLabels, labelsInDb)));\n\n // Reuse registered and reserved names that are already computed.\n Queries.queryBsaUnblockableDomainByLabels(affectedLabels)\n .map(BsaUnblockableDomain::toUnblockableDomain)\n .forEach(nonBlockedDomains::add);\n\n for (BlockLabel label : entry.getValue()) {\n getInvalidTldsForLabel(label, idnChecker)\n .map(tld -> UnblockableDomain.of(label.label(), tld, Reason.INVALID))\n .forEach(nonBlockedDomains::add);\n }\n break;\n }\n }\n },\n TRANSACTION_REPEATABLE_READ);\n logger.atInfo().log(\"Processed %s of labels.\", labels.size());\n return nonBlockedDomains.build();\n }\n```\n###test function signature:\n```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "applyLabelDiffs_newAssociationOfLabelToOrder", + "reference": " @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n tm().transact(\n () -> {\n tm().insert(new BsaLabel(\"label\", fakeClock.nowUtc()));\n tm().insert(new BsaUnblockableDomain(\"label\", \"app\", Reason.REGISTERED));\n });\n when(idnChecker.getSupportingTlds(any())).thenReturn(ImmutableSet.of(app));\n when(idnChecker.getForbiddingTlds(any()))\n .thenReturn(Sets.difference(ImmutableSet.of(dev), ImmutableSet.of()).immutableCopy());\n\n ImmutableList unblockableDomains =\n applyLabelDiff(\n ImmutableList.of(\n BlockLabel.of(\"label\", LabelType.NEW_ORDER_ASSOCIATION, ImmutableSet.of())),\n idnChecker,\n schedule,\n fakeClock.nowUtc());\n assertThat(unblockableDomains)\n .containsExactly(\n UnblockableDomain.of(\"label.app\", UnblockableDomain.Reason.REGISTERED),\n UnblockableDomain.of(\"label.dev\", UnblockableDomain.Reason.INVALID));\n assertThat(tm().transact(() -> tm().loadByKeyIfPresent(BsaLabel.vKey(\"label\")))).isPresent();\n assertThat(\n tm().transact(() -> tm().loadByKeyIfPresent(BsaUnblockableDomain.vKey(\"label\", \"app\"))))\n .isPresent();\n }", + "completions": [ + "```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n // Given\n ImmutableList labels = ImmutableList.of(BlockLabel.of(\"label1\", LabelType.NEW_ORDER_ASSOCIATION));\n IdnChecker idnChecker = mock(IdnChecker.class);\n DownloadSchedule schedule = mock(DownloadSchedule.class);\n DateTime now = DateTime.now();\n\n // When\n ImmutableList result = applyLabelDiff(labels, idnChecker, schedule, now);\n\n // Then\n // Add assertions here to verify the result\n }\n```\n" + ], + "line": 85, + "token": 602, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 83, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture createTopic(CreateTopicRequest request) {\n CompletableFuture future = new CompletableFuture<>();\n for (; ; ) {\n if (metadataStore.isLeader()) {\n try (SqlSession session = metadataStore.openSession()) {\n if (!metadataStore.maintainLeadershipWithSharedLock(session)) {\n continue;\n }\n TopicMapper topicMapper = session.getMapper(TopicMapper.class);\n if (null != topicMapper.get(null, request.getTopic())) {\n ControllerException e = new ControllerException(Code.DUPLICATED_VALUE,\n String.format(\"Topic %s was taken\", request.getTopic()));\n future.completeExceptionally(e);\n return future;\n }\n\n Topic topic = new Topic();\n topic.setName(request.getTopic());\n topic.setQueueNum(request.getCount());\n topic.setStatus(TopicStatus.TOPIC_STATUS_ACTIVE);\n topic.setAcceptMessageTypes(JsonFormat.printer().print(request.getAcceptTypes()));\n topic.setRetentionHours(request.getRetentionHours());\n topicMapper.create(topic);\n long topicId = topic.getId();\n List assignments = createQueues(IntStream.range(0, request.getCount()),\n topicId, session);\n // Commit transaction\n session.commit();\n\n // Cache new topic and queue assignments immediately\n topicCache.apply(List.of(topic));\n assignmentCache.apply(assignments);\n future.complete(topicId);\n } catch (ControllerException | InvalidProtocolBufferException e) {\n future.completeExceptionally(e);\n }\n return future;\n } else {\n Optional leaderAddress = metadataStore.electionService().leaderAddress();\n if (leaderAddress.isEmpty()) {\n return CompletableFuture.failedFuture(new ControllerException(Code.NO_LEADER_VALUE, \"No leader is elected yet\"));\n }\n metadataStore.controllerClient().createTopic(leaderAddress.get(), request).whenComplete((res, e) -> {\n if (null != e) {\n future.completeExceptionally(e);\n } else {\n future.complete(res);\n }\n });\n }\n break;\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCreateTopic()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCreateTopic", + "reference": " @Test\n public void testCreateTopic() {\n MetadataStore metadataStore = Mockito.mock(MetadataStore.class);\n ControllerClient controllerClient = Mockito.mock(ControllerClient.class);\n Mockito.when(controllerClient.createTopic(ArgumentMatchers.anyString(), ArgumentMatchers.any())).thenReturn(\n CompletableFuture.failedFuture(new CompletionException(new RuntimeException()))\n );\n Mockito.when(metadataStore.controllerClient()).thenReturn(controllerClient);\n Mockito.when(metadataStore.isLeader()).thenReturn(false);\n\n ElectionService electionService = Mockito.mock(ElectionService.class);\n Mockito.when(metadataStore.electionService()).thenReturn(electionService);\n\n Mockito.when(electionService.leaderAddress()).thenReturn(Optional.of(\"localhost:1234\"));\n\n TopicManager topicManager = new TopicManager(metadataStore);\n topicManager.createTopic(CreateTopicRequest.newBuilder().build());\n }", + "completions": [ + "```java\n @Test\n public void testCreateTopic() {\n // Given\n CreateTopicRequest request = new CreateTopicRequest();\n request.setTopic(\"testTopic\");\n request.setCount(3);\n request.setRetentionHours(24);\n request.setAcceptTypes(List.of(MessageType.MESSAGE_TYPE_JSON));\n\n // When\n CompletableFuture future = controller.createTopic(request);\n\n // Then\n assertDoesNotThrow(() -> {\n Long topicId = future.get();\n assertNotNull(topicId);\n assertTrue(topicId > 0);\n });\n }\n```\n" + ], + "line": 63, + "token": 424, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 84, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull(){\n try(MockedStatic sessionMockedStatic = Mockito.mockStatic(Session.class)){\n Session mockSession = mock(Session.class);\n Transport transport = mock(Transport.class);\n sessionMockedStatic.when(()->Session.getDefaultInstance(any(Properties.class))).thenReturn(mockSession);\n when(blacklistService.inBlacklist(anyString())).thenReturn(false);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n when(settingsService.getByCategoryAndKey(anyString(),anyString())).thenReturn(\"test\").thenReturn(5678).thenReturn(\"test\").thenReturn(\"test\").thenReturn(\"123456\");\n when(mockSession.getTransport(anyString())).thenReturn(transport);\n doNothing().when(transport).connect(anyString(),anyInt(),anyString(),anyString());\n doAnswer(invocationOnMock -> {\n InternetAddress[] internetAddressList = invocationOnMock.getArgument(1);\n Assertions.assertEquals(addressList.get(0),internetAddressList[0].getAddress());\n return null;\n }).when(transport).sendMessage(any(MimeMessage.class),any());\n mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n } catch (MessagingException e) {\n throw new RuntimeException(e);\n }\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = Arrays.asList(\"test1@example.com\", \"test2@example.com\");\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://example.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 85, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead", + "reference": " @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 1,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n AstraConfigs.IndexerConfig headLocationAndNoRecoveryConfig =\n AstraConfigs.IndexerConfig.newBuilder()\n .setCreateRecoveryTasksOnStart(false)\n .setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST)\n .build();\n\n // When there is no data and ReadFromLocationOnStart is set to LATEST, return the current head\n assertThat(\n recoveryTaskCreator.determineStartingOffset(1000, 0, headLocationAndNoRecoveryConfig))\n .isEqualTo(1000);\n\n // Data exists for not for this partition.\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTime = 1;\n final long endTime = 100;\n final long maxOffset = 100;\n\n final SnapshotMetadata partition1 =\n new SnapshotMetadata(name, path, startTime, endTime, maxOffset, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition1);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition1));\n\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n\n final SnapshotMetadata partition11 =\n new SnapshotMetadata(\n name + \"1\", path, endTime + 1, endTime * 2, maxOffset * 2, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition11);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition11));\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(partition1, partition11);\n assertThat(recoveryTaskCreator.determineStartingOffset(0, 0, indexerConfig)).isNegative();\n\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n \"2\",\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(false);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n\n // When\n long result = determineStartingOffset(currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(currentEndOffsetForPartition, result);\n }\n```\n" + ], + "line": 128, + "token": 935, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 86, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime =\n LocalDateTime.of(2020, 10, 1, 10, 10, 0).atZone(ZoneOffset.UTC).toInstant();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n long startOffset = 1;\n long endOffset = msgsToProduce - 1;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry))\n .isEqualTo(endOffset - startOffset + 1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(0);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(0);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isEmpty();\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 0,\n 100,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 5, + "token_diff": 3 + }, + { + "id": 87, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"3g\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n String expirationTime = \"10x\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(expirationTime);\n assertEquals(CacheControlHeader.DEFAULT_BASE_VALUE + CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 88, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_empty_collection()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_empty_collection", + "reference": " @Test\n public void should_clone_empty_collection() {\n List original = new ArrayList<>();\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_empty_collection() {\n // Given\n Collection collection = new ArrayList<>();\n\n // When\n Collection clonedCollection = (Collection) clone(collection);\n\n // Then\n assertNotSame(collection, clonedCollection);\n assertEquals(collection, clonedCollection);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 89, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testTimeBoundSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testTimeBoundSearch", + "reference": " @Test\n public void testTimeBoundSearch() {\n Instant time = Instant.now();\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(1, time));\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(2, time.plusSeconds(100)));\n strictLogStore.logStore.commit();\n strictLogStore.logStore.refresh();\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(2);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(0);\n assertThat(getTimerCount(REFRESHES_TIMER, strictLogStore.metricsRegistry)).isEqualTo(1);\n\n // Start inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Extended range still only picking one element.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(90).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Both ranges are inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.toEpochMilli(),\n time.plusSeconds(100).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n // Extended range to pick up both events\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(1000).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n }", + "completions": [ + "```java\n @Test\n public void testTimeBoundSearch() {\n // Given\n String dataset = \"test_dataset\";\n String queryStr = \"test_query\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n for (LogMessage logMessage : result.getResults()) {\n assertTrue(logMessage.getTimestampMs() >= startTimeMsEpoch);\n assertTrue(logMessage.getTimestampMs() <= endTimeMsEpoch);\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 90, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "\t@ParameterizedTest\n\t@MethodSource", + "reference": "\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\t// setup\n\t\tString key = \"key\";\n\t\tTestParameters parameters = new TestParameters();\n\t\tObject testValue = theParams.getTestValue();\n\n\t\t// test\n\t\tif (theParams.isExpectedToWork()) {\n\t\t\tparameters.setUserData(key, testValue);\n\t\t\tassertFalse(parameters.getUserData().isEmpty());\n\t\t\tassertEquals(testValue, parameters.getUserData().get(key));\n\t\t} else {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, testValue);\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tString dataType = testValue.getClass().getName();\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid data type provided \" + dataType),\n\t\t\t\t\tex.getMessage());\n\t\t\t\tassertTrue(parameters.getUserData().isEmpty());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\tUser user = new User();\n\t\tuser.setUserData(theParams.getKey(), theParams.getValue());\n\t\tassertEquals(theParams.getValue(), user.getUserData().get(theParams.getKey()));\n\t}\n\n\tstatic Stream parameters() {\n\t\treturn Stream.of(\n\t\t\t\tnew TestParam(\"stringKey\", \"stringValue\"),\n\t\t\t\tnew TestParam(\"numberKey\", 123),\n\t\t\t\tnew TestParam(\"booleanKey\", true),\n\t\t\t\tnew TestParam(\"nullKey\", null)\n\t\t);\n\t}\n\n\tstatic class TestParam {\n\t\tprivate final String key;\n\t\tprivate final Object value;\n\n\t\tTestParam(String key, Object value) {\n\t\t\tthis.key = key;\n\t\t\tthis.value = value;\n\t\t}\n\n\t\tString getKey() {\n\t\t\treturn key;\n\t\t}\n\n\t\tObject getValue() {\n\t\t\treturn value;\n\t\t}\n\t}\n```\n" + ], + "line": 24, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\n\t\tString text = \"Hello, how are you?\";\n\n\t\tEmbedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));\n\n\t\tserver\n\t\t\t.expect(requestToUriTemplate(\"/models/{generative}:embedText?key={apiKey}\",\n\t\t\t\t\tVertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))\n\t\t\t.andExpect(method(HttpMethod.POST))\n\t\t\t.andExpect(content().json(objectMapper.writeValueAsString(Map.of(\"text\", text))))\n\t\t\t.andRespond(withSuccess(objectMapper.writeValueAsString(Map.of(\"embedding\", expectedEmbedding)),\n\t\t\t\t\tMediaType.APPLICATION_JSON));\n\n\t\tEmbedding embedding = client.embedText(text);\n\n\t\tassertThat(embedding).isEqualTo(expectedEmbedding);\n\n\t\tserver.verify();\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\tEmbeddingResponse response = new EmbeddingResponse(expectedEmbedding);\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(response);\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Mono syncAclWithAclCsv(KafkaCluster cluster, String csv) {\n return adminClientService.get(cluster)\n .flatMap(ac -> ac.listAcls(ResourcePatternFilter.ANY).flatMap(existingAclList -> {\n var existingSet = Set.copyOf(existingAclList);\n var newAcls = Set.copyOf(AclCsv.parseCsv(csv));\n var toDelete = Sets.difference(existingSet, newAcls);\n var toAdd = Sets.difference(newAcls, existingSet);\n logAclSyncPlan(cluster, toAdd, toDelete);\n if (toAdd.isEmpty() && toDelete.isEmpty()) {\n return Mono.empty();\n }\n log.info(\"Starting new ACLs creation\");\n return ac.createAcls(toAdd)\n .doOnSuccess(v -> {\n log.info(\"{} new ACLs created\", toAdd.size());\n log.info(\"Starting ACLs deletion\");\n })\n .then(ac.deleteAcls(toDelete)\n .doOnSuccess(v -> log.info(\"{} ACLs deleted\", toDelete.size())));\n }));\n }\n```\n###test function signature:\n```java\n @Test\n void testSyncAclWithAclCsv()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSyncAclWithAclCsv", + "reference": " @Test\n void testSyncAclWithAclCsv() {\n var existingBinding1 = new AclBinding(\n new ResourcePattern(ResourceType.TOPIC, \"*\", PatternType.LITERAL),\n new AccessControlEntry(\"User:test1\", \"*\", AclOperation.READ, AclPermissionType.ALLOW));\n\n var existingBinding2 = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"group1\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test2\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n var newBindingToBeAdded = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"groupNew\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test3\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n when(adminClientMock.listAcls(ResourcePatternFilter.ANY))\n .thenReturn(Mono.just(List.of(existingBinding1, existingBinding2)));\n\n ArgumentCaptor> createdCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.createAcls(createdCaptor.capture()))\n .thenReturn(Mono.empty());\n\n ArgumentCaptor> deletedCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.deleteAcls(deletedCaptor.capture()))\n .thenReturn(Mono.empty());\n\n aclsService.syncAclWithAclCsv(\n CLUSTER,\n \"Principal,ResourceType, PatternType, ResourceName,Operation,PermissionType,Host\\n\"\n + \"User:test1,TOPIC,LITERAL,*,READ,ALLOW,*\\n\"\n + \"User:test3,GROUP,PREFIXED,groupNew,DESCRIBE,DENY,localhost\"\n ).block();\n\n Collection createdBindings = createdCaptor.getValue();\n assertThat(createdBindings)\n .hasSize(1)\n .contains(newBindingToBeAdded);\n\n Collection deletedBindings = deletedCaptor.getValue();\n assertThat(deletedBindings)\n .hasSize(1)\n .contains(existingBinding2);\n }", + "completions": [ + "```java\n @Test\n void testSyncAclWithAclCsv() {\n // Given\n KafkaCluster cluster = new KafkaCluster(\"testCluster\", \"localhost:9092\");\n String csv = \"principal,host,operation,permissionType\\nUser:Alice,*,Read,Allow\\nUser:Bob,*,Write,Allow\";\n AdminClient adminClient = Mockito.mock(AdminClient.class);\n when(adminClientService.get(cluster)).thenReturn(Mono.just(adminClient));\n when(adminClient.listAcls(ResourcePatternFilter.ANY)).thenReturn(Flux.just(\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Alice\", \"*\", AclOperation.READ, AclPermissionType.ALLOW)),\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Bob\", \"*\", AclOperation.WRITE, AclPermissionType.ALLOW))\n ).collectList());\n\n // When\n Mono result = service.syncAclWithAclCsv(cluster, csv);\n\n // Then\n StepVerifier.create(result)\n .verifyComplete();\n\n verify(adminClient).createAcls(any());\n verify(adminClient).deleteAcls(any());\n }\n```\n" + ], + "line": 32, + "token": 259, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 93, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyNoResults", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final SearchResponse mockResponse = mock(SearchResponse.class);\n final SearchHits mockHits = mock(SearchHits.class);\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n when(mockHits.getHits()).thenReturn(new SearchHit[0]);\n when(tenantAwareClient.search(any(SearchRequest.class))).thenReturn(mockResponse);\n\n // When\n try {\n processInstanceService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 94, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyMultipleRecoveryBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(3);\n assertThat(recoveryTasks).contains(recoveryTask1, recoveryTask11);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo(recoveryStartOffset * 3 + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(100);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 95, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n private String saveAndChecksumBlockList(String jobName, LazyBlockList blockList) {\n BlobId blobId = getBlobId(jobName, getBlockListFileName(blockList.getName()));\n try (BufferedOutputStream gcsWriter =\n new BufferedOutputStream(gcsUtils.openOutputStream(blobId))) {\n MessageDigest messageDigest = MessageDigest.getInstance(checksumAlgorithm);\n blockList.consumeAll(\n (byteArray, length) -> {\n try {\n gcsWriter.write(byteArray, 0, length);\n } catch (IOException e) {\n throw new RuntimeException(e);\n }\n messageDigest.update(byteArray, 0, length);\n });\n return base16().lowerCase().encode(messageDigest.digest());\n } catch (IOException | NoSuchAlgorithmException e) {\n throw new RuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "saveAndChecksumBlockList_success", + "reference": " @Test\n void saveAndChecksumBlockList_success() throws Exception {\n String payload = \"somedata\\n\";\n String payloadChecksum = \"0737c8e591c68b93feccde50829aca86a80137547d8cfbe96bab6b20f8580c63\";\n\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream((\"bsa-checksum\\n\" + payload).getBytes(UTF_8)));\n lazyBlockList = new LazyBlockList(BlockListType.BLOCK, connection);\n\n ImmutableMap checksums =\n gcsClient.saveAndChecksumBlockList(\"some-name\", ImmutableList.of(lazyBlockList));\n assertThat(gcsUtils.existsAndNotEmpty(BlobId.of(\"my-bucket\", \"some-name/BLOCK.csv\"))).isTrue();\n assertThat(checksums).containsExactly(BlockListType.BLOCK, payloadChecksum);\n assertThat(gcsClient.readBlockList(\"some-name\", BlockListType.BLOCK))\n .containsExactly(\"somedata\");\n }", + "completions": [ + "```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception {\n // Given\n String jobName = \"testJob\";\n LazyBlockList blockList = mock(LazyBlockList.class);\n BlobId blobId = mock(BlobId.class);\n BufferedOutputStream gcsWriter = mock(BufferedOutputStream.class);\n MessageDigest messageDigest = mock(MessageDigest.class);\n when(gcsUtils.openOutputStream(blobId)).thenReturn(gcsWriter);\n when(checksumAlgorithm).thenReturn(\"SHA-256\");\n when(messageDigest.digest()).thenReturn(new byte[] {1, 2, 3, 4, 5});\n when(base16().lowerCase().encode(messageDigest.digest())).thenReturn(\"0102030405\");\n\n // When\n String checksum = saveAndChecksumBlockList(jobName, blockList);\n\n // Then\n assertEquals(\"0102030405\", checksum);\n verify(gcsWriter).write(any(), anyInt(), anyInt());\n verify(messageDigest).update(any(), anyInt(), anyInt());\n }\n```\n" + ], + "line": 30, + "token": 211, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ListenableFuture> findMissingBlobs(\n Iterable blobDigests, RequestMetadata requestMetadata) {\n // Some requests have been blocked, and we should tell the client we refuse to perform a lookup.\n try {\n if (inDenyList(requestMetadata)) {\n return immediateFailedFuture(\n Status.UNAVAILABLE\n .withDescription(\"The action associated with this request is forbidden\")\n .asException());\n }\n } catch (IOException e) {\n return immediateFailedFuture(Status.fromThrowable(e).asException());\n }\n\n // Empty blobs are an exceptional case. Filter them out.\n // If the user only requested empty blobs we can immediately tell them we already have it.\n Iterable nonEmptyDigests =\n Iterables.filter(blobDigests, (digest) -> digest.getSizeBytes() != 0);\n if (Iterables.isEmpty(nonEmptyDigests)) {\n return immediateFuture(ImmutableList.of());\n }\n\n if (configs.getServer().isFindMissingBlobsViaBackplane()) {\n return findMissingBlobsViaBackplane(nonEmptyDigests, requestMetadata);\n }\n\n return findMissingBlobsQueryingEachWorker(nonEmptyDigests, requestMetadata);\n }\n```\n###test function signature:\n```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "findMissingBlobsTest_ViaBackPlane", + "reference": " @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n Set activeWorkers = ImmutableSet.of(\"worker1\", \"worker2\", \"worker3\");\n Set expiredWorkers = ImmutableSet.of(\"workerX\", \"workerY\", \"workerZ\");\n Set imposterWorkers = ImmutableSet.of(\"imposter1\", \"imposter2\", \"imposter3\");\n\n Set availableDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFound1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build());\n\n Set missingDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"missing1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build());\n\n Set digestAvailableOnImposters =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFoundOnImposter1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter3\").setSizeBytes(1).build());\n\n Set emptyDigests =\n new HashSet<>(\n Arrays.asList(\n Digest.newBuilder().setHash(\"empty1\").build(),\n Digest.newBuilder().setHash(\"empty2\").build()));\n\n Iterable allDigests =\n Iterables.concat(\n availableDigests,\n missingDigests,\n emptyDigests,\n digestAvailableOnImposters,\n Arrays.asList(\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build()));\n\n Map> digestAndWorkersMap = new HashMap<>();\n\n for (Digest digest : availableDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(activeWorkers));\n }\n for (Digest digest : missingDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(expiredWorkers));\n }\n for (Digest digest : digestAvailableOnImposters) {\n digestAndWorkersMap.put(digest, getRandomSubset(imposterWorkers));\n }\n\n BuildfarmConfigs buildfarmConfigs = instance.getBuildFarmConfigs();\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(true);\n Set activeAndImposterWorkers =\n Sets.newHashSet(Iterables.concat(activeWorkers, imposterWorkers));\n\n when(mockBackplane.getStorageWorkers()).thenReturn(activeAndImposterWorkers);\n when(mockBackplane.getBlobDigestsWorkers(any(Iterable.class))).thenReturn(digestAndWorkersMap);\n when(mockInstanceLoader.load(anyString())).thenReturn(mockWorkerInstance);\n when(mockWorkerInstance.findMissingBlobs(anyIterable(), any(RequestMetadata.class)))\n .thenReturn(Futures.immediateFuture(new ArrayList<>()));\n\n long serverStartTime = 1686951033L; // june 15th, 2023\n Map workersStartTime = new HashMap<>();\n for (String worker : activeAndImposterWorkers) {\n workersStartTime.put(worker, serverStartTime);\n }\n when(mockBackplane.getWorkersStartTimeInEpochSecs(activeAndImposterWorkers))\n .thenReturn(workersStartTime);\n long oneDay = 86400L;\n for (Digest digest : availableDigests) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime + oneDay);\n }\n for (Digest digest : digestAvailableOnImposters) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime - oneDay);\n }\n\n Iterable actualMissingDigests =\n instance.findMissingBlobs(allDigests, RequestMetadata.getDefaultInstance()).get();\n Iterable expectedMissingDigests =\n Iterables.concat(missingDigests, digestAvailableOnImposters);\n\n assertThat(actualMissingDigests).containsExactlyElementsIn(expectedMissingDigests);\n verify(mockWorkerInstance, atMost(3))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n verify(mockWorkerInstance, atLeast(1))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n\n for (Digest digest : actualMissingDigests) {\n assertThat(digest).isNotIn(availableDigests);\n assertThat(digest).isNotIn(emptyDigests);\n assertThat(digest).isIn(expectedMissingDigests);\n }\n\n // reset BuildfarmConfigs\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(false);\n }", + "completions": [ + "```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n // Given\n Iterable blobDigests = Arrays.asList(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n RequestMetadata requestMetadata = RequestMetadata.newBuilder().setClientId(\"client1\").build();\n\n // When\n ListenableFuture> future =\n underTest.findMissingBlobs(blobDigests, requestMetadata);\n\n // Then\n Iterable missingBlobs = future.get();\n assertThat(missingBlobs).containsExactly(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n }\n```\n" + ], + "line": 37, + "token": 276, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 97, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent_notFound()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent_notFound", + "reference": " @Test\n public void testReadComponent_notFound() {\n Component component = mock(Component.class);\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(null);\n try {\n underTest.readComponent(\"someid\", \"testRepositoryName\");\n fail(\"Exception should have been thrown\");\n }\n catch (WebApplicationException e) {\n assertThat(e.getResponse(), is(notNullValue()));\n assertThat(e.getResponse().getStatus(), is(404));\n }\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent_notFound() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(() -> storageTx);\n when(storageTx.findComponent(componentId)).thenReturn(null);\n\n // When\n try {\n readComponent(repository, componentId);\n fail(\"Expected WebApplicationException\");\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Status.NOT_FOUND.getStatusCode(), e.getResponse().getStatus());\n }\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture auth(MQTT3AuthData authData) {\n try {\n HttpRequest request = HttpRequest.newBuilder()\n .uri(URI.create(webhookURI + \"/auth\"))\n .header(\"Content-Type\", \"application/json\")\n .POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))\n .timeout(Duration.ofSeconds(5))\n .build();\n return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())\n .thenApply(response -> {\n if (response.statusCode() == 200) {\n try {\n MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();\n JsonFormat.parser()\n .ignoringUnknownFields()\n .merge(response.body(), resultBuilder);\n return resultBuilder.build();\n } catch (Throwable e) {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build();\n }\n } else {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(\"Authenticate failed\")\n .build())\n .build();\n }\n })\n .exceptionally(e -> {\n System.out.println(\"Failed to call webhook: \" + e.getMessage());\n return null;\n });\n } catch (Throwable e) {\n return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAuth()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAuth", + "reference": " @Test\n public void testAuth() {\n MQTT3AuthData authedUser = MQTT3AuthData.newBuilder()\n .setUsername(\"authUser\")\n .build();\n\n MQTT3AuthData unauthUser = MQTT3AuthData.newBuilder()\n .setUsername(\"unauthUser\")\n .build();\n testServer.addAuthedUser(authedUser.getUsername());\n WebHookBasedAuthProvider provider = new WebHookBasedAuthProvider(testServer.getURI());\n MQTT3AuthResult authResult = provider.auth(authedUser).join();\n assertTrue(authResult.hasOk());\n\n authResult = provider.auth(unauthUser).join();\n assertFalse(authResult.hasOk());\n }", + "completions": [ + "```java\n @Test\n public void testAuth() {\n // Given\n String webhookURI = \"http://localhost:8080\";\n HttpClient httpClient = HttpClient.newHttpClient();\n MQTT3AuthData authData = MQTT3AuthData.newBuilder()\n .setClientId(\"client1\")\n .setUsername(\"user1\")\n .setPassword(\"password1\")\n .build();\n MQTT3AuthResult expectedResult = MQTT3AuthResult.newBuilder()\n .setSuccess(Success.newBuilder()\n .setIsSuperuser(false)\n .build())\n .build();\n\n // When\n CompletableFuture result = new MQTT3AuthService(webhookURI, httpClient).auth(authData);\n\n // Then\n assertNotNull(result);\n assertEquals(expectedResult, result.join());\n }\n```\n" + ], + "line": 59, + "token": 335, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 99, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_success", + "reference": " @Test\n public void testValidateRoleDuplicate_success() {\n // 调用,不会抛异常\n roleService.validateRoleDuplicate(randomString(), randomString(), null);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_success() {\n // Given\n String name = \"TestRole\";\n String code = \"TEST_ROLE\";\n Long id = 1L;\n\n // When\n // Call the method to be tested\n validateRoleDuplicate(name, code, id);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/8/EI/line_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Java/8/EI/line_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..3d1ec685293ee9b5213b539dcfa383dd1be4abca --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/8/EI/line_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(22~35),subset_1(36~47),subset_2(51~59),subset_3(63~74),subset_4(85~85),subset_5(94~99),subset_6(109~109),subset_7(127~128) +StarCoder2-15b,31.16,34.87,25.15,25.38,23.55,15.04,12.46,24.52 +CodeLlama-7b,35.27,33.04,26.75,32.12,19.82,23.35,28.90,20.99 +CodeLlama-13b,31.56,30.63,26.74,27.99,21.71,20.23,22.37,23.66 +CodeLlama-34b,31.88,31.95,28.72,28.39,21.69,23.56,18.32,25.67 +DeepSeek-Coder-1.3b,30.50,31.16,32.20,27.09,23.99,18.36,20.52,23.89 +DeepSeek-Coder-6.7b,20.15,20.35,21.16,24.80,22.51,20.24,24.00,24.81 +DeepSeek-Coder-33b,34.52,35.59,27.48,32.75,23.15,18.64,28.98,18.29 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/8/EI/token_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Java/8/EI/token_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..6d5b28c36391787e214f368c511f4bb8a032c1e8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/8/EI/token_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~234),subset_1(241~335),subset_2(346~424),subset_3(486~544),subset_4(602~608),subset_5(664~681),subset_6(793~793),subset_7(935~952) +StarCoder2-15b,33.94,24.98,39.16,20.49,22.87,18.67,16.57,23.59 +CodeLlama-7b,34.71,30.29,36.29,25.76,19.97,25.34,27.45,23.19 +CodeLlama-13b,30.54,27.72,34.00,25.90,21.80,21.82,16.48,22.53 +CodeLlama-34b,32.93,26.91,36.18,24.77,20.97,22.07,26.15,26.04 +DeepSeek-Coder-1.3b,27.61,30.61,36.09,25.55,24.11,21.22,8.33,20.71 +DeepSeek-Coder-6.7b,19.56,23.79,18.57,20.93,21.89,25.76,25.63,22.94 +DeepSeek-Coder-33b,32.95,31.10,39.09,27.46,23.42,26.03,26.97,19.38 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/8/EI/tongji.py b/dataset/Test Generation/ComplexCodeEval-Java/8/EI/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/8/EI/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/8/QS/QS.json b/dataset/Test Generation/ComplexCodeEval-Java/8/QS/QS.json new file mode 100644 index 0000000000000000000000000000000000000000..741daacf24ea8eb7507b2519fbc847f29068ecd0 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/8/QS/QS.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTask", + "reference": " @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(TEST_S3_BUCKET);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTask() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 100L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockKafkaConsumer.consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong())).thenReturn(true);\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(true);\n Mockito.when(mockKafkaConsumer.close()).thenReturn(true);\n Mockito.when(mockChunkManager.stopAsync()).thenReturn(true);\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenReturn(true);\n Mockito.when(AstraKafkaConsumer.makeKafkaConfig(Mockito.any(), Mockito.anyInt())).thenReturn(new Properties());\n Mockito.when(RecoveryChunkManager.fromConfig(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any())).thenReturn(mockChunkManager);\n Mockito.when(adminClient.listConsumerGroupOffsets(Mockito.anyString())).thenReturn(new ConsumerGroupOffsets());\n Mockito.when(AstraConfig.getRecoveryConfig()).thenReturn(new RecoveryConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig()).thenReturn(new KafkaConfig());\n Mockito.when(AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic()).thenReturn(\"testTopic\");\n Mockito.when(AstraConfig.getIndexerConfig()).thenReturn(new IndexerConfig());\n Mockito.when(AstraConfig.getS3Config()).thenReturn(new S3Config());\n Mockito.when(meterRegistry.timer(Mockito.anyString())).thenReturn(Mockito.mock(Timer.class));\n Mockito.when(Timer.start(Mockito.any())).thenReturn(Mockito.mock(Timer.Sample.class));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n Assert.assertTrue(result);\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).prepConsumerForConsumption(Mockito.anyLong());\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).consumeMessagesBetweenOffsetsInParallel(Mockito.anyLong(), Mockito.anyLong(), Mockito.anyLong());\n Mockito.verify(mockChunkManager, Mockito.times(1)).waitForRollOvers();\n Mockito.verify(mockKafkaConsumer, Mockito.times(1)).close();\n Mockito.verify(mockChunkManager, Mockito.times(1)).stopAsync();\n Mockito.verify(mockChunkManager, Mockito.times(1)).awaitTerminated(Mockito.anyLong());\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 7, + "token_diff": 6 + }, + { + "id": 1, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskFailure", + "reference": " @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n String fakeS3Bucket = \"fakeBucket\";\n AstraConfigs.AstraConfig astraCfg = makeAstraConfig(fakeS3Bucket);\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n // Start recovery service\n recoveryService = new RecoveryService(astraCfg, curatorFramework, meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Populate data in Kafka so we can recover from it.\n final Instant startTime = Instant.now();\n produceMessagesToKafka(kafkaServer.getBroker(), startTime, TEST_KAFKA_TOPIC_1, 0);\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\"testRecoveryTask\", \"0\", 30, 60, Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isFalse();\n\n assertThat(s3AsyncClient.listBuckets().get().buckets().size()).isEqualTo(1);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name()).isEqualTo(TEST_S3_BUCKET);\n assertThat(s3AsyncClient.listBuckets().get().buckets().get(0).name())\n .isNotEqualTo(fakeS3Bucket);\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, meterRegistry)).isEqualTo(31);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(1);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskFailure() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\"testTask\", 0, 0L, 10L, System.currentTimeMillis());\n AstraKafkaConsumer mockKafkaConsumer = Mockito.mock(AstraKafkaConsumer.class);\n RecoveryChunkManager mockChunkManager = Mockito.mock(RecoveryChunkManager.class);\n Mockito.when(mockKafkaConsumer.prepConsumerForConsumption(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.waitForRollOvers()).thenReturn(false);\n Mockito.when(mockKafkaConsumer.close()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.stopAsync()).thenThrow(new RuntimeException(\"Test exception\"));\n Mockito.when(mockChunkManager.awaitTerminated(Mockito.anyLong())).thenThrow(new RuntimeException(\"Test exception\"));\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertFalse(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 7, + "token_diff": 6 + }, + { + "id": 2, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public int migration(String srcTable, List uuids, String targetTable) {\n if (Objects.isNull(srcTable) || Objects.isNull(targetTable) || CollectionUtils.isEmpty(uuids)) {\n throw new SwValidationException(\n ValidSubject.EVALUATION, \"srcTable, targetTable and uuids must not be null\");\n }\n var filter = TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build();\n return store.migration(DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(filter)\n .build()\n );\n }\n```\n###test function signature:\n```java\n @Test\n public void testMigration()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMigration", + "reference": " @Test\n public void testMigration() {\n var srcTable = \"p/1/space/1/eval/summary\";\n var targetTable = \"p/1/eval/summary\";\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(null, List.of(\"uuid1\"), targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, null, targetTable));\n assertThrows(SwValidationException.class, () ->\n evaluationRepo.migration(srcTable, List.of(\"uuid1\"), null));\n\n evaluationRepo.migration(srcTable, List.of(\"uuid1\", \"uuid2\"), targetTable);\n\n verify(dataStore, times(1)).migration(\n DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(\n TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(List.of(\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid1\")))\n .build(),\n TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, \"uuid2\")))\n .build()))\n .build())\n .build()\n );\n }", + "completions": [ + "```java\n @Test\n public void testMigration() {\n // Given\n String srcTable = \"srcTable\";\n List uuids = Arrays.asList(\"uuid1\", \"uuid2\", \"uuid3\");\n String targetTable = \"targetTable\";\n DataStoreMigrationRequest expectedRequest = DataStoreMigrationRequest.builder()\n .srcTableName(srcTable)\n .targetTableName(targetTable)\n .filter(TableQueryFilter.builder()\n .operator(Operator.OR)\n .operands(uuids.stream()\n .map(uuid -> TableQueryFilter.builder()\n .operator(Operator.EQUAL)\n .operands(List.of(\n new TableQueryFilter.Column(KeyColumn),\n new TableQueryFilter.Constant(\n ColumnType.STRING, uuid)\n ))\n .build())\n .collect(Collectors.toList())\n )\n .build()\n )\n .build();\n when(store.migration(any(DataStoreMigrationRequest.class))).thenReturn(1);\n\n // When\n int result = migrationService.migration(srcTable, uuids, targetTable);\n\n // Then\n assertEquals(1, result);\n verify(store).migration(expectedRequest);\n }\n```\n" + ], + "line": 37, + "token": 221, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 3, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDatasMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDatasMapTest", + "reference": " @Test\n public void sendDatasMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"doris_test_db\";\n List> dataList = new ArrayList<>();\n for (int i = 0; i < 100; i++) {\n Map data = new HashMap<>();\n data.put(\"id\", i);\n data.put(\"name\", \"张三\");\n data.put(\"message\", \"股东风波的法规的规定翻跟斗广泛的给yryrtytr\");\n System.out.println(gson.toJson(data));\n dataList.add(data);\n\n }\n dorisStreamLoad.sendData(database, tableName, columnList, dataList);\n }", + "completions": [ + "```java\n @Test\n public void sendDataTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 4, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_serializable_key_and_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_serializable_key_and_value", + "reference": " @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map original = new HashMap<>();\n original.put(\"key\", new SerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_serializable_key_and_value() {\n Map originalMap = new HashMap<>();\n originalMap.put(\"one\", 1);\n originalMap.put(\"two\", 2);\n originalMap.put(\"three\", 3);\n\n Map clonedMap = clone(originalMap);\n\n assertNotSame(originalMap, clonedMap);\n assertEquals(originalMap, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 5, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void encode(String schemaName, long nextTransactionOffset) {\n if (schemaName == null) {\n schemaName = UNKNOWN;\n }\n this.schemaName = schemaName;\n this.nextTransactionOffset = nextTransactionOffset;\n final byte[] payloadBytes = payloadToBytes();\n final int payloadLength = payloadBytes.length;\n\n // set logEventHeader\n int eventSize = eventHeaderLengthVersionGt1 + payloadLength;\n setLogEventHeader(new LogEventHeader(drc_filter_log_event.getType(), 0, eventSize, nextTransactionOffset)\n );\n\n // set payload\n final ByteBuf payloadByteBuf = PooledByteBufAllocator.DEFAULT.directBuffer(payloadLength);\n payloadByteBuf.writeBytes(payloadBytes);\n payloadByteBuf.skipBytes(payloadLength);\n setPayloadBuf(payloadByteBuf);\n }\n```\n###test function signature:\n```java\n @Test\n public void encode()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "encode", + "reference": " @Test\n public void encode() {\n FilterLogEvent filterLogEvent = new FilterLogEvent();\n filterLogEvent.encode(\"drc1\", 101);\n Assert.assertEquals(\"drc1\", filterLogEvent.getSchemaName());\n Assert.assertEquals(101, filterLogEvent.getNextTransactionOffset());\n\n ByteBuf headerBuf = filterLogEvent.getLogEventHeader().getHeaderBuf();\n headerBuf.readerIndex(0);\n ByteBuf payloadBuf = filterLogEvent.getPayloadBuf();\n payloadBuf.readerIndex(0);\n CompositeByteBuf compositeByteBuf = PooledByteBufAllocator.DEFAULT.compositeDirectBuffer();\n compositeByteBuf.addComponents(true, headerBuf, payloadBuf);\n\n FilterLogEvent newFilterLogEvent = new FilterLogEvent();\n newFilterLogEvent.read(compositeByteBuf);\n compositeByteBuf.release(compositeByteBuf.refCnt());\n Assert.assertEquals(\"drc1\", newFilterLogEvent.getSchemaName());\n Assert.assertEquals(101, newFilterLogEvent.getNextTransactionOffset());\n }", + "completions": [ + "```java\n @Test\n public void encode() {\n // Given\n String schemaName = \"testSchema\";\n long nextTransactionOffset = 1234567890L;\n drc_filter_log_event event = new drc_filter_log_event();\n\n // When\n event.encode(schemaName, nextTransactionOffset);\n\n // Then\n assertEquals(schemaName, event.getSchemaName());\n assertEquals(nextTransactionOffset, event.getNextTransactionOffset());\n assertEquals(drc_filter_log_event.getType(), event.getLogEventHeader().getType());\n assertEquals(0, event.getLogEventHeader().getServerId());\n assertEquals(eventHeaderLengthVersionGt1 + event.getPayloadBuf().readableBytes(), event.getLogEventHeader().getEventLength());\n assertEquals(nextTransactionOffset, event.getLogEventHeader().getNextTransactionOffset());\n assertEquals(event.getPayloadBuf().readableBytes(), event.getPayloadBuf().readableBytes());\n }\n```\n" + ], + "line": 29, + "token": 183, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 6, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void put(String path, long position, ByteBuf data) {\n writeLock.lock();\n try {\n int dataLength = data.readableBytes();\n NavigableMap cache = path2cache.computeIfAbsent(path, k -> new TreeMap<>());\n Map.Entry pos2value = cache.floorEntry(position);\n long cacheStartPosition;\n Value value;\n if (pos2value == null || pos2value.getKey() + pos2value.getValue().dataLength < position) {\n cacheStartPosition = position;\n value = Value.EMPTY;\n } else {\n cacheStartPosition = pos2value.getKey();\n value = pos2value.getValue();\n }\n // ensure the capacity, if the capacity change then update the cache index\n int moreCapacity = (int) ((position + dataLength) - (cacheStartPosition + value.blocks.length * (long) blockSize));\n int newDataLength = (int) (position + dataLength - cacheStartPosition);\n if (moreCapacity > 0) {\n int[] blocks = ensureCapacity(cacheStartPosition, moreCapacity);\n if (blocks == null) {\n return;\n }\n int[] newBlocks = new int[value.blocks.length + blocks.length];\n System.arraycopy(value.blocks, 0, newBlocks, 0, value.blocks.length);\n System.arraycopy(blocks, 0, newBlocks, value.blocks.length, blocks.length);\n value = new Value(newBlocks, newDataLength);\n } else {\n value = new Value(value.blocks, newDataLength);\n }\n cache.put(cacheStartPosition, value);\n lru.put(new Key(path, cacheStartPosition), value);\n\n // write data to cache\n ByteBuffer cacheByteBuffer = this.cacheByteBuffer.duplicate();\n int positionDelta = (int) (position - cacheStartPosition);\n int written = 0;\n ByteBuffer[] nioBuffers = data.nioBuffers();\n int[] blocks = value.blocks;\n for (ByteBuffer nioBuffer : nioBuffers) {\n ByteBuf buf = Unpooled.wrappedBuffer(nioBuffer);\n while (buf.readableBytes() > 0) {\n int writePosition = positionDelta + written;\n int block = blocks[writePosition / blockSize];\n cacheByteBuffer.position(block * blockSize + writePosition % blockSize);\n int length = Math.min(buf.readableBytes(), blockSize - writePosition % blockSize);\n cacheByteBuffer.put(buf.slice(buf.readerIndex(), length).nioBuffer());\n buf.skipBytes(length);\n written += length;\n }\n }\n } finally {\n writeLock.unlock();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testMergePut() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMergePut", + "reference": " @Test\n public void testMergePut() throws IOException {\n FileCache fileCache = new FileCache(\"/tmp/file_cache_test\", 10 * 1024, 1024);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 1, 500));\n buf.addComponent(true, genBuf((byte) 2, 500));\n buf.addComponent(true, genBuf((byte) 3, 500));\n fileCache.put(\"test\", 3333L, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(1500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 4, 500));\n buf.addComponent(true, genBuf((byte) 5, 500));\n buf.addComponent(true, genBuf((byte) 6, 500));\n fileCache.put(\"test\", 3333L + 1000, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(2500, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n {\n CompositeByteBuf buf = Unpooled.compositeBuffer();\n buf.addComponent(true, genBuf((byte) 7, 500));\n fileCache.put(\"test\", 3333L + 1000 + 1500, buf);\n }\n assertEquals(1, fileCache.path2cache.get(\"test\").size());\n assertEquals(3000, fileCache.path2cache.get(\"test\").get(3333L).dataLength);\n\n assertTrue(verify(fileCache.get(\"test\", 3333L, 500).get(), (byte) 1));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 500, 500).get(), (byte) 2));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1000, 500).get(), (byte) 4));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 1500, 500).get(), (byte) 5));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2000, 500).get(), (byte) 6));\n assertTrue(verify(fileCache.get(\"test\", 3333L + 2500, 500).get(), (byte) 7));\n }", + "completions": [ + "```java\n @Test\n public void testMergePut() throws IOException {\n // Given\n String path = \"/test/path\";\n long position = 100;\n ByteBuf data = Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5});\n\n // When\n cache.put(path, position, data);\n\n // Then\n NavigableMap cacheMap = cache.path2cache.get(path);\n assertNotNull(cacheMap);\n Map.Entry entry = cacheMap.floorEntry(position);\n assertNotNull(entry);\n assertEquals(position, entry.getKey().longValue());\n assertEquals(5, entry.getValue().dataLength);\n assertEquals(1, entry.getValue().blocks[0]);\n assertEquals(2, entry.getValue().blocks[1]);\n assertEquals(3, entry.getValue().blocks[2]);\n assertEquals(4, entry.getValue().blocks[3]);\n assertEquals(5, entry.getValue().blocks[4]);\n }\n```\n" + ], + "line": 65, + "token": 508, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 7, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchById()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchById", + "reference": " @Test\n public void testSearchById() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult index =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(index.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchById() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 8, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = new ArrayList<>();\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 9, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_object", + "reference": " @Test\n public void should_clone_serializable_object() {\n Object original = new SerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_object() {\n // Given\n Serializable object = new Serializable() {\n private static final long serialVersionUID = 1L;\n // Implement necessary methods\n };\n\n // When\n Serializable clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertEquals(object, clonedObject);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 10, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() {\n\n\t\tvar text = \"Hello, how are you?\";\n\n\t\tEmbedding response = vertexAiPaLm2Api.embedText(text);\n\n\t\tassertThat(response).isNotNull();\n\t\tassertThat(response.value()).hasSize(768);\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(new EmbeddingResponse(expectedEmbedding));\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Optional maybeFetchErrorIntervalMs(short errorCode, int intervalMs) {\n if (errorCode == Errors.NONE.code())\n return Optional.empty();\n\n int pushIntervalMs;\n String reason;\n\n Errors error = Errors.forCode(errorCode);\n switch (error) {\n case INVALID_REQUEST:\n case INVALID_RECORD:\n case UNSUPPORTED_VERSION:\n pushIntervalMs = Integer.MAX_VALUE;\n reason = \"The broker response indicates the client sent an request that cannot be resolved\"\n + \" by re-trying, hence disable telemetry\";\n break;\n case UNKNOWN_SUBSCRIPTION_ID:\n case UNSUPPORTED_COMPRESSION_TYPE:\n pushIntervalMs = 0;\n reason = error.message();\n break;\n case TELEMETRY_TOO_LARGE:\n case THROTTLING_QUOTA_EXCEEDED:\n reason = error.message();\n pushIntervalMs = (intervalMs != -1) ? intervalMs : ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS;\n break;\n default:\n reason = \"Unwrapped error code\";\n log.error(\"Error code: {}. Unmapped error for telemetry, disable telemetry.\", errorCode);\n pushIntervalMs = Integer.MAX_VALUE;\n }\n\n log.debug(\"Error code: {}, reason: {}. Push interval update to {} ms.\", errorCode, reason, pushIntervalMs);\n return Optional.of(pushIntervalMs);\n }\n```\n###test function signature:\n```java\n @Test\n public void testMaybeFetchErrorIntervalMs()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testMaybeFetchErrorIntervalMs", + "reference": " @Test\n public void testMaybeFetchErrorIntervalMs() {\n assertEquals(Optional.empty(), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.NONE.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_REQUEST.code(), -1));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.INVALID_RECORD.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SUBSCRIPTION_ID.code(), -1));\n assertEquals(Optional.of(0), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNSUPPORTED_COMPRESSION_TYPE.code(), -1));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.TELEMETRY_TOO_LARGE.code(), 20000));\n assertEquals(Optional.of(ClientTelemetryReporter.DEFAULT_PUSH_INTERVAL_MS), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), -1));\n assertEquals(Optional.of(20000), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.THROTTLING_QUOTA_EXCEEDED.code(), 20000));\n assertEquals(Optional.of(Integer.MAX_VALUE), ClientTelemetryUtils.maybeFetchErrorIntervalMs(Errors.UNKNOWN_SERVER_ERROR.code(), -1));\n }", + "completions": [ + "```java\n @Test\n public void testMaybeFetchErrorIntervalMs() {\n // Test case 1: errorCode is Errors.NONE.code()\n short errorCode = Errors.NONE.code();\n int intervalMs = 1000;\n Optional result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertFalse(result.isPresent());\n\n // Test case 2: errorCode is Errors.INVALID_REQUEST.code()\n errorCode = Errors.INVALID_REQUEST.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 3: errorCode is Errors.UNKNOWN_SUBSCRIPTION_ID.code()\n errorCode = Errors.UNKNOWN_SUBSCRIPTION_ID.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 4: errorCode is Errors.TELEMETRY_TOO_LARGE.code()\n errorCode = Errors.TELEMETRY_TOO_LARGE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 5: errorCode is Errors.THROTTLING_QUOTA_EXCEEDED.code()\n errorCode = Errors.THROTTLING_QUOTA_EXCEEDED.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(intervalMs, result.get().intValue());\n\n // Test case 6: errorCode is Errors.UNSUPPORTED_COMPRESSION_TYPE.code()\n errorCode = Errors.UNSUPPORTED_COMPRESSION_TYPE.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(0, result.get().intValue());\n\n // Test case 7: errorCode is Errors.INVALID_RECORD.code()\n errorCode = Errors.INVALID_RECORD.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 8: errorCode is Errors.UNSUPPORTED_VERSION.code()\n errorCode = Errors.UNSUPPORTED_VERSION.code();\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n\n // Test case 9: errorCode is an unmapped error\n errorCode = (short) 9999;\n result = maybeFetchErrorIntervalMs(errorCode, intervalMs);\n assertTrue(result.isPresent());\n assertEquals(Integer.MAX_VALUE, result.get().intValue());\n }\n```\n" + ], + "line": 43, + "token": 290, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 12, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_non_serializable_object", + "reference": " @Test\n public void should_clone_non_serializable_object() {\n Object original = new NonSerializableObject(\"value\");\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_non_serializable_object() {\n // Given\n Map nonSerializableMap = new HashMap<>();\n nonSerializableMap.put(\"key1\", new Object());\n nonSerializableMap.put(\"key2\", new Object());\n\n // When\n Map clonedMap = objectCloner.clone(nonSerializableMap);\n\n // Then\n assertNotSame(nonSerializableMap, clonedMap);\n assertNotSame(nonSerializableMap.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(nonSerializableMap.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 13, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueOperationCompletesOperationWithCachedActionResult", + "reference": " @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ActionKey actionKey = DigestUtil.asActionKey(Digest.newBuilder().setHash(\"test\").build());\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"operation-with-cached-action-result\")\n .setActionDigest(actionKey.getDigest())\n .build();\n\n ActionResult actionResult = ActionResult.getDefaultInstance();\n\n when(mockBackplane.getActionResult(eq(actionKey))).thenReturn(actionResult);\n\n Poller poller = mock(Poller.class);\n\n instance.queue(executeEntry, poller, DEFAULT_TIMEOUT).get();\n\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(CACHE_CHECK));\n verify(mockBackplane, never()).putOperation(any(Operation.class), eq(QUEUED));\n verify(mockBackplane, times(1)).putOperation(any(Operation.class), eq(COMPLETED));\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueOperationCompletesOperationWithCachedActionResult() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n ActionResult actionResult = ActionResult.newBuilder().setExitCode(0).build();\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(\"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n when(cache.get(any(ActionKey.class))).thenReturn(actionResult);\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n future.get();\n\n verify(poller).pause();\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 14, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueActionFailsQueueEligibility", + "reference": " @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n Directory inputRoot = Directory.newBuilder().build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(false);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_INVALID)\n .setSubject(INVALID_PLATFORM)\n .setDescription(\n \"properties are not valid for queue eligibility: []. If you think your\"\n + \" queue should still accept these poperties without them being\"\n + \" specified in queue configuration, consider configuring the queue\"\n + \" with `allow_unmatched: True`\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueActionFailsQueueEligibility() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(1);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(ACTION_DIGEST));\n when(executeEntry.getStdoutStreamName()).thenReturn(STDOUT_STREAM_NAME);\n when(executeEntry.getStderrStreamName()).thenReturn(STDERR_STREAM_NAME);\n when(executeEntry.getOperationName()).thenReturn(OPERATION_NAME);\n when(executeEntry.getRequestMetadata()).thenReturn(REQUEST_METADATA);\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n when(poller.pause()).thenThrow(new RuntimeException(\"Queue eligibility failed\"));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n\n assertThrows(RuntimeException.class, future::get);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 15, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new SerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new SerializableObject(\"name2\"),\n new SerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_serializable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key\", new SerializableObject());\n SerializableObject object = new SerializableObject(map);\n\n // When\n SerializableObject clonedObject = objectCloner.clone(object);\n\n // Then\n assertNotSame(object, clonedObject);\n assertNotSame(object.getMap(), clonedObject.getMap());\n assertNotSame(object.getMap().get(\"key\"), clonedObject.getMap().get(\"key\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 16, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForSumAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForSumAgg", + "reference": " @Test\n public void testFullIndexSearchForSumAgg() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new SumAggBuilder(\"test\", TEST_SOURCE_LONG_PROPERTY, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalSum internalSum =\n (InternalSum) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n // 1, 3, 4, 5\n assertThat(internalSum.getValue()).isEqualTo(13);\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForSumAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new SumAggBuilder(\"testField\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertEquals(SumAggBuilder.class, result.getAggregation().getClass());\n assertEquals(\"testField\", ((SumAggBuilder) result.getAggregation()).getField());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 17, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void process(String base64AuthenticatorData, String signature, String clientDataJson, Fido2RegistrationData registration,\n Fido2AuthenticationData authenticationEntity) {\n log.debug(\"Registration: {}\", registration);\n\n AuthData authData = authenticatorDataParser.parseAssertionData(base64AuthenticatorData);\n commonVerifiers.verifyRpIdHash(authData, registration.getDomain());\n\n log.debug(\"User verification option: {}\", authenticationEntity.getUserVerificationOption());\n userVerificationVerifier.verifyUserVerificationOption(authenticationEntity.getUserVerificationOption(), authData);\n\n byte[] clientDataHash = DigestUtils.getSha256Digest().digest(base64Service.urlDecode(clientDataJson));\n\n try {\n int counter = authenticatorDataParser.parseCounter(authData.getCounters());\n commonVerifiers.verifyCounter(registration.getCounter(), counter);\n registration.setCounter(counter);\n\n JsonNode uncompressedECPointNode = dataMapperService.cborReadTree(base64Service.urlDecode(registration.getUncompressedECPoint()));\n PublicKey publicKey = coseService.createUncompressedPointFromCOSEPublicKey(uncompressedECPointNode);\n\n log.debug(\"Uncompressed ECpoint node: {}\", uncompressedECPointNode);\n log.debug(\"EC Public key hex: {}\", Hex.encodeHexString(publicKey.getEncoded()));\n log.debug(\"Registration algorithm: {}, default use: -7\", registration.getSignatureAlgorithm());\n authenticatorDataVerifier.verifyAssertionSignature(authData, clientDataHash, signature, publicKey, -7);\n\n } catch (Fido2CompromisedDevice ex) {\n log.error(\"Error compromised device: {}\", ex.getMessage());\n throw ex;\n } catch (Exception ex) {\n log.error(\"Error to check none assertion: {}\", ex.getMessage());\n throw new Fido2RuntimeException(\"Failed to check none assertion: {}\", ex.getMessage(), ex);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "process_ifCborReadTreeThrowException_fido2RuntimeException", + "reference": " @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n String base64AuthenticatorData = \"base64AuthenticatorData_test\";\n String signature = \"signature_test\";\n String clientDataJson = \"clientDataJson_test\";\n Fido2RegistrationData registration = mock(Fido2RegistrationData.class);\n Fido2AuthenticationData authenticationEntity = mock(Fido2AuthenticationData.class);\n\n when(authenticationEntity.getUserVerificationOption()).thenReturn(UserVerification.preferred);\n when(registration.getDomain()).thenReturn(\"domain_test\");\n when(registration.getCounter()).thenReturn(100);\n when(registration.getUncompressedECPoint()).thenReturn(\"uncompressedECPoint_test\");\n\n when(authenticatorDataParser.parseAssertionData(any())).thenReturn(mock(AuthData.class));\n when(base64Service.urlDecode(any(String.class))).thenReturn(\"decode_test\".getBytes());\n when(dataMapperService.cborReadTree(any(byte[].class))).thenThrow(new IOException(\"IOException_test\"));\n\n Fido2RuntimeException ex = assertThrows(Fido2RuntimeException.class, () -> noneAssertionFormatProcessor.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity));\n assertNotNull(ex);\n assertEquals(ex.getMessage(), \"IOException_test\");\n\n verify(log).debug(eq(\"Registration: {}\"), any(Fido2RegistrationData.class));\n verify(log).debug(eq(\"User verification option: {}\"), any(UserVerification.class));\n verify(commonVerifiers).verifyRpIdHash(any(AuthData.class), any(String.class));\n verify(authenticatorDataParser).parseCounter(any());\n verify(log).error(eq(\"Error to check none assertion: {}\"), any(String.class));\n\n verify(log, never()).error(eq(\"Error compromised device: {}\"), any(String.class));\n verifyNoInteractions(coseService, authenticatorDataVerifier);\n verifyNoMoreInteractions(log);\n }", + "completions": [ + "```java\n @Test\n void process_ifCborReadTreeThrowException_fido2RuntimeException() throws Fido2CompromisedDevice, IOException {\n // Arrange\n String base64AuthenticatorData = \"base64AuthenticatorData\";\n String signature = \"signature\";\n String clientDataJson = \"clientDataJson\";\n Fido2RegistrationData registration = new Fido2RegistrationData();\n Fido2AuthenticationData authenticationEntity = new Fido2AuthenticationData();\n\n when(authenticatorDataParser.parseAssertionData(base64AuthenticatorData)).thenReturn(new AuthData());\n when(authenticatorDataParser.parseCounter(any())).thenReturn(1);\n when(dataMapperService.cborReadTree(any())).thenThrow(new IOException());\n\n // Act & Assert\n Assertions.assertThrows(Fido2RuntimeException.class, () -> {\n fido2AuthenticationService.process(base64AuthenticatorData, signature, clientDataJson, registration, authenticationEntity);\n });\n }\n```\n" + ], + "line": 38, + "token": 365, + "line_diff": 2, + "token_diff": 4 + }, + { + "id": 18, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void setUserData_invalidKey_throws()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "setUserData_invalidKey_throws", + "reference": "\t@Test\n\tpublic void setUserData_invalidKey_throws() {\n\t\t// setup\n\t\tTestParameters parameters = new TestParameters();\n\n\t\t// test\n\t\tfor (String key : new String[] { null, \"\" }) {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, \"test\");\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid key; key must be non-empty, non-null\"),\n\t\t\t\t\tex.getMessage());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@Test(expected = IllegalArgumentException.class)\n\tpublic void setUserData_invalidKey_throws() {\n\t\tTestClass testClass = new TestClass();\n\t\ttestClass.setUserData(\"\", \"value\");\n\t}\n```\n\nIn this test, we are creating an instance of the TestClass and calling the setUserData method with an empty key and a valid value. We expect this to throw an IllegalArgumentException, which is what we specify in the @Test annotation.\n" + ], + "line": 23, + "token": 170, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n public ListenableFuture queue(ExecuteEntry executeEntry, Poller poller, Duration timeout) {\n ExecuteOperationMetadata metadata =\n ExecuteOperationMetadata.newBuilder()\n .setActionDigest(executeEntry.getActionDigest())\n .setStdoutStreamName(executeEntry.getStdoutStreamName())\n .setStderrStreamName(executeEntry.getStderrStreamName())\n .build();\n Operation operation =\n Operation.newBuilder()\n .setName(executeEntry.getOperationName())\n .setMetadata(Any.pack(metadata))\n .build();\n Digest actionDigest = executeEntry.getActionDigest();\n ActionKey actionKey = DigestUtil.asActionKey(actionDigest);\n\n Stopwatch stopwatch = Stopwatch.createStarted();\n ListenableFuture cachedResultFuture;\n if (executeEntry.getSkipCacheLookup()) {\n cachedResultFuture = immediateFuture(false);\n } else {\n cachedResultFuture =\n checkCacheFuture(actionKey, operation, executeEntry.getRequestMetadata());\n }\n return transformAsync(\n cachedResultFuture,\n (cachedResult) -> {\n if (cachedResult) {\n poller.pause();\n long checkCacheUSecs = stopwatch.elapsed(MICROSECONDS);\n log.log(\n Level.FINER,\n format(\n \"ServerInstance(%s): checkCache(%s): %sus elapsed\",\n getName(), operation.getName(), checkCacheUSecs));\n return IMMEDIATE_VOID_FUTURE;\n }\n return transformAndQueue(executeEntry, poller, operation, stopwatch, timeout);\n },\n operationTransformService);\n }\n```\n###test function signature:\n```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "queueDirectoryMissingErrorsOperation", + "reference": " @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ByteString foo = ByteString.copyFromUtf8(\"foo\");\n Digest subdirDigest = DIGEST_UTIL.compute(foo);\n Directory inputRoot =\n Directory.newBuilder()\n .addDirectories(\n DirectoryNode.newBuilder().setName(\"missing-subdir\").setDigest(subdirDigest))\n .build();\n ByteString inputRootContent = inputRoot.toByteString();\n Digest inputRootDigest = DIGEST_UTIL.compute(inputRootContent);\n provideBlob(inputRootDigest, inputRootContent);\n Action action = createAction(true, true, inputRootDigest, SIMPLE_COMMAND);\n Digest actionDigest = DIGEST_UTIL.compute(action);\n\n ExecuteEntry executeEntry =\n ExecuteEntry.newBuilder()\n .setOperationName(\"missing-directory-operation\")\n .setActionDigest(actionDigest)\n .setSkipCacheLookup(true)\n .build();\n\n when(mockBackplane.propertiesEligibleForQueue(anyList())).thenReturn(true);\n\n when(mockBackplane.canQueue()).thenReturn(true);\n\n Poller poller = mock(Poller.class);\n\n boolean failedPreconditionExceptionCaught = false;\n try {\n instance\n .queue(executeEntry, poller, DEFAULT_TIMEOUT)\n .get(QUEUE_TEST_TIMEOUT_SECONDS, SECONDS);\n } catch (ExecutionException e) {\n com.google.rpc.Status status = StatusProto.fromThrowable(e);\n if (status.getCode() == Code.FAILED_PRECONDITION.getNumber()) {\n failedPreconditionExceptionCaught = true;\n } else {\n e.getCause().printStackTrace();\n }\n }\n assertThat(failedPreconditionExceptionCaught).isTrue();\n\n PreconditionFailure preconditionFailure =\n PreconditionFailure.newBuilder()\n .addViolations(\n Violation.newBuilder()\n .setType(VIOLATION_TYPE_MISSING)\n .setSubject(\"blobs/\" + DigestUtil.toString(subdirDigest))\n .setDescription(\"The directory `/missing-subdir` was not found in the CAS.\"))\n .build();\n ExecuteResponse executeResponse =\n ExecuteResponse.newBuilder()\n .setStatus(\n com.google.rpc.Status.newBuilder()\n .setCode(Code.FAILED_PRECONDITION.getNumber())\n .setMessage(invalidActionVerboseMessage(actionDigest, preconditionFailure))\n .addDetails(Any.pack(preconditionFailure)))\n .build();\n assertResponse(executeResponse);\n verify(poller, atLeastOnce()).pause();\n }", + "completions": [ + "```java\n @Test\n public void queueDirectoryMissingErrorsOperation() throws Exception {\n ExecuteEntry executeEntry = mock(ExecuteEntry.class);\n Poller poller = mock(Poller.class);\n Duration timeout = Duration.ofSeconds(10);\n\n when(executeEntry.getActionDigest()).thenReturn(DigestUtil.hashString(DigestUtil.DEFAULT_INSTANCE_NAME + \"action\"));\n when(executeEntry.getStdoutStreamName()).thenReturn(\"stdout\");\n when(executeEntry.getStderrStreamName()).thenReturn(\"stderr\");\n when(executeEntry.getOperationName()).thenReturn(\"operation\");\n when(executeEntry.getRequestMetadata()).thenReturn(RequestMetadata.newBuilder().build());\n when(executeEntry.getSkipCacheLookup()).thenReturn(false);\n\n ActionKey actionKey = DigestUtil.asActionKey(executeEntry.getActionDigest());\n when(checkCacheFuture(actionKey, any(), any())).thenReturn(immediateFuture(true));\n\n ListenableFuture future = serverInstance.queue(executeEntry, poller, timeout);\n assertThat(future.get()).isEqualTo(null);\n }\n```\n" + ], + "line": 51, + "token": 300, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 20, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime = Instant.now();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n\n // Start recovery with an offset range that is partially unavailable\n long startOffset = 50;\n long endOffset = 150;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry)).isEqualTo(50);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(51);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(1);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isNotEmpty();\n assertThat(blobFs.exists(URI.create(snapshots.get(0).snapshotPath))).isTrue();\n assertThat(blobFs.listFiles(URI.create(snapshots.get(0).snapshotPath), false).length)\n .isGreaterThan(1);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithPartiallyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 100L,\n 200L,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 7, + "token_diff": 6 + }, + { + "id": 21, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public List selectHeaderAccept(String[] accepts) {\n if (accepts.length == 0) {\n return null;\n }\n for (String accept : accepts) {\n MediaType mediaType = MediaType.parseMediaType(accept);\n if (isJsonMime(mediaType) && !isProblemJsonMime(accept)) {\n return Collections.singletonList(mediaType);\n }\n }\n return MediaType.parseMediaTypes(StringUtils.arrayToCommaDelimitedString(accepts));\n }\n```\n###test function signature:\n```java\n @Test\n public void testSelectHeaderAccept()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSelectHeaderAccept", + "reference": " @Test\n public void testSelectHeaderAccept() {\n String[] accepts = {\"application/json\", \"application/xml\"};\n assertEquals(\"application/json\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"APPLICATION/XML\", \"APPLICATION/JSON\"};\n assertEquals(\"APPLICATION/JSON\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"application/xml\", \"application/json; charset=UTF8\"};\n assertEquals(\"application/json; charset=UTF8\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{\"text/plain\", \"application/xml\"};\n assertEquals(\"text/plain,application/xml\", apiClient.selectHeaderAccept(accepts));\n\n accepts = new String[]{};\n assertNull(apiClient.selectHeaderAccept(accepts));\n }", + "completions": [ + "```java\n @Test\n public void testSelectHeaderAccept() {\n // Given\n String[] accepts = {\"application/json\", \"application/problem+json\"};\n // When\n List result = selectHeaderAccept(accepts);\n // Then\n assertEquals(1, result.size());\n assertEquals(\"application/json\", result.get(0).toString());\n }\n```\n" + ], + "line": 23, + "token": 154, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollEventSentOnConsumerPoll", + "reference": " @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n SubscriptionState subscriptions = new SubscriptionState(new LogContext(), OffsetResetStrategy.NONE);\n consumer = newConsumer(\n mock(FetchBuffer.class),\n new ConsumerInterceptors<>(Collections.emptyList()),\n mock(ConsumerRebalanceListenerInvoker.class),\n subscriptions,\n singletonList(new RoundRobinAssignor()),\n \"group-id\",\n \"client-id\");\n final TopicPartition tp = new TopicPartition(\"topic\", 0);\n final List> records = singletonList(\n new ConsumerRecord<>(\"topic\", 0, 2, \"key1\", \"value1\"));\n doAnswer(invocation -> Fetch.forPartition(tp, records, true))\n .when(fetchCollector)\n .collectFetch(Mockito.any(FetchBuffer.class));\n\n consumer.subscribe(singletonList(\"topic1\"));\n consumer.poll(Duration.ofMillis(100));\n verify(applicationEventHandler).add(any(PollEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testEnsurePollEventSentOnConsumerPoll() {\n // Arrange\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n consumer.poll(Duration.ofMillis(100));\n\n // Act\n ConsumerRecords records = consumer.poll(Duration.ofMillis(100));\n\n // Assert\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 23, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetMultiplePartitionRecoveriesBehind", + "reference": " @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n final RecoveryTaskMetadata recoveryTask2 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"2\",\n \"2\",\n recoveryStartOffset * 3 + 1,\n recoveryStartOffset * 4,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask2);\n final RecoveryTaskMetadata recoveryTask21 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"21\", \"2\", recoveryStartOffset * 4 + 1, 50000, createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask21);\n await()\n .until(\n () ->\n recoveryTaskStore\n .listSync()\n .containsAll(\n List.of(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(5);\n assertThat(recoveryTasks)\n .contains(recoveryTask1, recoveryTask11, recoveryTask2, recoveryTask21);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetMultiplePartitionRecoveriesBehind() {\n // Given\n String partitionId = \"partition1\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 24, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture commitStreamObject(apache.rocketmq.controller.v1.S3StreamObject streamObject,\n List compactedObjects) {\n LOGGER.info(\"commitStreamObject with streamObject: {}, compactedObjects: {}\", TextFormat.shortDebugString(streamObject),\n compactedObjects);\n\n CompletableFuture future = new CompletableFuture<>();\n try (SqlSession session = sessionFactory.openSession()) {\n if (streamObject.getObjectId() == S3Constants.NOOP_OBJECT_ID) {\n LOGGER.error(\"S3StreamObject[object-id={}] is null or objectId is unavailable\", streamObject.getObjectId());\n String msg = String.format(\"S3StreamObject[object-id=%d] is null or objectId is unavailable\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.NOT_FOUND_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n\n long committedTs = System.currentTimeMillis();\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n\n // commit object\n if (!commitObject(streamObject.getObjectId(), streamObject.getStreamId(), streamObject.getObjectSize(), session)) {\n String msg = String.format(\"S3StreamObject[object-id=%d] is not ready for commit\",\n streamObject.getObjectId());\n ControllerException e = new ControllerException(Code.ILLEGAL_STATE_VALUE, msg);\n future.completeExceptionally(e);\n return future;\n }\n long dataTs = committedTs;\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n dataTs = compactedObjects.stream()\n .map(id -> {\n // mark destroy compacted object\n S3Object object = s3ObjectMapper.getById(id);\n object.setState(S3ObjectState.BOS_WILL_DELETE);\n object.setMarkedForDeletionTimestamp(new Date());\n s3ObjectMapper.markToDelete(object.getId(), new Date());\n\n // update dataTs to the min compacted object's dataTs\n com.automq.rocketmq.metadata.dao.S3StreamObject s3StreamObject =\n s3StreamObjectMapper.getByObjectId(id);\n return s3StreamObject.getBaseDataTimestamp().getTime();\n })\n .min(Long::compareTo).get();\n }\n\n List toCache = new ArrayList<>();\n\n // create a new S3StreamObject to replace committed ones\n if (streamObject.getObjectId() != S3Constants.NOOP_OBJECT_ID) {\n com.automq.rocketmq.metadata.dao.S3StreamObject newS3StreamObj =\n new com.automq.rocketmq.metadata.dao.S3StreamObject();\n newS3StreamObj.setStreamId(streamObject.getStreamId());\n newS3StreamObj.setObjectId(streamObject.getObjectId());\n newS3StreamObj.setObjectSize(streamObject.getObjectSize());\n newS3StreamObj.setStartOffset(streamObject.getStartOffset());\n newS3StreamObj.setEndOffset(streamObject.getEndOffset());\n newS3StreamObj.setBaseDataTimestamp(new Date(dataTs));\n newS3StreamObj.setCommittedTimestamp(new Date(committedTs));\n s3StreamObjectMapper.create(newS3StreamObj);\n toCache.add(newS3StreamObj);\n }\n\n // delete the compactedObjects of S3Stream\n if (!Objects.isNull(compactedObjects) && !compactedObjects.isEmpty()) {\n compactedObjects.forEach(id -> s3StreamObjectMapper.delete(null, null, id));\n }\n session.commit();\n\n // Update Cache\n s3StreamObjectCache.cache(streamObject.getStreamId(), toCache);\n s3StreamObjectCache.onCompact(streamObject.getStreamId(), compactedObjects);\n\n LOGGER.info(\"S3StreamObject[object-id={}] commit success, compacted objects: {}\",\n streamObject.getObjectId(), compactedObjects);\n future.complete(null);\n } catch (Exception e) {\n LOGGER.error(\"CommitStream failed\", e);\n ControllerException ex = new ControllerException(Code.INTERNAL_VALUE, \"CommitStream failed\" + e.getMessage());\n future.completeExceptionally(ex);\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCommitStreamObject() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCommitStreamObject", + "reference": " @Test\n public void testCommitStreamObject() throws IOException {\n long objectId, streamId = 1;\n\n try (S3MetadataService metadataService = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n objectId = metadataService.prepareS3Objects(3, 5).get();\n } catch (Exception e) {\n throw new RuntimeException(e);\n }\n\n S3StreamObject news3StreamObject = S3StreamObject.newBuilder()\n .setObjectId(objectId + 2)\n .setStreamId(streamId)\n .setObjectSize(111L)\n .build();\n\n try (SqlSession session = this.getSessionFactory().openSession()) {\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n buildS3StreamObjs(objectId, 2, 3, 100L).forEach(s3StreamObjectMapper::create);\n session.commit();\n }\n\n try (S3MetadataService service = new DefaultS3MetadataService(config, getSessionFactory(), executorService)) {\n List compactedObjects = new ArrayList<>();\n compactedObjects.add(objectId);\n compactedObjects.add(objectId + 1);\n service.commitStreamObject(news3StreamObject, compactedObjects);\n }\n\n try (SqlSession session = getSessionFactory().openSession()) {\n S3ObjectMapper s3ObjectMapper = session.getMapper(S3ObjectMapper.class);\n S3Object s3Object = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object.getState());\n\n S3Object s3Object1 = s3ObjectMapper.getById(objectId);\n Assertions.assertEquals(S3ObjectState.BOS_WILL_DELETE, s3Object1.getState());\n\n S3StreamObjectMapper s3StreamObjectMapper = session.getMapper(S3StreamObjectMapper.class);\n for (long index = objectId; index < objectId + 2; index++) {\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(index);\n Assertions.assertNull(object);\n }\n\n com.automq.rocketmq.metadata.dao.S3StreamObject object = s3StreamObjectMapper.getByObjectId(objectId + 2);\n Assertions.assertEquals(111L, object.getObjectSize());\n Assertions.assertEquals(streamId, object.getStreamId());\n Assertions.assertTrue(object.getBaseDataTimestamp().getTime() > 0);\n Assertions.assertTrue(object.getCommittedTimestamp().getTime() > 0);\n }\n }", + "completions": [ + "```java\n @Test\n public void testCommitStreamObject() throws IOException {\n // Given\n apache.rocketmq.controller.v1.S3StreamObject streamObject = new apache.rocketmq.controller.v1.S3StreamObject();\n streamObject.setObjectId(1L);\n streamObject.setStreamId(2L);\n streamObject.setObjectSize(3L);\n streamObject.setStartOffset(4L);\n streamObject.setEndOffset(5L);\n\n List compactedObjects = Arrays.asList(6L, 7L, 8L);\n\n // When\n CompletableFuture future = controller.commitStreamObject(streamObject, compactedObjects);\n\n // Then\n assertDoesNotThrow(() -> future.get());\n }\n```\n" + ], + "line": 85, + "token": 664, + "line_diff": 6, + "token_diff": 7 + }, + { + "id": 25, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyTooManyResults", + "reference": " @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(totalHits.value).thenReturn(2L);\n Mockito.when(hits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(response.getHits()).thenReturn(hits);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processService.getProcessByKey(processDefinitionKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Expected exception message to match\",\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n e.getMessage());\n }\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 26, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String doGet(String url) {\n try (CloseableHttpClient client = HttpClients.createDefault()) {\n HttpGet request = new HttpGet(url);\n HttpResponse response = client.execute(request);\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n return EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n }\n } catch (IOException e) {\n log.error(\"DoGet exception.\", e);\n\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n void testDoGet()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDoGet", + "reference": " @Test\n void testDoGet() {\n String result = HttpUtil.doGet(\"http://baidu.com\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testDoGet() {\n String url = \"http://example.com\";\n String expectedResponse = \"Hello, world!\";\n\n // Mock the HttpClient and HttpResponse\n CloseableHttpClient mockClient = Mockito.mock(CloseableHttpClient.class);\n HttpResponse mockResponse = Mockito.mock(HttpResponse.class);\n StatusLine mockStatusLine = Mockito.mock(StatusLine.class);\n HttpEntity mockEntity = Mockito.mock(HttpEntity.class);\n\n // Set up the mock objects to return the expected values\n Mockito.when(mockClient.execute(Mockito.any(HttpGet.class))).thenReturn(mockResponse);\n Mockito.when(mockResponse.getStatusLine()).thenReturn(mockStatusLine);\n Mockito.when(mockStatusLine.getStatusCode()).thenReturn(HttpStatus.SC_OK);\n Mockito.when(mockResponse.getEntity()).thenReturn(mockEntity);\n Mockito.when(EntityUtils.toString(mockEntity, StandardCharsets.UTF_8)).thenReturn(expectedResponse);\n\n // Call the method under test\n String actualResponse = YourClassName.doGet(url);\n\n // Verify the result\n Assertions.assertEquals(expectedResponse, actualResponse);\n }\n```\nPlease replace `YourClassName` with the actual name of the class where the `doGet` method is located.\n" + ], + "line": 23, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n dpopService.validateDpopThumprint(null, \"any\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintIsMissedAndConfIsFalse_shouldBeValid() {\n // Given\n String existingThumprint = \"\";\n String requestThumprint = \"requestThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n try {\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), e.getResponse().getStatus());\n assertEquals(\"Thumprint does not match\", e.getMessage());\n }\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 28, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void sendData(String database, String table, String content, boolean partitioned) throws Exception {\n final String loadUrl = String.format(\"http://%s:%s/api/%s/%s/_stream_load?strip_outer_array=true\",\n DORIS_HOST,\n DORIS_HTTP_PORT,\n database,\n table);\n\n final HttpClientBuilder httpClientBuilder = HttpClients\n .custom()\n .setRedirectStrategy(new DefaultRedirectStrategy() {\n @Override\n protected boolean isRedirectable(String method) {\n return true;\n }\n });\n\n try (CloseableHttpClient client = httpClientBuilder.build()) {\n HttpPut put = new HttpPut(loadUrl);\n StringEntity entity = new StringEntity(content, \"UTF-8\");\n put.setHeader(HttpHeaders.EXPECT, \"100-continue\");\n put.setHeader(HttpHeaders.AUTHORIZATION, HttpUtil.basicAuthHeader(DORIS_USER, DORIS_PASSWORD));\n put.setHeader(\"max_filter_ratio\", \"0.1\");\n if (partitioned) {\n SimpleDateFormat simpleDateFormat = new SimpleDateFormat(\"yyyyMMdd\");\n put.setHeader(\"partitions\", \"p\" + simpleDateFormat.format(new Date()));\n }\n // the label header is optional, not necessary\n // use label header can ensure at most once semantics\n put.setEntity(entity);\n try (CloseableHttpResponse response = client.execute(put)) {\n String contentStr = new String(ByteStreams.toByteArray(response.getEntity().getContent()));\n JsonObject jsonObject = jsonParser.parse(contentStr).getAsJsonObject();\n log.info(\"result:{}\", contentStr);\n int statusCode = response.getStatusLine().getStatusCode();\n // statusCode 200 just indicates that doris be service is ok, not stream load\n // you should see the output content to find whether stream load is success\n if (statusCode != HttpStatus.SC_OK || (!jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"SUCCESS\") &&\n !jsonObject.get(\"Status\").getAsString().equalsIgnoreCase(\"Publish Timeout\"))) {\n throw new IOException(\n String.format(\"Stream load failed, statusCode=%s load result=%s content=%s\", statusCode, jsonObject.toString(), content));\n }\n } catch (Exception e) {\n log.error(\"error\", e);\n } finally {\n client.close();\n }\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sendDataMapTest() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendDataMapTest", + "reference": " @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test\";\n String tableName = \"hera_log_doris_table_2_11\";\n for (int i = 0; i < 1; i++) {\n String jsonString = \"{\\\"linenumber\\\":1256,\\\"tailId\\\":90219,\\\"mqtag\\\":\\\"tags_4_13_90219\\\",\\\"filename\\\":\\\"/home/work/log/log-agent/server.log\\\",\\\"tail\\\":\\\"demo-client-agent\\\",\\\"mqtopic\\\":\\\"90219_hera-demo-client\\\",\\\"message\\\":\\\"2024-01-10 19:34:40,360|INFO ||NettyClientPublicExecutor_1|c.x.mone.log.agent.rpc.task.PingTask|83|ping res: log-agent-server:2022-12-05:0.0.2->2024-01-10 19:34:40 358->10.53.129.250:9899\\\",\\\"logstore\\\":\\\"测试doris日志\\\",\\\"logip\\\":\\\"10.53.129.176\\\",\\\"timestamp\\\":1704886481241}\";\n\n Map map = gson.fromJson(jsonString, Map.class);\n dorisStreamLoad.sendData(database, tableName, columnListNew, map);\n }\n }", + "completions": [ + "```java\n @Test\n public void sendDataMapTest() throws Exception {\n String database = \"test_db\";\n String table = \"test_table\";\n String content = \"{\\\"columns\\\":[\\\"id\\\",\\\"name\\\",\\\"age\\\"],\\\"rows\\\":[[1,\\\"John\\\",25],[2,\\\"Jane\\\",30]]}\";\n boolean partitioned = true;\n\n sendData(database, table, content, partitioned);\n }\n```\n" + ], + "line": 57, + "token": 486, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 29, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyNoResults", + "reference": " @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyNoResults() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final NotFoundException exception =\n Assertions.assertThrows(\n NotFoundException.class, () -> processService.getProcessByKey(processDefinitionKey));\n\n // Then\n Assertions.assertEquals(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 30, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String getDiagramByKey(Long processDefinitionKey) {\n final IdsQueryBuilder q = idsQuery().addIds(processDefinitionKey.toString());\n\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n\n if (response.getHits().getTotalHits().value == 1) {\n final Map result = response.getHits().getHits()[0].getSourceAsMap();\n return (String) result.get(BPMN_XML);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with id '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with id '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\n \"Exception occurred, while obtaining the process diagram: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetDiagramByKeyTooManyResults", + "reference": " @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(2L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getDiagramByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetDiagramByKeyTooManyResults() throws IOException {\n // Given\n final Long processDefinitionKey = 1L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(new SearchSourceBuilder().query(q).fetchSource(BPMN_XML, null));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits hits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(hits);\n Mockito.when(hits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHits.Relation.EQUAL_TO));\n\n // When\n try {\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n processDiagramService.getDiagramByKey(processDefinitionKey);\n } catch (NotFoundException e) {\n // Then\n Assert.assertEquals(\n \"Could not find unique process with id '\" + processDefinitionKey + \"'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 37, + "token": 309, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 31, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n void shutdown() {\n LOG.info(\"Running shutdown hook.\");\n try {\n serviceManager.stopAsync().awaitStopped(30, TimeUnit.SECONDS);\n } catch (Exception e) {\n // stopping timed out\n LOG.error(\"ServiceManager shutdown timed out\", e);\n }\n try {\n curatorFramework.unwrap().close();\n } catch (Exception e) {\n LOG.error(\"Error while closing curatorFramework \", e);\n }\n LOG.info(\"Shutting down LogManager\");\n LogManager.shutdown();\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexerShutdownTwice() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexerShutdownTwice", + "reference": " @Test\n public void testIndexerShutdownTwice() throws Exception {\n startKafkaServer();\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // Create a live partition for this partiton\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTimeMs = 1;\n final long endTimeMs = 100;\n final long maxOffset = 30;\n SnapshotMetadata livePartition0 =\n new SnapshotMetadata(\n name + \"live0\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"0\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition0);\n\n SnapshotMetadata livePartition1 =\n new SnapshotMetadata(\n name + \"live1\",\n LIVE_SNAPSHOT_PATH,\n startTimeMs,\n endTimeMs,\n maxOffset,\n \"1\",\n LOGS_LUCENE9);\n snapshotMetadataStore.createSync(livePartition1);\n\n final SnapshotMetadata partition0 =\n new SnapshotMetadata(name, path, startTimeMs, endTimeMs, maxOffset, \"0\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition0);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .containsOnly(livePartition1, livePartition0, partition0);\n\n // Empty consumer offset since there is no prior consumer.\n astraIndexer =\n new AstraIndexer(\n chunkManagerUtil.chunkManager,\n curatorFramework,\n makeIndexerConfig(50),\n getKafkaConfig(),\n metricsRegistry);\n astraIndexer.startAsync();\n astraIndexer.awaitRunning(DEFAULT_START_STOP_DURATION);\n await().until(() -> kafkaServer.getConnectedConsumerGroups() == 1);\n\n // Produce more messages since the recovery task is created for head.\n produceMessagesToKafka(kafkaServer.getBroker(), startTime);\n\n consumeMessagesAndSearchMessagesTest(100, 1);\n\n // Live snapshot is deleted, recovery task is created.\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(livePartition1, partition0);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isEqualTo(4);\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).size()).isEqualTo(1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(searchMetadataStore).size()).isEqualTo(1);\n RecoveryTaskMetadata recoveryTask1 =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore).get(0);\n assertThat(recoveryTask1.startOffset).isEqualTo(31);\n assertThat(recoveryTask1.endOffset).isEqualTo(99);\n assertThat(recoveryTask1.partitionId).isEqualTo(\"0\");\n\n // Shutting down is idempotent. So, doing it twice shouldn't throw an error.\n astraIndexer.shutDown();\n astraIndexer.shutDown();\n astraIndexer = null;\n }", + "completions": [ + "```java\n @Test\n public void testIndexerShutdownTwice() throws Exception {\n // Given\n Indexer indexer = new Indexer();\n indexer.serviceManager = Mockito.mock(ServiceManager.class);\n indexer.curatorFramework = Mockito.mock(CuratorFramework.class);\n Mockito.when(indexer.serviceManager.stopAsync()).thenReturn(Mockito.mock(ServiceManager.class));\n Mockito.when(indexer.serviceManager.awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class))).thenThrow(new TimeoutException());\n Mockito.when(indexer.curatorFramework.unwrap()).thenReturn(Mockito.mock(Closeable.class));\n Mockito.doThrow(new IOException()).when(indexer.curatorFramework.unwrap()).close();\n\n // When\n indexer.shutdown();\n\n // Then\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).stopAsync();\n Mockito.verify(indexer.serviceManager, Mockito.times(1)).awaitStopped(Mockito.anyLong(), Mockito.any(TimeUnit.class));\n Mockito.verify(indexer.curatorFramework.unwrap(), Mockito.times(1)).close();\n }\n```\n" + ], + "line": 27, + "token": 177, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_notInBlacklistAddressIsNull", + "reference": " @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull(){\n when(blacklistService.inBlacklist(anyString())).thenReturn(true);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n SendStatus result = mailUtils.sendHtmlMail(addressList,\"test\",\"test\",\"test\", SystemEnum.AGENT, MsgTypeEnum.ALARM);\n Assertions.assertEquals(\"false\",result.getStatus());\n Assertions.assertTrue(result.getErrorMessage().contains(\"blacklist\"));\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_notInBlacklistAddressIsNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = null;\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://test.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 33, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessEntity getProcessByKey(Long processDefinitionKey) {\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n try {\n final SearchResponse response = tenantAwareClient.search(searchRequest);\n if (response.getHits().getTotalHits().value == 1) {\n return fromSearchHit(response.getHits().getHits()[0].getSourceAsString());\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey));\n } else {\n throw new NotFoundException(\n String.format(\"Could not find process with key '%s'.\", processDefinitionKey));\n }\n } catch (IOException e) {\n final String message =\n String.format(\"Exception occurred, while obtaining the process: %s\", e.getMessage());\n LOGGER.error(message, e);\n throw new OperateRuntimeException(message, e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessByKeyWithException", + "reference": " @Test\n public void testGetProcessByKeyWithException() throws IOException {\n when(processIndex.getAlias()).thenReturn(\"processIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(OperateRuntimeException.class, () -> underTest.getProcessByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessByKeyWithException() throws IOException {\n // Given\n final Long processDefinitionKey = 123L;\n final SearchRequest searchRequest =\n new SearchRequest(processIndex.getAlias())\n .source(\n new SearchSourceBuilder()\n .query(QueryBuilders.termQuery(ProcessIndex.KEY, processDefinitionKey)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n final TotalHits totalHits = Mockito.mock(TotalHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(totalHits);\n Mockito.when(totalHits.value).thenReturn(2L);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n final Executable executable = () -> processService.getProcessByKey(processDefinitionKey);\n\n // Then\n final NotFoundException exception =\n Assertions.assertThrows(NotFoundException.class, executable);\n Assertions.assertEquals(\n String.format(\"Could not find unique process with key '%s'.\", processDefinitionKey),\n exception.getMessage());\n }\n```\n" + ], + "line": 36, + "token": 279, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 34, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testConcurrentSearches() throws InterruptedException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testConcurrentSearches", + "reference": " @Test\n public void testConcurrentSearches() throws InterruptedException {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n AtomicInteger searchFailures = new AtomicInteger(0);\n AtomicInteger statsFailures = new AtomicInteger(0);\n AtomicInteger searchExceptions = new AtomicInteger(0);\n AtomicInteger successfulRuns = new AtomicInteger(0);\n\n Runnable searchRun =\n () -> {\n for (int i = 0; i < 100; i++) {\n try {\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n 0L,\n MAX_TIME,\n 100,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n if (babies.hits.size() != 2) {\n searchFailures.addAndGet(1);\n } else {\n successfulRuns.addAndGet(1);\n }\n } catch (Exception e) {\n searchExceptions.addAndGet(1);\n }\n }\n };\n\n Thread t1 = new Thread(searchRun);\n Thread t2 = new Thread(searchRun);\n t1.start();\n t2.start();\n t1.join();\n t2.join();\n assertThat(searchExceptions.get()).isEqualTo(0);\n assertThat(statsFailures.get()).isEqualTo(0);\n assertThat(searchFailures.get()).isEqualTo(0);\n assertThat(successfulRuns.get()).isEqualTo(200);\n }", + "completions": [ + "```java\n @Test\n public void testConcurrentSearches() throws InterruptedException {\n // Create a fixed thread pool with 10 threads\n ExecutorService executorService = Executors.newFixedThreadPool(10);\n\n // Create a list to hold the futures\n List>> futures = new ArrayList<>();\n\n // Submit 100 search tasks to the executor service\n for (int i = 0; i < 100; i++) {\n // Create a search task\n Callable> searchTask = () -> {\n // Generate a random dataset\n String dataset = UUID.randomUUID().toString();\n\n // Generate a random query string\n String queryStr = UUID.randomUUID().toString();\n\n // Generate a random start time\n long startTimeMsEpoch = System.currentTimeMillis() - ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random end time\n long endTimeMsEpoch = startTimeMsEpoch + ThreadLocalRandom.current().nextLong(10000);\n\n // Generate a random number of hits requested\n int howMany = ThreadLocalRandom.current().nextInt(100);\n\n // Generate a random aggregation builder\n AggBuilder aggBuilder = new AggBuilder();\n\n // Perform the search\n return logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n };\n\n // Submit the search task to the executor service and add the future to the list\n futures.add(executorService.submit(searchTask));\n }\n\n // Wait for all search tasks to complete\n for (Future> future : futures) {\n future.get();\n }\n\n // Shutdown the executor service\n executorService.shutdown();\n }\n```\n" + ], + "line": 85, + "token": 608, + "line_diff": 6, + "token_diff": 7 + }, + { + "id": 35, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String sendGetData(String path, Map headMap) {\n log.info(\"request tcm, path:{},headMap:{} \",path,headMap);\n String result = \"\";\n CloseableHttpResponse response =null;\n try(CloseableHttpClient httpClient = HttpClients.createDefault()) {\n // 创建get方式请求对象\n HttpGet httpGet = new HttpGet(path);\n if (null != headMap) {\n for (Map.Entry entry : headMap.entrySet()) {\n httpGet.addHeader(entry.getKey(), entry.getValue());\n }\n }\n // 通过请求对象获取响应对象\n response = httpClient.execute(httpGet);\n\n // 获取结果实体\n // 判断网络连接状态码是否正常(0--200都数正常)\n if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) {\n result = EntityUtils.toString(response.getEntity(), UTF_8);\n } else {\n log.error(\"get请求传输 异常.url:{}, headMap:{}\", path, headMap);\n }\n } catch (ClientProtocolException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n } catch (IOException e) {\n log.error(\"get请求传输 异常 \", e);\n log.error(\"get请求传输 异常.path:{}, headMap:{}\", path, headMap);\n }\n finally {\n try {\n if (null!=response){\n response.close();\n }\n } catch (IOException e) {\n log.error(\"关闭response 异常 \", e);\n }\n }\n log.debug(result);\n return result;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendGetData()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendGetData", + "reference": " @Test\n void testSendGetData() {\n final Map headMap = new HashMap<>();\n final String result = HttpUtils.sendGetData(\"path\", headMap);\n assertThat(result).isEqualTo(\"\");\n }", + "completions": [ + "```java\n @Test\n void testSendGetData() {\n String path = \"http://localhost:8080/api/test\";\n Map headMap = new HashMap<>();\n headMap.put(\"Content-Type\", \"application/json\");\n headMap.put(\"Authorization\", \"Bearer your_token\");\n\n String result = sendGetData(path, headMap);\n\n // Add assertions to verify the result\n // For example, if the result is expected to be a JSON string, you can use a JSON library to parse the result and assert the values\n // Here is an example using the JSON library Jackson\n ObjectMapper objectMapper = new ObjectMapper();\n JsonNode jsonNode = objectMapper.readTree(result);\n assertEquals(\"expectedValue\", jsonNode.get(\"key\").asText());\n }\n```\nPlease replace \"your_token\" with the actual token you want to use for the test. Also, replace \"expectedValue\" and \"key\" with the actual expected value and key from the JSON response.\n" + ], + "line": 51, + "token": 367, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 36, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ByteString retainKeyPrefix(ByteString tenantNS, List topicFilterLevels) {\n ByteString prefix = ByteString.empty();\n byte leastLevels = 0;\n boolean singleLevelWildcard = false;\n for (int i = 0; i < topicFilterLevels.size(); i++) {\n String tfl = topicFilterLevels.get(i);\n if (\"+\".equals(tfl)) {\n leastLevels++;\n singleLevelWildcard = true;\n continue;\n }\n if (\"#\".equals(tfl)) {\n break;\n }\n leastLevels++;\n if (!singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(tfl));\n }\n if (i + 1 < topicFilterLevels.size()) {\n if (!topicFilterLevels.get(i + 1).equals(\"#\") && !singleLevelWildcard) {\n prefix = prefix.concat(copyFromUtf8(NUL));\n }\n }\n }\n return tenantNS.concat(unsafeWrap(new byte[] {leastLevels})).concat(prefix);\n }\n```\n###test function signature:\n```java\n @Test\n public void testRetainKeyPrefix()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRetainKeyPrefix", + "reference": " @Test\n public void testRetainKeyPrefix() {\n String tenantId = \"tenantA\";\n ByteString tenantNS = tenantNS(tenantId);\n assertEquals(retainKeyPrefix(tenantNS, parse(\"#\", false)), tenantNS.concat(levelByte(0)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"+/#\", false)), tenantNS.concat(levelByte(1)));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/#\", false)),\n tenantNS.concat(levelByte(1).concat(copyFromUtf8(\"a\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/+\", false)),\n tenantNS.concat(levelByte(2).concat(copyFromUtf8(\"a\").concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"a/b\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)).concat(copyFromUtf8(\"b\"))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/#\", false)),\n tenantNS.concat(\n levelByte(2).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+\", false)),\n tenantNS.concat(\n levelByte(3).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/a/+/+\", false)),\n tenantNS.concat(\n levelByte(4).concat(copyFromUtf8(NUL).concat(copyFromUtf8(\"a\")).concat(copyFromUtf8(NUL)))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/\", false)),\n tenantNS.concat(levelByte(4).concat(copyFromUtf8(NUL))));\n assertEquals(retainKeyPrefix(tenantNS, parse(\"/+/b/+/\", false)),\n tenantNS.concat(levelByte(5).concat(copyFromUtf8(NUL))));\n }", + "completions": [ + "```java\n @Test\n public void testRetainKeyPrefix() {\n ByteString tenantNS = ByteString.copyFromUtf8(\"tenant1\");\n List topicFilterLevels = Arrays.asList(\"level1\", \"level2\", \"+\", \"#\");\n ByteString expectedPrefix = ByteString.copyFromUtf8(\"tenant1\\u0001level1\\u0000level2\");\n ByteString actualPrefix = retainKeyPrefix(tenantNS, topicFilterLevels);\n assertEquals(expectedPrefix, actualPrefix);\n }\n```\n" + ], + "line": 37, + "token": 255, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 37, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n @Transactional(rollbackFor = Exception.class)\n public void updateFileConfigMaster(Long id) {\n // 校验存在\n validateFileConfigExists(id);\n // 更新其它为非 master\n fileConfigMapper.updateBatch(new FileConfigDO().setMaster(false));\n // 更新\n fileConfigMapper.updateById(new FileConfigDO().setId(id).setMaster(true));\n // 发送刷新配置的消息\n TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronization() {\n\n @Override\n public void afterCommit() {\n fileConfigProducer.sendFileConfigRefreshMessage();\n }\n\n });\n }\n```\n###test function signature:\n```java\n @Test\n public void testUpdateFileConfigMaster_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testUpdateFileConfigMaster_success", + "reference": " @Test\n public void testUpdateFileConfigMaster_success() {\n // mock 数据\n FileConfigDO dbFileConfig = randomFileConfigDO().setMaster(false);\n fileConfigMapper.insert(dbFileConfig);// @Sql: 先插入出一条存在的数据\n FileConfigDO masterFileConfig = randomFileConfigDO().setMaster(true);\n fileConfigMapper.insert(masterFileConfig);// @Sql: 先插入出一条存在的数据\n\n // 调用\n fileConfigService.updateFileConfigMaster(dbFileConfig.getId());\n // 断言数据\n assertTrue(fileConfigMapper.selectById(dbFileConfig.getId()).getMaster());\n assertFalse(fileConfigMapper.selectById(masterFileConfig.getId()).getMaster());\n // verify 调用\n verify(fileConfigProducer).sendFileConfigRefreshMessage();\n }", + "completions": [ + "```java\n @Test\n public void testUpdateFileConfigMaster_success() {\n // 准备数据\n FileConfigDO fileConfigDO = new FileConfigDO();\n fileConfigDO.setId(1L);\n fileConfigDO.setMaster(false);\n fileConfigMapper.insert(fileConfigDO);\n\n // 执行方法\n fileConfigService.updateFileConfigMaster(1L);\n\n // 验证结果\n FileConfigDO updatedFileConfigDO = fileConfigMapper.selectById(1L);\n Assert.assertTrue(updatedFileConfigDO.getMaster());\n\n // 清理数据\n fileConfigMapper.deleteById(1L);\n }\n```\n" + ], + "line": 28, + "token": 164, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\t@Override\n\tpublic void add(List documents) {\n\n\t\tAssert.notNull(documents, \"The document list should not be null.\");\n\t\tif (CollectionUtils.isEmpty(documents)) {\n\t\t\treturn; // nothing to do;\n\t\t}\n\n\t\tfinal var searchDocuments = documents.stream().map(document -> {\n\t\t\tfinal var embeddings = this.embeddingClient.embed(document);\n\t\t\tSearchDocument searchDocument = new SearchDocument();\n\t\t\tsearchDocument.put(ID_FIELD_NAME, document.getId());\n\t\t\tsearchDocument.put(EMBEDDING_FIELD_NAME, embeddings);\n\t\t\tsearchDocument.put(CONTENT_FIELD_NAME, document.getContent());\n\t\t\tsearchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());\n\n\t\t\t// Add the filterable metadata fields as top level fields, allowing filler\n\t\t\t// expressions on them.\n\t\t\tfor (MetadataField mf : this.filterMetadataFields) {\n\t\t\t\tif (document.getMetadata().containsKey(mf.name())) {\n\t\t\t\t\tsearchDocument.put(METADATA_FIELD_PREFIX + mf.name(), document.getMetadata().get(mf.name()));\n\t\t\t\t}\n\t\t\t}\n\n\t\t\treturn searchDocument;\n\t\t}).toList();\n\n\t\tIndexDocumentsResult result = this.searchClient.uploadDocuments(searchDocuments);\n\n\t\tfor (IndexingResult indexingResult : result.getResults()) {\n\t\t\tAssert.isTrue(indexingResult.isSucceeded(),\n\t\t\t\t\tString.format(\"Document with key %s upload is not successfully\", indexingResult.getKey()));\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void addAndSearchTest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "addAndSearchTest", + "reference": "\t@Test\n\tpublic void addAndSearchTest() {\n\n\t\tcontextRunner.run(context -> {\n\n\t\t\tVectorStore vectorStore = context.getBean(VectorStore.class);\n\n\t\t\tvectorStore.add(documents);\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\t\t\t}, hasSize(1));\n\n\t\t\tList results = vectorStore.similaritySearch(SearchRequest.query(\"Great Depression\").withTopK(1));\n\n\t\t\tassertThat(results).hasSize(1);\n\t\t\tDocument resultDoc = results.get(0);\n\t\t\tassertThat(resultDoc.getId()).isEqualTo(documents.get(2).getId());\n\t\t\tassertThat(resultDoc.getContent()).contains(\"The Great Depression (1929–1939) was an economic shock\");\n\t\t\tassertThat(resultDoc.getMetadata()).hasSize(2);\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"meta2\");\n\t\t\tassertThat(resultDoc.getMetadata()).containsKey(\"distance\");\n\n\t\t\t// Remove all documents from the store\n\t\t\tvectorStore.delete(documents.stream().map(doc -> doc.getId()).toList());\n\n\t\t\tAwaitility.await().until(() -> {\n\t\t\t\treturn vectorStore.similaritySearch(SearchRequest.query(\"Hello\").withTopK(1));\n\t\t\t}, hasSize(0));\n\t\t});\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void addAndSearchTest() {\n\t\t// Given\n\t\tList documents = new ArrayList<>();\n\t\tDocument document1 = new Document();\n\t\tdocument1.setId(\"1\");\n\t\tdocument1.setContent(\"This is a test document\");\n\t\tdocument1.setMetadata(new HashMap<>());\n\t\tdocument1.getMetadata().put(\"author\", \"John Doe\");\n\t\tdocument1.getMetadata().put(\"date\", \"2022-01-01\");\n\t\tdocuments.add(document1);\n\n\t\tDocument document2 = new Document();\n\t\tdocument2.setId(\"2\");\n\t\tdocument2.setContent(\"This is another test document\");\n\t\tdocument2.setMetadata(new HashMap<>());\n\t\tdocument2.getMetadata().put(\"author\", \"Jane Doe\");\n\t\tdocument2.getMetadata().put(\"date\", \"2022-01-02\");\n\t\tdocuments.add(document2);\n\n\t\t// When\n\t\tsearchService.add(documents);\n\n\t\t// Then\n\t\tSearchRequest searchRequest = new SearchRequest();\n\t\tsearchRequest.setQuery(\"test\");\n\t\tsearchRequest.setFilter(new HashMap<>());\n\t\tsearchRequest.getFilter().put(\"author\", \"John Doe\");\n\t\tSearchResponse searchResponse = searchService.search(searchRequest);\n\n\t\tAssert.assertEquals(1, searchResponse.getDocuments().size());\n\t\tAssert.assertEquals(\"1\", searchResponse.getDocuments().get(0).getId());\n\t}\n```\n" + ], + "line": 39, + "token": 312, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 39, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void collect(CollectRep.MetricsData.Builder builder, long monitorId, String app, Metrics metrics) {\n long startTime = System.currentTimeMillis();\n\n // 校验参数\n try {\n validateParams(metrics);\n } catch (Exception e) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(e.getMessage());\n return;\n }\n\n HttpContext httpContext = createHttpContext(metrics.getNginx());\n HttpUriRequest request = createHttpRequest(metrics.getNginx());\n try (CloseableHttpResponse response = CommonHttpClient.getHttpClient().execute(request, httpContext)){\n // 发起http请求,获取响应数据\n int statusCode = response.getStatusLine().getStatusCode();\n if (statusCode != SUCCESS_CODE) {\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(\"StatusCode \" + statusCode);\n return;\n }\n String resp = EntityUtils.toString(response.getEntity(), StandardCharsets.UTF_8);\n\n Long responseTime = System.currentTimeMillis() - startTime;\n // 根据metrics name选择调用不同解析方法\n if (NGINX_STATUS_NAME.equals(metrics.getName()) || AVAILABLE.equals(metrics.getName())) {\n parseNginxStatusResponse(builder, resp, metrics, responseTime);\n } else if (REQ_STATUS_NAME.equals(metrics.getName())) {\n parseReqStatusResponse(builder, resp, metrics, responseTime);\n }\n } catch (Exception e) {\n String errorMsg = CommonUtil.getMessageFromThrowable(e);\n log.info(errorMsg);\n builder.setCode(CollectRep.Code.FAIL);\n builder.setMsg(errorMsg);\n } finally {\n if (request != null) {\n request.abort();\n }\n }\n\n }\n```\n###test function signature:\n```java\n @Test\n public void testNginxCollect()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNginxCollect", + "reference": " @Test\n public void testNginxCollect() {\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n long monitorId = 999;\n String app = \"testNginx\";\n\n Metrics metrics = new Metrics();\n metrics.setNginx(nginxProtocol);\n nginxCollect.collect(builder, monitorId, app, metrics);\n }", + "completions": [ + "```java\n @Test\n public void testNginxCollect() {\n // 创建一个Metrics对象\n Metrics metrics = new Metrics();\n metrics.setName(\"nginx_status\");\n metrics.setNginx(new Nginx(\"http://localhost:8080/nginx_status\"));\n\n // 创建一个CollectRep.MetricsData.Builder对象\n CollectRep.MetricsData.Builder builder = CollectRep.MetricsData.newBuilder();\n\n // 创建一个NginxCollect对象\n NginxCollect nginxCollect = new NginxCollect();\n\n // 调用collect方法\n nginxCollect.collect(builder, 1, \"test\", metrics);\n\n // 验证结果\n assertEquals(CollectRep.Code.SUCCESS, builder.getCode());\n assertNotNull(builder.getData());\n assertTrue(builder.getData().getMetrics().containsKey(\"active_connections\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"accepts\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"handled\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"requests\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"reading\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"writing\"));\n assertTrue(builder.getData().getMetrics().containsKey(\"waiting\"));\n }\n```\n" + ], + "line": 51, + "token": 335, + "line_diff": 5, + "token_diff": 3 + }, + { + "id": 40, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Transactional(rollbackOn = Exception.class)\n public void submit(@NonNull PartitionPlanConfig partitionPlanConfig)\n throws SchedulerException, ClassNotFoundException {\n Long databaseId = partitionPlanConfig.getDatabaseId();\n Validate.notNull(databaseId, \"DatabaseId can not be null\");\n // disable all related partition plan task\n Database database = this.databaseService.detail(databaseId);\n disablePartitionPlan(database.getId());\n PartitionPlanEntity partitionPlanEntity = modelToEntity(partitionPlanConfig);\n partitionPlanEntity = this.partitionPlanRepository.save(partitionPlanEntity);\n if (!partitionPlanConfig.isEnabled()\n || CollectionUtils.isEmpty(partitionPlanConfig.getPartitionTableConfigs())) {\n log.info(\"Partition plan is disabled or table config is empty, do nothing and return\");\n return;\n }\n Validate.isTrue(partitionPlanConfig.getCreationTrigger() != null, \"Creation trigger can not be null\");\n if (partitionPlanConfig.getDroppingTrigger() == null) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(partitionPlanConfig.getPartitionTableConfigs(),\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n return;\n }\n Map> strategy2TblCfgs =\n partitionPlanConfig.getPartitionTableConfigs().stream().flatMap(tableConfig -> {\n Map> strategy2Cfgs =\n tableConfig.getPartitionKeyConfigs().stream()\n .collect(Collectors.groupingBy(PartitionPlanKeyConfig::getStrategy));\n return strategy2Cfgs.values().stream().map(cfgs -> {\n PartitionPlanTableConfig cfg = new PartitionPlanTableConfig();\n cfg.setPartitionKeyConfigs(cfgs);\n cfg.setTableName(tableConfig.getTableName());\n cfg.setEnabled(tableConfig.isEnabled());\n cfg.setPartitionNameInvoker(tableConfig.getPartitionNameInvoker());\n cfg.setPartitionNameInvokerParameters(tableConfig.getPartitionNameInvokerParameters());\n return cfg;\n });\n }).collect(Collectors.groupingBy(cfg -> cfg.getPartitionKeyConfigs().get(0).getStrategy()));\n List createConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.CREATE);\n List dropConfigs = strategy2TblCfgs.get(PartitionPlanStrategy.DROP);\n if (CollectionUtils.isNotEmpty(createConfigs)) {\n ScheduleEntity createScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getCreationTrigger());\n createPartitionPlanTables(createConfigs,\n partitionPlanEntity.getId(), createScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n if (CollectionUtils.isNotEmpty(dropConfigs)) {\n ScheduleEntity dropScheduleEntity = createAndEnableSchedule(\n database, partitionPlanConfig.getDroppingTrigger());\n createPartitionPlanTables(dropConfigs,\n partitionPlanEntity.getId(), dropScheduleEntity.getId(),\n partitionPlanConfig.getFlowInstanceId(), partitionPlanConfig.getTaskId(),\n partitionPlanConfig.getErrorStrategy(), partitionPlanConfig.getTimeoutMillis());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "submit_bothCreateAndDropTrigger_submitSucceed", + "reference": " @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n PartitionPlanTableConfig tableConfig = new PartitionPlanTableConfig();\n tableConfig.setTableName(MYSQL_REAL_RANGE_TABLE_NAME);\n tableConfig.setPartitionNameInvoker(\"CUSTOM_PARTITION_NAME_GENERATOR\");\n SqlExprBasedGeneratorConfig config = new SqlExprBasedGeneratorConfig();\n config.setGenerateExpr(\"concat('p', date_format(from_unixtime(unix_timestamp(\"\n + \"STR_TO_DATE(20240125, '%Y%m%d')) + \"\n + PartitionPlanVariableKey.INTERVAL.getVariable() + \"), '%Y%m%d'))\");\n config.setIntervalGenerateExpr(\"86400\");\n tableConfig.setPartitionNameInvokerParameters(getSqlExprBasedNameGeneratorParameters(config));\n PartitionPlanKeyConfig c3Create = getMysqlc3CreateConfig();\n PartitionPlanKeyConfig datekeyCreate = getMysqldatekeyCreateConfig();\n PartitionPlanKeyConfig dropConfig = getDropConfig();\n tableConfig.setPartitionKeyConfigs(Arrays.asList(c3Create, datekeyCreate, dropConfig));\n\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setPartitionTableConfigs(Collections.singletonList(tableConfig));\n partitionPlanConfig.setFlowInstanceId(1L);\n partitionPlanConfig.setTimeoutMillis(180000L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setDatabaseId(1L);\n\n long t1 = System.currentTimeMillis();\n long t2 = System.currentTimeMillis() + 1000;\n TriggerConfig createTrigger = TestRandom.nextObject(TriggerConfig.class);\n createTrigger.setStartAt(new Date(t1));\n partitionPlanConfig.setCreationTrigger(createTrigger);\n\n TriggerConfig dropTrigger = TestRandom.nextObject(TriggerConfig.class);\n dropTrigger.setStartAt(new Date(t2));\n partitionPlanConfig.setDroppingTrigger(dropTrigger);\n Database database = TestRandom.nextObject(Database.class);\n database.setId(1L);\n Mockito.when(this.databaseService.detail(1L)).thenReturn(database);\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t1;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.when(this.scheduleService.create(Mockito.argThat(s -> {\n if (s == null) {\n return true;\n }\n TriggerConfig config1 = JsonUtils.fromJson(s.getTriggerConfigJson(), TriggerConfig.class);\n return config1.getStartAt().getTime() == t2;\n }))).thenReturn(TestRandom.nextObject(ScheduleEntity.class));\n Mockito.doNothing().when(this.scheduleService).enable(Mockito.isA(ScheduleEntity.class));\n this.partitionPlanScheduleService.submit(partitionPlanConfig);\n List actuals = this.partitionPlanTableRepository.findAll();\n Set scheduleIds = actuals.stream()\n .map(PartitionPlanTableEntity::getScheduleId).collect(Collectors.toSet());\n Assert.assertEquals(scheduleIds.size(), 2);\n }", + "completions": [ + "```java\n @Test\n public void submit_bothCreateAndDropTrigger_submitSucceed() throws ClassNotFoundException, SchedulerException {\n // Given\n PartitionPlanConfig partitionPlanConfig = new PartitionPlanConfig();\n partitionPlanConfig.setDatabaseId(1L);\n partitionPlanConfig.setEnabled(true);\n partitionPlanConfig.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanConfig.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanConfig.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanConfig.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanConfig.setTaskId(\"taskId\");\n partitionPlanConfig.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanConfig.setTimeoutMillis(1000L);\n\n Database database = new Database();\n database.setId(1L);\n database.setName(\"database\");\n database.setType(DatabaseType.MYSQL);\n database.setHost(\"localhost\");\n database.setPort(3306);\n database.setUsername(\"root\");\n database.setPassword(\"password\");\n database.setDatabaseName(\"test\");\n database.setProperties(\"{}\");\n database.setStatus(DatabaseStatus.ONLINE);\n database.setCreateTime(new Date());\n database.setUpdateTime(new Date());\n when(databaseService.detail(1L)).thenReturn(database);\n\n PartitionPlanEntity partitionPlanEntity = new PartitionPlanEntity();\n partitionPlanEntity.setId(1L);\n partitionPlanEntity.setDatabaseId(1L);\n partitionPlanEntity.setEnabled(true);\n partitionPlanEntity.setCreationTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 0 * *?\"));\n partitionPlanEntity.setDroppingTrigger(new TriggerConfig(TriggerType.CRON, \"0 0 1 * *?\"));\n partitionPlanEntity.setPartitionTableConfigs(Arrays.asList(\n new PartitionPlanTableConfig(\"table1\", true, \"invoker1\", \"param1\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key1\", PartitionPlanStrategy.CREATE, \"value1\"),\n new PartitionPlanKeyConfig(\"key2\", PartitionPlanStrategy.DROP, \"value2\")\n )),\n new PartitionPlanTableConfig(\"table2\", true, \"invoker2\", \"param2\", Arrays.asList(\n new PartitionPlanKeyConfig(\"key3\", PartitionPlanStrategy.CREATE, \"value3\"),\n new PartitionPlanKeyConfig(\"key4\", PartitionPlanStrategy.DROP, \"value4\")\n ))\n ));\n partitionPlanEntity.setFlowInstanceId(\"flowInstanceId\");\n partitionPlanEntity.setTaskId(\"taskId\");\n partitionPlanEntity.setErrorStrategy(ErrorStrategy.CONTINUE);\n partitionPlanEntity.setTimeoutMillis(1000L);\n when(partitionPlanRepository.save(any(PartitionPlanEntity.class))).thenReturn(partitionPlanEntity);\n\n ScheduleEntity createScheduleEntity = new ScheduleEntity();\n createScheduleEntity.setId(1L);\n createScheduleEntity.setDatabaseId(1L);\n createScheduleEntity.setTriggerConfig(new TriggerConfig(TriggerType" + ], + "line": 70, + "token": 530, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 41, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAllQueryWithFullTextSearchDisabled", + "reference": " @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n Instant time = Instant.now();\n Trace.KeyValue customField =\n Trace.KeyValue.newBuilder()\n .setVStr(\"value\")\n .setKey(\"customField\")\n .setFieldType(Schema.SchemaFieldType.KEYWORD)\n .build();\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time, List.of(customField)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n SearchResult termQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"customField:value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(termQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermStrQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"value\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermStrQuery.hits.size()).isEqualTo(1);\n\n SearchResult noTermNumericQuery =\n strictLogStoreWithoutFts.logSearcher.search(\n TEST_DATASET_NAME,\n \"1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(noTermNumericQuery.hits.size()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testAllQueryWithFullTextSearchDisabled() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"*\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNull(result.getInternalAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 42, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testEnsurePollExecutedCommitAsyncCallbacks", + "reference": " @Test\n public void testEnsurePollExecutedCommitAsyncCallbacks() {\n consumer = newConsumer();\n MockCommitCallback callback = new MockCommitCallback();\n completeCommitAsyncApplicationEventSuccessfully();\n doReturn(Fetch.empty()).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n completeFetchedCommittedOffsetApplicationEventSuccessfully(mkMap());\n\n consumer.assign(Collections.singleton(new TopicPartition(\"foo\", 0)));\n assertDoesNotThrow(() -> consumer.commitAsync(new HashMap<>(), callback));\n assertMockCommitCallbackInvoked(() -> consumer.poll(Duration.ZERO),\n callback,\n null);\n }", + "completions": [ + "```java\n @Test\n public void testPoll() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertNotNull(records);\n assertEquals(0, records.count());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 43, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static URI resolve(URI base, URI ref) {\n String path = ref.getRawPath();\n if (path != null && !path.isEmpty()) {\n path = base.getRawPath();\n if (path == null || !path.endsWith(\"/\")) {\n try {\n base = new URI(base.getScheme(), base.getAuthority(), base.getPath() + '/', null, null);\n } catch (URISyntaxException e) {\n throw new IllegalStateException(e);\n }\n }\n }\n return URIUtils.resolve(base, ref);\n }\n```\n###test function signature:\n```java\n @Test\n void testResolve_BasePathTrailingSlash()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testResolve_BasePathTrailingSlash", + "reference": " @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://host/sub/dir/\");\n assertEquals(\"http://host/sub/dir/file.jar\", resolve(base, \"file.jar\"));\n assertEquals(\"http://host/sub/dir/dir/file.jar\", resolve(base, \"dir/file.jar\"));\n assertEquals(\"http://host/sub/dir/?arg=val\", resolve(base, \"?arg=val\"));\n assertEquals(\"http://host/sub/dir/file?arg=val\", resolve(base, \"file?arg=val\"));\n assertEquals(\"http://host/sub/dir/dir/file?arg=val\", resolve(base, \"dir/file?arg=val\"));\n }", + "completions": [ + "```java\n @Test\n void testResolve_BasePathTrailingSlash() {\n URI base = URI.create(\"http://example.com/path/\");\n URI ref = URI.create(\"http://example.com/path/file.txt\");\n URI expected = URI.create(\"http://example.com/path/file.txt\");\n URI actual = resolve(base, ref);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 25, + "token": 187, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 44, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_collection_of_non_serializable_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_collection_of_non_serializable_object", + "reference": " @Test\n public void should_clone_collection_of_non_serializable_object() {\n List original = new ArrayList<>();\n original.add(new NonSerializableObject(\"value\"));\n List cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_collection_of_non_serializable_object() {\n // Given\n List nonSerializableList = new ArrayList<>();\n nonSerializableList.add(new NonSerializableObject(\"test\"));\n nonSerializableList.add(new NonSerializableObject(\"test2\"));\n\n // When\n List clonedList = objectCloner.clone(nonSerializableList);\n\n // Then\n assertNotSame(nonSerializableList, clonedList);\n assertEquals(nonSerializableList.size(), clonedList.size());\n for (int i = 0; i < nonSerializableList.size(); i++) {\n assertNotSame(nonSerializableList.get(i), clonedList.get(i));\n assertEquals(nonSerializableList.get(i), clonedList.get(i));\n }\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 45, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssign()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssign", + "reference": " @Test\n public void testAssign() {\n consumer = newConsumer();\n final TopicPartition tp = new TopicPartition(\"foo\", 3);\n consumer.assign(singleton(tp));\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().contains(tp));\n verify(applicationEventHandler).add(any(AssignmentChangeEvent.class));\n verify(applicationEventHandler).add(any(NewTopicsMetadataUpdateRequestEvent.class));\n }", + "completions": [ + "```java\n @Test\n public void testAssign() {\n // Create a mock Consumer instance\n Consumer mockConsumer = mock(Consumer.class);\n\n // Create a collection of TopicPartition instances\n Collection partitions = new ArrayList<>();\n partitions.add(new TopicPartition(\"testTopic\", 0));\n partitions.add(new TopicPartition(\"testTopic\", 1));\n\n // Call the assign method\n mockConsumer.assign(partitions);\n\n // Verify that the assign method was called with the correct parameters\n verify(mockConsumer).assign(partitions);\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 46, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String relativePath(Path absolutePath, Path basePath) {\n Preconditions.checkArgument(absolutePath.isAbsolute(), \"absolutePath must be an absolute path\");\n Preconditions.checkArgument(basePath.isAbsolute(), \"basePath must be an absolute path\");\n\n List absolutePathComponents = Lists.newArrayList(toPathComponents(absolutePath)); // make a copy\n List basePathComponents = toPathComponents(basePath);\n boolean hasCommonPrefix = basePathComponents.isEmpty(); // when basePath is \"/\", we always have a common prefix\n\n for (String base : basePathComponents) {\n if (absolutePathComponents.get(0).equals(base)) {\n absolutePathComponents.remove(0);\n hasCommonPrefix = true;\n } else {\n break;\n }\n }\n if (hasCommonPrefix) {\n return PATH_JOINER.join(absolutePathComponents);\n } else {\n return absolutePath.toString();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testRelativePath() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testRelativePath", + "reference": " @Test\n public void testRelativePath() throws Exception {\n assertEquals(\"a\", PathUtils.relativePath(Path.of(\"/a\"), Path.of(\"/\")));\n assertEquals(\"b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/a\")));\n assertEquals(\"b/c.json\", PathUtils.relativePath(Path.of(\"/a/b/c.json\"), Path.of(\"/a\")));\n assertEquals(\"c/d/e\", PathUtils.relativePath(Path.of(\"/a/b/c/d/e\"), Path.of(\"/a/b\")));\n assertEquals(\"/a/b\", PathUtils.relativePath(Path.of(\"/a/b\"), Path.of(\"/c/d\"))); // no common prefix\n }", + "completions": [ + "```java\n @Test\n public void testRelativePath() throws Exception {\n Path absolutePath = Paths.get(\"/a/b/c/d\");\n Path basePath = Paths.get(\"/a/b\");\n String expected = \"c/d\";\n String actual = relativePath(absolutePath, basePath);\n assertEquals(expected, actual);\n }\n```\n" + ], + "line": 31, + "token": 234, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 47, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static Boolean isJson(String value) {\n JsonElement jsonElement;\n try {\n jsonElement = JsonParser.parseString(value);\n } catch (Exception e) {\n return false;\n }\n if (jsonElement == null || jsonElement.isJsonNull()) {\n return false;\n }\n if (!jsonElement.isJsonObject() && !jsonElement.isJsonArray()) {\n return false;\n }\n return true;\n }\n```\n###test function signature:\n```java\n @Test\n void testIsJson()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJson", + "reference": " @Test\n void testIsJson() {\n assertThat(JsonUtil.isJson(\"value\")).isFalse();\n }", + "completions": [ + "```java\n @Test\n void testIsJson() {\n // Test case 1: Valid JSON object\n String jsonObject = \"{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}\";\n assertTrue(isJson(jsonObject));\n\n // Test case 2: Valid JSON array\n String jsonArray = \"[{\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"}, {\\\"name\\\":\\\"Jane\\\", \\\"age\\\":25, \\\"city\\\":\\\"Chicago\\\"}]\";\n assertTrue(isJson(jsonArray));\n\n // Test case 3: Invalid JSON (not an object or array)\n String invalidJson = \"\\\"name\\\":\\\"John\\\", \\\"age\\\":30, \\\"city\\\":\\\"New York\\\"\";\n assertFalse(isJson(invalidJson));\n\n // Test case 4: Empty string\n String emptyString = \"\";\n assertFalse(isJson(emptyString));\n\n // Test case 5: Null value\n String nullValue = null;\n assertFalse(isJson(nullValue));\n }\n```\n" + ], + "line": 26, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchOnlyHistogram()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchOnlyHistogram", + "reference": " @Test\n public void testSearchOnlyHistogram() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"_id:Message3 OR _id:Message4\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 0,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(0);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(2);\n\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(1).getDocCount()).isEqualTo(1);\n\n assertThat(\n Long.parseLong(histogram.getBuckets().get(0).getKeyAsString()) >= time.toEpochMilli())\n .isTrue();\n assertThat(\n Long.parseLong(histogram.getBuckets().get(1).getKeyAsString())\n <= time.plusSeconds(10).toEpochMilli())\n .isTrue();\n }", + "completions": [ + "```java\n @Test\n public void testSearchOnlyHistogram() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 0;\n AggBuilder aggBuilder = new AggBuilder();\n aggBuilder.addHistogramAgg(\"testField\", \"testInterval\");\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation().getHistogramAggs().size() > 0);\n assertEquals(\"testField\", result.getAggregation().getHistogramAggs().get(0).getField());\n assertEquals(\"testInterval\", result.getAggregation().getHistogramAggs().get(0).getInterval());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 49, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent", + "reference": " @Test\n public void testReadComponent() {\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n EntityMetadata entityMetadata = mock(EntityMetadata.class);\n VariableSource variableSource = mock(VariableSource.class);\n when(contentPermissionChecker.isPermitted(anyString(), any(), eq(BreadActions.BROWSE), any())).thenReturn(true);\n when(component.getEntityMetadata()).thenReturn(entityMetadata);\n when(entityMetadata.getId()).thenReturn(new DetachedEntityId(\"someid\"));\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(Arrays.asList(asset));\n when(assetVariableResolver.fromAsset(asset)).thenReturn(variableSource);\n ComponentXO componentXO = underTest.readComponent(\"someid\", \"testRepositoryName\");\n\n assertThat(componentXO, is(notNullValue()));\n assertThat(componentXO.getId(), is(\"someid\"));\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n Component component = mock(Component.class);\n Asset asset = mock(Asset.class);\n Supplier txSupplier = () -> storageTx;\n\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(txSupplier);\n when(storageTx.findComponent(componentId)).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(ImmutableList.of(asset));\n\n // When\n ComponentXO result = readComponent(repository, componentId);\n\n // Then\n verify(repository).facet(StorageFacet.class);\n verify(storageFacet).txSupplier();\n verify(storageTx).begin();\n verify(storageTx).findComponent(componentId);\n verify(storageTx).browseAssets(component);\n verify(storageTx).close();\n // Add assertions for the result\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 50, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ListResult listEntities(String tableName, String searchIndexName, List matchFilters,\n List queryFilters, List multiMatchFilter, String nextToken, List sorters, Class clazz) {\n SearchQuery searchQuery = createSearchQuery(matchFilters, queryFilters, multiMatchFilter);\n SearchRequest searchRequest = new SearchRequest(tableName, searchIndexName, searchQuery);\n if (!StringUtils.isEmpty(nextToken)) {\n byte[] tokenBytes = EncryptionUtil.decode(nextToken);\n searchRequest.getSearchQuery().setToken(tokenBytes);\n } else {\n if (sorters != null &&!sorters.isEmpty()) {\n searchQuery.setSort(new Sort(sorters));\n }\n }\n SearchRequest.ColumnsToGet columnsToGet = new SearchRequest.ColumnsToGet();\n columnsToGet.setColumns(ReflectionUtil.getPropertyNames(clazz));\n searchRequest.setColumnsToGet(columnsToGet);\n log.info(\"searchRequest:{}\", JsonUtil.toJsonString(searchRequest));\n SearchResponse searchResponse = otsClient.search(searchRequest);\n if (searchResponse == null || searchResponse.getRows() == null) {\n return ListResult.genSuccessListResult(null, 0);\n }\n byte[] nextTokenBytes = searchResponse.getNextToken();\n nextToken = nextTokenBytes == null || nextTokenBytes.length == 0 ? null : EncryptionUtil.encode(nextTokenBytes);\n List result = searchResponse.getRows().stream()\n .map(row -> OtsUtil.convertRowToDTO(row, clazz))\n .collect(Collectors.toList());\n return ListResult.genSuccessListResult(result, searchResponse.getTotalCount(), nextToken);\n }\n```\n###test function signature:\n```java\n @Test\n void testListEntities()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListEntities", + "reference": " @Test\n void testListEntities() {\n final List matchFilters = createMatchFilters();\n final List queryFilters = createQueryFilters();\n final ListResult expectedResult = new ListResult<>();\n FieldSort fieldSort = new FieldSort(OrderOtsConstant.GMT_CREATE_LONG);\n fieldSort.setOrder(SortOrder.DESC);\n expectedResult.setData(Arrays.asList());\n expectedResult.setCount(0L);\n String nextToken = \"CAESFQoTChEKDWdtdENyZWF0ZUxvbmcQARgBIlQKCQBI8UqGigEAAApHA0IAAAAxUzM1MzQzMTM0NjQzMjYzMzAzMzYyMzE2MTMzMzkzOTM1MzEzNjM2MzM2NDM2MzAzMDMwNjYzNTM1MzA2NjY0MzM=\";\n expectedResult.setNextToken(nextToken);\n final SearchResponse searchResponse = new SearchResponse(new Response(\"requestId\"));\n when(mockOtsClient.search(any(SearchRequest.class))).thenReturn(searchResponse);\n ListResult result = baseOtsHelper.listEntities(\"order\", \"order_index\", matchFilters, queryFilters, null, nextToken, Arrays.asList(fieldSort), OrderDTO.class);\n assertThat(result.getCount()).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n void testListEntities() {\n // Given\n String tableName = \"testTable\";\n String searchIndexName = \"testIndex\";\n List matchFilters = Arrays.asList(new OtsFilter(\"field1\", OtsFilter.CompareOperator.EQUAL, \"value1\"),\n new OtsFilter(\"field2\", OtsFilter.CompareOperator.EQUAL, \"value2\"));\n List queryFilters = Arrays.asList(new OtsFilter(\"field3\", OtsFilter.CompareOperator.EQUAL, \"value3\"),\n new OtsFilter(\"field4\", OtsFilter.CompareOperator.EQUAL, \"value4\"));\n List multiMatchFilter = Arrays.asList(new OtsFilter(\"field5\", OtsFilter.CompareOperator.EQUAL, \"value5\"),\n new OtsFilter(\"field6\", OtsFilter.CompareOperator.EQUAL, \"value6\"));\n String nextToken = \"testNextToken\";\n List sorters = Arrays.asList(new Sort.Sorter(\"field7\", Sort.SortOrder.ASC),\n new Sort.Sorter(\"field8\", Sort.SortOrder.DESC));\n Class clazz = TestDTO.class;\n\n // When\n ListResult result = listEntities(tableName, searchIndexName, matchFilters, queryFilters, multiMatchFilter, nextToken, sorters, clazz);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalCount());\n assertNull(result.getNextToken());\n assertTrue(result.getData().isEmpty());\n }\n```\n" + ], + "line": 38, + "token": 346, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 51, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public String[] listFiles(URI fileUri, boolean recursive) throws IOException {\n try {\n ImmutableList.Builder builder = ImmutableList.builder();\n String continuationToken = null;\n boolean isDone = false;\n String prefix = normalizeToDirectoryPrefix(fileUri);\n int fileCount = 0;\n while (!isDone) {\n ListObjectsV2Request.Builder listObjectsV2RequestBuilder =\n ListObjectsV2Request.builder().bucket(fileUri.getHost());\n if (!prefix.equals(DELIMITER)) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.prefix(prefix);\n }\n if (!recursive) {\n listObjectsV2RequestBuilder = listObjectsV2RequestBuilder.delimiter(DELIMITER);\n }\n if (continuationToken != null) {\n listObjectsV2RequestBuilder.continuationToken(continuationToken);\n }\n ListObjectsV2Request listObjectsV2Request = listObjectsV2RequestBuilder.build();\n LOG.debug(\"Trying to send ListObjectsV2Request {}\", listObjectsV2Request);\n ListObjectsV2Response listObjectsV2Response = s3Client.listObjectsV2(listObjectsV2Request);\n LOG.debug(\"Getting ListObjectsV2Response: {}\", listObjectsV2Response);\n List filesReturned = listObjectsV2Response.contents();\n fileCount += filesReturned.size();\n filesReturned.stream()\n .forEach(\n object -> {\n // Only add files and not directories\n if (!object.key().equals(fileUri.getPath())\n && !object.key().endsWith(DELIMITER)) {\n String fileKey = object.key();\n if (fileKey.startsWith(DELIMITER)) {\n fileKey = fileKey.substring(1);\n }\n builder.add(S3_SCHEME + fileUri.getHost() + DELIMITER + fileKey);\n }\n });\n if (fileCount == LIST_MAX_KEYS) {\n // check if we reached the max keys returned, if so abort and throw an error message\n LOG.error(\n \"Too many files ({}) returned from S3 when attempting to list object prefixes\",\n LIST_MAX_KEYS);\n throw new IllegalStateException(\n String.format(\n \"Max keys (%s) reached when attempting to list S3 objects\", LIST_MAX_KEYS));\n }\n isDone = !listObjectsV2Response.isTruncated();\n continuationToken = listObjectsV2Response.nextContinuationToken();\n }\n String[] listedFiles = builder.build().toArray(new String[0]);\n LOG.debug(\n \"Listed {} files from URI: {}, is recursive: {}\", listedFiles.length, fileUri, recursive);\n return listedFiles;\n } catch (Throwable t) {\n throw new IOException(t);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testListFilesInFolderRecursive", + "reference": " @Test\n public void testListFilesInFolderRecursive() throws Exception {\n String folder = \"list-files-rec\";\n String[] nestedFolders = new String[] {\"list-files-child-1\", \"list-files-child-2\"};\n String[] originalFiles = new String[] {\"a-list-3.txt\", \"b-list-3.txt\", \"c-list-3.txt\"};\n\n List expectedResultList = new ArrayList<>();\n for (String childFolder : nestedFolders) {\n String folderName = folder + DELIMITER + childFolder;\n for (String fileName : originalFiles) {\n createEmptyFile(folderName, fileName);\n expectedResultList.add(\n String.format(FILE_FORMAT, SCHEME, bucket, folderName + DELIMITER + fileName));\n }\n }\n String[] actualFiles =\n s3BlobFs.listFiles(URI.create(String.format(FILE_FORMAT, SCHEME, bucket, folder)), true);\n\n actualFiles =\n Arrays.stream(actualFiles).filter(x -> x.contains(\"list-3\")).toArray(String[]::new);\n assertEquals(actualFiles.length, expectedResultList.size());\n assertTrue(Arrays.equals(expectedResultList.toArray(), actualFiles));\n }", + "completions": [ + "```java\n @Test\n public void testListFilesInFolderRecursive() throws Exception {\n // Given\n URI fileUri = new URI(\"s3://bucket/folder/\");\n boolean recursive = true;\n String[] expectedFiles = {\n \"s3://bucket/folder/file1.txt\",\n \"s3://bucket/folder/file2.txt\",\n \"s3://bucket/folder/subfolder/file3.txt\",\n \"s3://bucket/folder/subfolder/file4.txt\"\n };\n\n // When\n String[] actualFiles = s3FileSystem.listFiles(fileUri, recursive);\n\n // Then\n assertArrayEquals(expectedFiles, actualFiles);\n }\n```\n" + ], + "line": 70, + "token": 490, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 52, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_nameDuplicate", + "reference": " @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // mock 数据\n RoleDO roleDO = randomPojo(RoleDO.class, o -> o.setName(\"role_name\"));\n roleMapper.insert(roleDO);\n // 准备参数\n String name = \"role_name\";\n\n // 调用,并断言异常\n assertServiceException(() -> roleService.validateRoleDuplicate(name, randomString(), null),\n ROLE_NAME_DUPLICATE, name);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_nameDuplicate() {\n // Given\n String name = \"duplicateName\";\n String code = \"uniqueCode\";\n Long id = 1L;\n RoleDO existingRole = new RoleDO();\n existingRole.setId(2L);\n existingRole.setName(name);\n existingRole.setCode(code);\n when(roleMapper.selectByName(name)).thenReturn(existingRole);\n\n // When\n try {\n roleService.validateRoleDuplicate(name, code, id);\n fail(\"Expected exception not thrown\");\n } catch (ServiceException e) {\n // Then\n assertEquals(ROLE_NAME_DUPLICATE, e.getCode());\n assertEquals(name, e.getArgs()[0]);\n }\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 53, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public String sign(Map map) {\n try {\n Date date = new Date(System.currentTimeMillis() + jwtProperties.getExpireTimeMills());\n Algorithm algorithm = Algorithm.HMAC256(jwtProperties.getTokenSecret());\n Map header = new HashMap<>(2);\n header.put(\"typ\", \"jwt\");\n JWTCreator.Builder builder = JWT.create()\n .withHeader(header)\n .withIssuedAt(new Date())\n .withExpiresAt(date);\n map.entrySet().forEach(entry -> {\n if (entry.getValue() instanceof Integer) {\n builder.withClaim(entry.getKey(), (Integer) entry.getValue());\n } else if (entry.getValue() instanceof Long) {\n builder.withClaim(entry.getKey(), (Long) entry.getValue());\n } else if (entry.getValue() instanceof Boolean) {\n builder.withClaim(entry.getKey(), (Boolean) entry.getValue());\n } else if (entry.getValue() instanceof String) {\n builder.withClaim(entry.getKey(), String.valueOf(entry.getValue()));\n } else if (entry.getValue() instanceof Double) {\n builder.withClaim(entry.getKey(), (Double) entry.getValue());\n } else if (entry.getValue() instanceof Date) {\n builder.withClaim(entry.getKey(), (Date) entry.getValue());\n }\n });\n return builder.sign(algorithm);\n } catch (Exception e) {\n return null;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void sign_generateJwtToken_succeed()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sign_generateJwtToken_succeed", + "reference": " @Test\n public void sign_generateJwtToken_succeed() {\n Map map = new HashMap<>();\n map.put(\"userId\", 123);\n map.put(\"username\", \"testUser\");\n String token = jwtService.sign(map);\n Assert.assertNotNull(token);\n }", + "completions": [ + "```java\n @Test\n public void sign_generateJwtToken_succeed() {\n // Given\n Map map = new HashMap<>();\n map.put(\"userId\", 1);\n map.put(\"username\", \"testUser\");\n map.put(\"isAdmin\", true);\n map.put(\"expireTime\", System.currentTimeMillis() + 3600000);\n\n // When\n String jwtToken = jwtService.sign(map);\n\n // Then\n assertNotNull(jwtToken);\n assertTrue(jwtToken.length() > 0);\n }\n```\n" + ], + "line": 41, + "token": 349, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 54, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_map_of_non_serializable_value_with_null_value", + "reference": " @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map original = new LinkedHashMap<>();\n original.put(\"null\", null);\n original.put(\"key\", new NonSerializableObject(\"value\"));\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_map_of_non_serializable_value_with_null_value() {\n Map map = new HashMap<>();\n map.put(\"key\", null);\n\n Map clonedMap = (Map) clone(map);\n\n assertNotSame(map, clonedMap);\n assertEquals(map, clonedMap);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 55, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void init(final DiscoveryConfig config) {\n if (this.namingService != null) {\n LOGGER.info(\"Nacos naming service already registered\");\n return;\n }\n Properties properties = config.getProps();\n Properties nacosProperties = new Properties();\n this.groupName = properties.getProperty(\"groupName\", \"SHENYU_GROUP\");\n String serverAddr = config.getServerList();\n nacosProperties.put(PropertyKeyConst.SERVER_ADDR, serverAddr);\n nacosProperties.put(PropertyKeyConst.NAMESPACE, properties.getProperty(NAMESPACE, \"\"));\n nacosProperties.put(PropertyKeyConst.USERNAME, properties.getProperty(PropertyKeyConst.USERNAME, \"\"));\n nacosProperties.put(PropertyKeyConst.PASSWORD, properties.getProperty(PropertyKeyConst.PASSWORD, \"\"));\n nacosProperties.put(PropertyKeyConst.ACCESS_KEY, properties.getProperty(PropertyKeyConst.ACCESS_KEY, \"\"));\n nacosProperties.put(PropertyKeyConst.SECRET_KEY, properties.getProperty(PropertyKeyConst.SECRET_KEY, \"\"));\n try {\n this.namingService = NamingFactory.createNamingService(nacosProperties);\n LOGGER.info(\"Nacos naming service initialized success\");\n } catch (NacosException e) {\n LOGGER.error(\"Error initializing Nacos naming service\", e);\n throw new ShenyuException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testInit", + "reference": " @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n // Set the discovery config\n setField(nacosDiscoveryServiceUnderTest.getClass(), \"namingService\", null);\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n config.setServerList(\"127.0.0.1:8848\");\n properties.setProperty(\"groupName\", \"SHENYU_GROUP\");\n config.setProps(properties);\n\n try (MockedStatic mockedNamingFactory = mockStatic(NamingFactory.class)) {\n // Mock the successful creation of NamingService\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenReturn(namingService);\n nacosDiscoveryServiceUnderTest.init(config);\n mockedNamingFactory.verify(() -> NamingFactory.createNamingService(any(Properties.class)));\n assertEquals(namingService, getField(nacosDiscoveryServiceUnderTest, \"namingService\"));\n // Mock the situation where NamingService fails to be created and throws an exception\n mockedNamingFactory.when(() -> NamingFactory.createNamingService(any(Properties.class)))\n .thenThrow(new NacosException());\n assertDoesNotThrow(() -> nacosDiscoveryServiceUnderTest.init(config));\n }\n }", + "completions": [ + "```java\n @Test\n void testInit() throws NoSuchFieldException, IllegalAccessException {\n DiscoveryConfig config = new DiscoveryConfig();\n Properties properties = new Properties();\n properties.put(\"groupName\", \"SHENYU_GROUP\");\n properties.put(\"serverAddr\", \"localhost:8848\");\n properties.put(\"namespace\", \"\");\n properties.put(\"username\", \"\");\n properties.put(\"password\", \"\");\n properties.put(\"accessKey\", \"\");\n properties.put(\"secretKey\", \"\");\n config.setProps(properties);\n config.setServerList(\"localhost:8848\");\n\n NacosDiscoveryService nacosDiscoveryService = new NacosDiscoveryService();\n nacosDiscoveryService.init(config);\n\n Field namingServiceField = NacosDiscoveryService.class.getDeclaredField(\"namingService\");\n namingServiceField.setAccessible(true);\n NamingService namingService = (NamingService) namingServiceField.get(nacosDiscoveryService);\n\n assertNotNull(namingService);\n }\n```\n" + ], + "line": 35, + "token": 273, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 56, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyRecoveryNotBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(850, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 2) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(750, 0, indexerConfig));\n\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n assertThat(recoveryTaskCreator.determineStartingOffset(1201, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThat(recoveryTaskCreator.determineStartingOffset(1200, 0, indexerConfig))\n .isEqualTo((recoveryStartOffset * 3) + 1);\n assertThatIllegalStateException()\n .isThrownBy(() -> recoveryTaskCreator.determineStartingOffset(1150, 0, indexerConfig));\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore))\n .contains(recoveryTask1, recoveryTask11);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyRecoveryNotBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(1000);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 57, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public boolean isJsonMime(String mediaType) {\n // \"* / *\" is default to JSON\n if (\"*/*\".equals(mediaType)) {\n return true;\n }\n\n try {\n return isJsonMime(MediaType.parseMediaType(mediaType));\n } catch (InvalidMediaTypeException e) {\n }\n return false;\n }\n```\n###test function signature:\n```java\n @Test\n public void testIsJsonMime()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIsJsonMime", + "reference": " @Test\n public void testIsJsonMime() {\n assertFalse(apiClient.isJsonMime(null));\n assertFalse(apiClient.isJsonMime(\"\"));\n assertFalse(apiClient.isJsonMime(\"text/plain\"));\n assertFalse(apiClient.isJsonMime(\"application/xml\"));\n assertFalse(apiClient.isJsonMime(\"application/jsonp\"));\n assertFalse(apiClient.isJsonMime(\"example/json\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+jsonx\"));\n assertFalse(apiClient.isJsonMime(\"example/foo+bar+xjson\"));\n\n assertTrue(apiClient.isJsonMime(\"application/json\"));\n assertTrue(apiClient.isJsonMime(\"application/json; charset=UTF8\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/JSON\"));\n\n assertTrue(apiClient.isJsonMime(\"application/problem+json\"));\n assertTrue(apiClient.isJsonMime(\"APPLICATION/PROBLEM+JSON\"));\n assertTrue(apiClient.isJsonMime(\"application/json\\t\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+bar+json\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json;x;y\"));\n assertTrue(apiClient.isJsonMime(\"example/foo+json\\t;\"));\n assertTrue(apiClient.isJsonMime(\"Example/fOO+JSON\"));\n }", + "completions": [ + "```java\n @Test\n public void testIsJsonMime() {\n // Test case 1: \"* / *\"\n assertTrue(isJsonMime(\"*/*\"));\n\n // Test case 2: \"application/json\"\n assertTrue(isJsonMime(\"application/json\"));\n\n // Test case 3: \"application/json; charset=UTF-8\"\n assertTrue(isJsonMime(\"application/json; charset=UTF-8\"));\n\n // Test case 4: \"application/xml\"\n assertFalse(isJsonMime(\"application/xml\"));\n\n // Test case 5: \"text/plain\"\n assertFalse(isJsonMime(\"text/plain\"));\n\n // Test case 6: \"invalid/media/type\"\n assertFalse(isJsonMime(\"invalid/media/type\"));\n\n // Test case 7: null\n assertFalse(isJsonMime(null));\n }\n```\n" + ], + "line": 22, + "token": 139, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public void validateDpopThumprint(String existingThumprint, String requestThumprint) {\n if (StringUtils.isBlank(existingThumprint) && isFalse(appConfiguration.getDpopJktForceForAuthorizationCode())) {\n return; // nothing to check\n }\n\n if (!StringUtils.equals(existingThumprint, requestThumprint)) {\n log.debug(\"DPoP Thumprint between saved one '{}' and send in request '{}' does NOT match. Reject request.\", existingThumprint, requestThumprint);\n throw new WebApplicationException(Response\n .status(Response.Status.BAD_REQUEST)\n .type(MediaType.APPLICATION_JSON_TYPE)\n .entity(errorResponseFactory.errorAsJson(TokenErrorResponseType.INVALID_DPOP_PROOF, \"Thumprint does not match\"))\n .cacheControl(ServerUtil.cacheControl(true, false))\n .header(PRAGMA, NO_CACHE)\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass", + "reference": " @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n dpopService.validateDpopThumprint(\"test\", \"test\");\n }", + "completions": [ + "```java\n @Test\n public void validateDpopThumprint_whenExistingDpopThumprintMatchToActual_shouldPass() {\n // Given\n String existingThumprint = \"existingThumprint\";\n String requestThumprint = \"existingThumprint\";\n when(appConfiguration.getDpopJktForceForAuthorizationCode()).thenReturn(false);\n\n // When\n // Call the method to test\n tokenService.validateDpopThumprint(existingThumprint, requestThumprint);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 26, + "token": 209, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 59, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static String getExpandUrl(String url, Object paramObj) {\n UriComponentsBuilder builder = UriComponentsBuilder.fromHttpUrl(url);\n if (paramObj instanceof Map) {\n for (Map.Entry param : ((Map) paramObj).entrySet()) {\n if (null != param.getValue()) {\n builder.queryParam(param.getKey(), param.getValue());\n }\n }\n } else {\n for (Field param : paramObj.getClass().getDeclaredFields()) {\n param.setAccessible(true);\n\n try {\n Object value = param.get(paramObj);\n if (null != value) {\n if (value instanceof List) {\n builder.queryParam(param.getName(), (List) value);\n } else {\n builder.queryParam(param.getName(), value);\n }\n }\n } catch (IllegalAccessException e) {\n log.warn(\"Field: {} value retrieval failed, exception: {}\", param.getName(), e);\n }\n }\n }\n\n return builder.build().encode().toUri().toString();\n }\n```\n###test function signature:\n```java\n @Test\n void testGetExpandUrl()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetExpandUrl", + "reference": " @Test\n void testGetExpandUrl() {\n String result = HttpUtil.getExpandUrl(\"http://baidu.com\", \"paramObj\");\n Assertions.assertNotNull(result);\n }", + "completions": [ + "```java\n @Test\n void testGetExpandUrl() {\n // Test with a Map\n Map map = new HashMap<>();\n map.put(\"param1\", \"value1\");\n map.put(\"param2\", \"value2\");\n String url = \"http://example.com\";\n String expected = \"http://example.com?param1=value1¶m2=value2\";\n String actual = getExpandUrl(url, map);\n assertEquals(expected, actual);\n\n // Test with a custom object\n TestObject obj = new TestObject(\"value1\", \"value2\");\n expected = \"http://example.com?param1=value1¶m2=value2\";\n actual = getExpandUrl(url, obj);\n assertEquals(expected, actual);\n }\n\n public static class TestObject {\n private String param1;\n private String param2;\n\n public TestObject(String param1, String param2) {\n this.param1 = param1;\n this.param2 = param2;\n }\n\n public String getParam1() {\n return param1;\n }\n\n public String getParam2() {\n return param2;\n }\n }\n```\n" + ], + "line": 38, + "token": 277, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 60, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"asdf\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForInvalidExpirationTime() throws Exception {\n String invalidExpirationTime = \"invalid\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(invalidExpirationTime);\n assertEquals(CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 61, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Get\n @Path(\"/:indexName/_mapping\")\n public HttpResponse mapping(\n @Param(\"indexName\") Optional indexName,\n @Param(\"startTimeEpochMs\") Optional startTimeEpochMs,\n @Param(\"endTimeEpochMs\") Optional endTimeEpochMs)\n throws IOException {\n // Use a tree map so the results are naturally sorted\n Map> propertiesMap = new TreeMap<>();\n\n // we default the schema search to the last hour if params are not provided\n AstraSearch.SchemaResult schemaResult =\n searcher.getSchema(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(\n startTimeEpochMs.orElse(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build());\n\n Map schema = SearchResultUtils.fromSchemaResultProto(schemaResult);\n schema.forEach((key, value) -> propertiesMap.put(key, Map.of(\"type\", value.getName())));\n\n // todo - remove this after we add support for a \"date\" type\n // override the timestamp as a date field for proper autocomplete\n propertiesMap.put(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, Map.of(\"type\", \"date\"));\n\n return HttpResponse.of(\n HttpStatus.OK,\n MediaType.JSON,\n JsonUtil.writeAsString(\n ImmutableMap.of(\n indexName.orElseThrow(),\n ImmutableMap.of(\"mappings\", ImmutableMap.of(\"properties\", propertiesMap)))));\n }\n```\n###test function signature:\n```java\n @Test\n public void testIndexMapping() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testIndexMapping", + "reference": " @Test\n public void testIndexMapping() throws IOException {\n AstraQueryServiceBase searcher = mock(AstraQueryServiceBase.class);\n ElasticsearchApiService serviceUnderTest = new ElasticsearchApiService(searcher);\n\n Instant start = Instant.now();\n Instant end = start.minusSeconds(60);\n\n when(searcher.getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build())))\n .thenReturn(AstraSearch.SchemaResult.newBuilder().build());\n\n HttpResponse response =\n serviceUnderTest.mapping(\n Optional.of(\"foo\"), Optional.of(start.toEpochMilli()), Optional.of(end.toEpochMilli()));\n verify(searcher)\n .getSchema(\n eq(\n AstraSearch.SchemaRequest.newBuilder()\n .setDataset(\"foo\")\n .setStartTimeEpochMs(start.toEpochMilli())\n .setEndTimeEpochMs(end.toEpochMilli())\n .build()));\n\n // handle response\n AggregatedHttpResponse aggregatedRes = response.aggregate().join();\n String body = aggregatedRes.content(StandardCharsets.UTF_8);\n JsonNode jsonNode = new ObjectMapper().readTree(body);\n\n assertThat(aggregatedRes.status().code()).isEqualTo(200);\n\n assertThat(jsonNode.findValue(\"foo\")).isNotNull();\n assertThat(\n jsonNode.findValue(\"foo\").findValue(LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName))\n .isNotNull();\n\n when(searcher.getSchema(any()))\n .thenAnswer(\n invocationOnMock -> {\n AstraSearch.SchemaRequest request =\n ((AstraSearch.SchemaRequest) invocationOnMock.getArguments()[0]);\n assertThat(request.getDataset()).isEqualTo(\"bar\");\n assertThat(request.getStartTimeEpochMs())\n .isCloseTo(\n Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli(),\n Offset.offset(1000L));\n assertThat(request.getEndTimeEpochMs())\n .isCloseTo(Instant.now().toEpochMilli(), Offset.offset(1000L));\n return AstraSearch.SchemaResult.newBuilder().build();\n });\n serviceUnderTest.mapping(Optional.of(\"bar\"), Optional.empty(), Optional.empty());\n }", + "completions": [ + "```java\n @Test\n public void testIndexMapping() throws IOException {\n Optional indexName = Optional.of(\"testIndex\");\n Optional startTimeEpochMs = Optional.of(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli());\n Optional endTimeEpochMs = Optional.of(Instant.now().toEpochMilli());\n\n AstraSearch.SchemaResult schemaResult =\n AstraSearch.SchemaResult.newBuilder()\n .setDataset(indexName.orElse(\"*\"))\n .setStartTimeEpochMs(startTimeEpochMs.orElse(Instant.now().minus(1, ChronoUnit.HOURS).toEpochMilli()))\n .setEndTimeEpochMs(endTimeEpochMs.orElse(Instant.now().toEpochMilli()))\n .build();\n\n when(searcher.getSchema(any(AstraSearch.SchemaRequest.class))).thenReturn(schemaResult);\n\n HttpResponse response = mapping(indexName, startTimeEpochMs, endTimeEpochMs);\n\n assertEquals(HttpStatus.OK, response.getStatus());\n assertEquals(MediaType.JSON, response.getContentType());\n assertNotNull(response.getContent());\n }\n```\n" + ], + "line": 43, + "token": 354, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 62, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "@SuppressWarnings", + "reference": " @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n consumer = newConsumer();\n Exception e = assertThrows(UnsupportedOperationException.class, () -> consumer.poll(0L));\n assertEquals(\"Consumer.poll(long) is not supported when \\\"group.protocol\\\" is \\\"consumer\\\". \" +\n \"This method is deprecated and will be removed in the next major release.\", e.getMessage());\n }", + "completions": [ + "```java\n @Test\n @SuppressWarnings(\"deprecation\")\n public void testPollLongThrowsException() {\n // Given\n Duration timeout = Duration.ofMillis(1000);\n Consumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // When\n ConsumerRecords records = consumer.poll(timeout);\n\n // Then\n assertTrue(records.isEmpty());\n }\n```\n" + ], + "line": 46, + "token": 373, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 63, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testNullSearchString()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testNullSearchString", + "reference": " @Test\n public void testNullSearchString() {\n Instant time = Instant.ofEpochSecond(1593365471);\n loadTestData(time);\n\n assertThatExceptionOfType(IllegalArgumentException.class)\n .isThrownBy(\n () ->\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME + \"miss\",\n null,\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\")));\n }", + "completions": [ + "```java\n @Test\n public void testNullSearchString() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = null;\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n try {\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n fail(\"Expected an IllegalArgumentException to be thrown\");\n } catch (IllegalArgumentException e) {\n // Expected\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 64, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture sendMessageBack(ProxyContext ctx, ReceiptHandle handle, String messageId,\n ConsumerSendMsgBackRequestHeader requestHeader, long timeoutMillis) {\n // Build the response.\n final RemotingCommand response = RemotingCommand.createResponseCommand(ResponseCode.SUCCESS, null, null);\n\n Integer delayLevel = requestHeader.getDelayLevel();\n if (Objects.isNull(delayLevel)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument delay level is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n Long offset = requestHeader.getOffset();\n if (Objects.isNull(offset)) {\n response.setCode(ResponseCode.ILLEGAL_OPERATION);\n response.setRemark(\"argument offset is null\");\n return CompletableFuture.completedFuture(response);\n }\n\n VirtualQueue virtualQueue = new VirtualQueue(requestHeader.getBname());\n\n CompletableFuture topicFuture = topicOf(requestHeader.getOriginTopic());\n CompletableFuture groupFuture = consumerGroupOf(requestHeader.getGroup());\n return topicFuture.thenCombine(groupFuture, (topic, group) -> {\n if (topic.getTopicId() != virtualQueue.topicId()) {\n LOGGER.error(\"Topic id in request header {} does not match topic id in message queue {}, maybe the topic is recreated.\",\n topic.getTopicId(), virtualQueue.topicId());\n throw new ProxyException(apache.rocketmq.v2.Code.TOPIC_NOT_FOUND, \"Topic resource does not exist.\");\n }\n return Pair.of(topic, group);\n }).thenCompose(pair -> {\n Topic topic = pair.getLeft();\n ConsumerGroup group = pair.getRight();\n\n return store.pull(StoreContext.EMPTY, group.getGroupId(), topic.getTopicId(), virtualQueue.physicalQueueId(),\n Filter.DEFAULT_FILTER, requestHeader.getOffset(), 1, false)\n .thenApply(pullResult -> {\n if (pullResult.status() == com.automq.rocketmq.store.model.message.PullResult.Status.FOUND) {\n return pullResult.messageList().get(0);\n }\n throw new ProxyException(apache.rocketmq.v2.Code.MESSAGE_NOT_FOUND, \"Message not found from server.\");\n }).thenCompose(messageExt -> {\n if (messageExt.deliveryAttempts() > group.getMaxDeliveryAttempt()) {\n return deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n }\n\n // Message consume retry strategy\n // <0: no retry,put into DLQ directly\n // =0: broker control retry frequency\n // >0: client control retry frequency\n return switch (Integer.compare(delayLevel, 0)) {\n case -1 ->\n deadLetterService.send((ProxyContextExt) ctx, group.getGroupId(), messageExt.message());\n case 0 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n // Keep the same logic as apache RocketMQ.\n int serverDelayLevel = messageExt.deliveryAttempts() + 1;\n messageExt.setDeliveryAttempts(serverDelayLevel);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(serverDelayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put messageExt to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n case 1 -> topicOf(MixAll.RETRY_GROUP_TOPIC_PREFIX + requestHeader.getGroup())\n .thenCompose(retryTopic -> {\n messageExt.setDeliveryAttempts(messageExt.deliveryAttempts() + 1);\n messageExt.setOriginalQueueOffset(messageExt.originalOffset());\n\n FlatMessage message = messageExt.message();\n message.mutateTopicId(retryTopic.getTopicId());\n\n message.systemProperties().mutateDeliveryTimestamp(FlatMessageUtil.calculateDeliveryTimestamp(delayLevel));\n return store.put(StoreContext.EMPTY, message)\n .exceptionally(ex -> {\n LOGGER.error(\"Put message to retry topic failed\", ex);\n return null;\n })\n .thenApply(ignore -> null);\n });\n default -> throw new IllegalStateException(\"Never reach here\");\n };\n });\n }).whenComplete((nil, throwable) -> {\n if (throwable != null) {\n response.setCode(ResponseCode.SYSTEM_ERROR);\n response.setRemark(throwable.getMessage());\n }\n }).thenApply(nil -> response);\n }\n```\n###test function signature:\n```java\n @Test\n void sendMessageBack()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "sendMessageBack", + "reference": " @Test\n void sendMessageBack() {\n String topicName = \"topic\";\n VirtualQueue virtualQueue = new VirtualQueue(2, 0);\n ConsumerSendMsgBackRequestHeader header = new ConsumerSendMsgBackRequestHeader();\n header.setBname(virtualQueue.brokerName());\n header.setGroup(\"group\");\n header.setOriginTopic(topicName);\n header.setOffset(0L);\n header.setDelayLevel(0);\n\n CompletionException exception = assertThrowsExactly(CompletionException.class, () -> messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join());\n assertInstanceOf(ProxyException.class, ExceptionUtils.getRealException(exception));\n ProxyException realException = (ProxyException) ExceptionUtils.getRealException(exception);\n assertEquals(realException.getErrorCode(), Code.MESSAGE_NOT_FOUND);\n assertEquals(realException.getMessage(), \"Message not found from server.\");\n\n long topicId = metadataService.topicOf(topicName).join().getTopicId();\n messageStore.put(StoreContext.EMPTY, FlatMessageUtil.convertTo(ProxyContextExt.create(), topicId, 0, \"\", new Message(topicName, \"\", new byte[] {})));\n\n // Broker controlled delay level.\n header.setDelayLevel(0);\n RemotingCommand response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Client controlled delay level.\n header.setDelayLevel(16);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.never()).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n\n // Forward message into dead letter topic.\n header.setDelayLevel(-1);\n response = messageService.sendMessageBack(ProxyContextExt.create(), null, null, header, 0L).join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n Mockito.verify(deadLetterSender, Mockito.times(1)).send(Mockito.any(), Mockito.anyLong(), Mockito.any());\n }", + "completions": [ + "```java\n @Test\n void sendMessageBack() {\n // Given\n ProxyContext ctx = mock(ProxyContext.class);\n ReceiptHandle handle = mock(ReceiptHandle.class);\n String messageId = \"testMessageId\";\n ConsumerSendMsgBackRequestHeader requestHeader = mock(ConsumerSendMsgBackRequestHeader.class);\n long timeoutMillis = 1000L;\n\n when(requestHeader.getDelayLevel()).thenReturn(1);\n when(requestHeader.getOffset()).thenReturn(1L);\n when(requestHeader.getOriginTopic()).thenReturn(\"testOriginTopic\");\n when(requestHeader.getGroup()).thenReturn(\"testGroup\");\n when(requestHeader.getBname()).thenReturn(\"testBname\");\n\n Topic topic = mock(Topic.class);\n when(topic.getTopicId()).thenReturn(1);\n\n ConsumerGroup group = mock(ConsumerGroup.class);\n when(group.getGroupId()).thenReturn(\"testGroupId\");\n when(group.getMaxDeliveryAttempt()).thenReturn(3);\n\n FlatMessage message = mock(FlatMessage.class);\n when(message.deliveryAttempts()).thenReturn(1);\n when(message.originalOffset()).thenReturn(1L);\n\n MessageExt messageExt = mock(MessageExt.class);\n when(messageExt.message()).thenReturn(message);\n when(messageExt.deliveryAttempts()).thenReturn(1);\n\n PullResult pullResult = mock(PullResult.class);\n when(pullResult.status()).thenReturn(PullResult.Status.FOUND);\n when(pullResult.messageList()).thenReturn(List.of(messageExt));\n\n Store store = mock(Store.class);\n when(store.pull(any(), any(), any(), any(), any(), any(), any(), any())).thenReturn(CompletableFuture.completedFuture(pullResult));\n when(store.put(any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n DeadLetterService deadLetterService = mock(DeadLetterService.class);\n when(deadLetterService.send(any(), any(), any())).thenReturn(CompletableFuture.completedFuture(null));\n\n ProxyContextExt ctxExt = mock(ProxyContextExt.class);\n\n TopicService topicService = mock(TopicService.class);\n when(topicService.topicOf(any())).thenReturn(CompletableFuture.completedFuture(topic));\n when(topicService.consumerGroupOf(any())).thenReturn(CompletableFuture.completedFuture(group));\n\n ConsumerService consumerService = new ConsumerService(store, deadLetterService, topicService);\n\n // When\n CompletableFuture future = consumerService.sendMessageBack(ctx, handle, messageId, requestHeader, timeoutMillis);\n\n // Then\n RemotingCommand response = future.join();\n assertEquals(ResponseCode.SUCCESS, response.getCode());\n }\n```\n" + ], + "line": 99, + "token": 793, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 65, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInRequestThreads()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInRequestThreads", + "reference": " @Test\n public void testGetTracerInRequestThreads() {\n ApolloAuditTracer mockTracer = new ApolloAuditTracer(Mockito.mock(ApolloAuditScopeManager.class), supplier);\n RequestAttributes mockRequestAttributes = Mockito.mock(RequestAttributes.class);\n RequestContextHolder.setRequestAttributes(mockRequestAttributes);\n Mockito.when(mockRequestAttributes.getAttribute(Mockito.eq(ApolloAuditConstants.TRACER), Mockito.eq(RequestAttributes.SCOPE_REQUEST)))\n .thenReturn(mockTracer);\n ApolloAuditTracer get = traceContext.tracer();\n assertNotNull(get);\n Mockito.verify(traceContext, Mockito.times(0))\n .setTracer(Mockito.any(ApolloAuditTracer.class));\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInRequestThreads() {\n // Given\n RequestAttributes requestAttributes = Mockito.mock(RequestAttributes.class);\n ApolloAuditTracer tracer = Mockito.mock(ApolloAuditTracer.class);\n Mockito.when(requestAttributes.getAttribute(ApolloAuditConstants.TRACER, RequestAttributes.SCOPE_REQUEST)).thenReturn(tracer);\n Mockito.when(RequestContextHolder.getRequestAttributes()).thenReturn(requestAttributes);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n Assert.assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Map generate(String modelName) {\n Map kv = new HashMap<>();\n if (modelName != null) {\n final Schema schema = this.globalContext.getSchemaMap().get(modelName);\n if (schema != null) {\n String example = Json.pretty(this.resolveModelToExample(modelName, schema));\n\n if (example != null) {\n kv.put(EXAMPLE, example);\n return Map.copyOf(kv);\n }\n }\n }\n return Collections.emptyMap();\n }\n```\n###test function signature:\n```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "shouldGenerateOneOfWhenOneOfInRoot", + "reference": " @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n OpenAPIModelGenerator generator = setupPayloadGenerator();\n\n Map example = generator.generate(\"PetType\");\n String exampleJson = example.get(\"example\");\n\n Assertions.assertThat(exampleJson)\n .contains(\"PetTypeONE_OF#/components/schemas/Husky\")\n .contains(\"PetTypeONE_OF#/components/schemas/Labrador\");\n }", + "completions": [ + "```java\n @Test\n void shouldGenerateOneOfWhenOneOfInRoot() throws Exception {\n // Given\n String modelName = \"ModelName\";\n Schema schema = new Schema();\n schema.setOneOf(Arrays.asList(new Schema(), new Schema()));\n this.globalContext.getSchemaMap().put(modelName, schema);\n\n // When\n Map result = this.generate(modelName);\n\n // Then\n assertNotNull(result);\n assertFalse(result.isEmpty());\n assertTrue(result.containsKey(EXAMPLE));\n assertNotNull(result.get(EXAMPLE));\n }\n```\n" + ], + "line": 25, + "token": 181, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 67, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ApolloAuditTracer tracer() {\n RequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();\n if (requestAttributes != null) {\n Object tracer = requestAttributes.getAttribute(ApolloAuditConstants.TRACER,\n RequestAttributes.SCOPE_REQUEST);\n if (tracer != null) {\n return ((ApolloAuditTracer) tracer);\n } else {\n ApolloAuditTracer newTracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n setTracer(newTracer);\n return newTracer;\n }\n }\n return null;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetTracerInAnotherThreadButSameRequest", + "reference": " @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n ApolloAuditTracer mockTracer = Mockito.mock(ApolloAuditTracer.class);\n {\n Mockito.when(traceContext.tracer()).thenReturn(mockTracer);\n }\n CountDownLatch latch = new CountDownLatch(1);\n Executors.newSingleThreadExecutor().submit(() -> {\n ApolloAuditTracer tracer = traceContext.tracer();\n\n assertEquals(mockTracer, tracer);\n\n latch.countDown();\n });\n }", + "completions": [ + "```java\n @Test\n public void testGetTracerInAnotherThreadButSameRequest() {\n // Given\n RequestAttributes requestAttributes = new ServletRequestAttributes(new MockHttpServletRequest());\n RequestContextHolder.setRequestAttributes(requestAttributes);\n ApolloAuditTracer tracer = new ApolloAuditTracer(new ApolloAuditScopeManager(), operatorSupplier);\n requestAttributes.setAttribute(ApolloAuditConstants.TRACER, tracer, RequestAttributes.SCOPE_REQUEST);\n\n // When\n ApolloAuditTracer result = tracer();\n\n // Then\n assertEquals(tracer, result);\n }\n```\n" + ], + "line": 26, + "token": 159, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testSearchMultipleItemsAndIndices()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSearchMultipleItemsAndIndices", + "reference": " @Test\n public void testSearchMultipleItemsAndIndices() {\n Instant time = Instant.now();\n loadTestData(time);\n SearchResult babies =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(2).toEpochMilli(),\n 10,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"));\n assertThat(babies.hits.size()).isEqualTo(1);\n\n InternalDateHistogram histogram =\n (InternalDateHistogram) Objects.requireNonNull(babies.internalAggregation);\n assertThat(histogram.getBuckets().size()).isEqualTo(1);\n assertThat(histogram.getBuckets().get(0).getDocCount()).isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testSearchMultipleItemsAndIndices() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new AggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(10, result.getResults().size());\n assertTrue(result.getElapsedTimeMicros() > 0);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getFailedShards());\n assertEquals(1, result.getSuccessfulShards());\n assertNotNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 69, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n static Map>> getCandidateJobs(String tableNameWithType,\n Map> allJobMetadata)\n throws Exception {\n long nowMs = System.currentTimeMillis();\n Map>> candidates = new HashMap<>();\n // If the job started most recently has already completed, then skip retry for the table.\n Pair latestStartedJob = null;\n Pair latestCompletedJob = null;\n // The processing order of job metadata from the given Map is not deterministic. Track the completed original\n // jobs so that we can simply skip the retry jobs belonging to the completed original jobs.\n Map completedOriginalJobs = new HashMap<>();\n Set cancelledOriginalJobs = new HashSet<>();\n for (Map.Entry> entry : allJobMetadata.entrySet()) {\n String jobId = entry.getKey();\n Map jobMetadata = entry.getValue();\n long statsUpdatedAt = Long.parseLong(jobMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));\n String jobStatsInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);\n if (StringUtils.isEmpty(jobStatsInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job progress stats\", jobId);\n continue;\n }\n String jobCtxInStr = jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n if (StringUtils.isEmpty(jobCtxInStr)) {\n LOGGER.info(\"Skip rebalance job: {} as it has no job context\", jobId);\n continue;\n }\n TableRebalanceProgressStats jobStats = JsonUtils.stringToObject(jobStatsInStr, TableRebalanceProgressStats.class);\n TableRebalanceContext jobCtx = JsonUtils.stringToObject(jobCtxInStr, TableRebalanceContext.class);\n long jobStartTimeMs = jobStats.getStartTimeMs();\n if (latestStartedJob == null || latestStartedJob.getRight() < jobStartTimeMs) {\n latestStartedJob = Pair.of(jobId, jobStartTimeMs);\n }\n String originalJobId = jobCtx.getOriginalJobId();\n RebalanceResult.Status jobStatus = jobStats.getStatus();\n if (jobStatus == RebalanceResult.Status.DONE || jobStatus == RebalanceResult.Status.NO_OP) {\n LOGGER.info(\"Skip rebalance job: {} as it has completed with status: {}\", jobId, jobStatus);\n completedOriginalJobs.put(originalJobId, jobId);\n if (latestCompletedJob == null || latestCompletedJob.getRight() < jobStartTimeMs) {\n latestCompletedJob = Pair.of(jobId, jobStartTimeMs);\n }\n continue;\n }\n if (jobStatus == RebalanceResult.Status.FAILED || jobStatus == RebalanceResult.Status.ABORTED) {\n LOGGER.info(\"Found rebalance job: {} for original job: {} has been stopped with status: {}\", jobId,\n originalJobId, jobStatus);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n continue;\n }\n if (jobStatus == RebalanceResult.Status.CANCELLED) {\n LOGGER.info(\"Found cancelled rebalance job: {} for original job: {}\", jobId, originalJobId);\n cancelledOriginalJobs.add(originalJobId);\n continue;\n }\n // Check if an IN_PROGRESS job is still actively running.\n long heartbeatTimeoutMs = jobCtx.getConfig().getHeartbeatTimeoutInMs();\n if (nowMs - statsUpdatedAt < heartbeatTimeoutMs) {\n LOGGER.info(\"Rebalance job: {} is actively running with status updated at: {} within timeout: {}. Skip \"\n + \"retry for table: {}\", jobId, statsUpdatedAt, heartbeatTimeoutMs, tableNameWithType);\n return Collections.emptyMap();\n }\n // The job is considered failed, but it's possible it is still running, then we might end up with more than one\n // rebalance jobs running in parallel for a table. The rebalance algorithm is idempotent, so this should be fine\n // for the correctness.\n LOGGER.info(\"Found stuck rebalance job: {} for original job: {}\", jobId, originalJobId);\n candidates.computeIfAbsent(originalJobId, (k) -> new HashSet<>()).add(Pair.of(jobCtx, jobStartTimeMs));\n }\n if (latestCompletedJob != null && latestCompletedJob.getLeft().equals(latestStartedJob.getLeft())) {\n LOGGER.info(\"Rebalance job: {} started most recently has already done. Skip retry for table: {}\",\n latestCompletedJob.getLeft(), tableNameWithType);\n return Collections.emptyMap();\n }\n for (String jobId : cancelledOriginalJobs) {\n LOGGER.info(\"Skip original job: {} as it's cancelled\", jobId);\n candidates.remove(jobId);\n }\n for (Map.Entry entry : completedOriginalJobs.entrySet()) {\n LOGGER.info(\"Skip original job: {} as it's completed by attempt: {}\", entry.getKey(), entry.getValue());\n candidates.remove(entry.getKey());\n }\n return candidates;\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetCandidateJobs()\n throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetCandidateJobs", + "reference": " @Test\n public void testGetCandidateJobs()\n throws Exception {\n String tableName = \"table01\";\n Map> allJobMetadata = new HashMap<>();\n\n // Original job run as job1, and all its retry jobs failed too.\n RebalanceConfig jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n TableRebalanceProgressStats stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(1000);\n TableRebalanceContext jobCtx = TableRebalanceContext.forInitialAttempt(\"job1\", jobCfg);\n Map jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job1\", stats, jobCtx);\n allJobMetadata.put(\"job1\", jobMetadata);\n // 3 failed retry runs for job1\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 2, 1100, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 3, 1200, RebalanceResult.Status.ABORTED);\n allJobMetadata.put(\"job1_3\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job1\", 4, 1300, RebalanceResult.Status.FAILED);\n allJobMetadata.put(\"job1_4\", jobMetadata);\n\n // Original job run as job2, and its retry job job2_1 completed.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(2000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job2\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job2\", stats, jobCtx);\n allJobMetadata.put(\"job2\", jobMetadata);\n jobMetadata = createDummyJobMetadata(tableName, \"job2\", 2, 2100, RebalanceResult.Status.DONE);\n allJobMetadata.put(\"job2_2\", jobMetadata);\n\n // Original job run as job3, and failed to send out heartbeat in time.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.IN_PROGRESS);\n stats.setStartTimeMs(3000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job3\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job3\", stats, jobCtx);\n jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"3000\");\n allJobMetadata.put(\"job3\", jobMetadata);\n\n // Original job run as job4, which didn't have retryJobCfg as from old version of the code.\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.FAILED);\n stats.setStartTimeMs(4000);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job4\", stats, null);\n jobMetadata.remove(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);\n allJobMetadata.put(\"job4\", jobMetadata);\n\n // Only need to retry job1 and job3, as job2 is completed and job4 is from old version of code.\n Map>> jobs =\n RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 2);\n assertTrue(jobs.containsKey(\"job1\"));\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job1\").size(), 4); // four runs including job1,job1_1,job1_2,job1_3\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Abort job1 and cancel its retries, then only job3 is retry candidate.\n jobMetadata = allJobMetadata.get(\"job1_4\");\n cancelRebalanceJob(jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 1);\n assertTrue(jobs.containsKey(\"job3\"));\n assertEquals(jobs.get(\"job3\").size(), 1); // just a single run job3\n\n // Add latest job5 that's already done, thus no need to retry for table.\n jobCfg = new RebalanceConfig();\n jobCfg.setMaxAttempts(4);\n stats = new TableRebalanceProgressStats();\n stats.setStatus(RebalanceResult.Status.DONE);\n stats.setStartTimeMs(5000);\n jobCtx = TableRebalanceContext.forInitialAttempt(\"job5\", jobCfg);\n jobMetadata = ZkBasedTableRebalanceObserver.createJobMetadata(tableName, \"job5\", stats, jobCtx);\n allJobMetadata.put(\"job5\", jobMetadata);\n jobs = RebalanceChecker.getCandidateJobs(tableName, allJobMetadata);\n assertEquals(jobs.size(), 0);\n }", + "completions": [ + "```java\n @Test\n public void testGetCandidateJobs()\n throws Exception {\n // Given\n String tableNameWithType = \"table1_type1\";\n Map> allJobMetadata = new HashMap<>();\n Map jobMetadata1 = new HashMap<>();\n jobMetadata1.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"FAILED\\\"}\");\n jobMetadata1.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job1\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job1\", jobMetadata1);\n Map jobMetadata2 = new HashMap<>();\n jobMetadata2.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, \"1640995200000\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,\n \"{\\\"startTimeMs\\\":1640995200000,\\\"status\\\":\\\"CANCELLED\\\"}\");\n jobMetadata2.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,\n \"{\\\"originalJobId\\\":\\\"job2\\\",\\\"config\\\":{\\\"heartbeatTimeoutInMs\\\":60000}}\");\n allJobMetadata.put(\"job2\", jobMetadata2);\n\n // When\n Map>> result = getCandidateJobs(tableNameWithType, allJobMetadata);\n\n // Then\n assertEquals(1, result.size());\n assertTrue(result.containsKey(\"job2\"));\n Set> jobSet = result.get(\"job2\");\n assertEquals(1, jobSet.size());\n Pair jobPair = jobSet.iterator().next();\n assertEquals(\"job2\", jobPair.getLeft().getOriginalJobId());\n assertEquals(1640995200000L, jobPair.getRight().longValue());\n }\n```\n" + ], + "line": 94, + "token": 952, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 70, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchError()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchError", + "reference": " @Test\n public void distClientMatchError() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.ERROR);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.ERROR);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.NO_SUB));\n for (IMQTTTransientSession session : sessions) {\n verify(session, never()).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchError() {\n // Create a mock for the ITenantMeter interface\n ITenantMeter mockTenantMeter = Mockito.mock(ITenantMeter.class);\n // Create a mock for the ILocalDistService interface\n ILocalDistService mockLocalDistService = Mockito.mock(ILocalDistService.class);\n // Create a mock for the sessionMap\n Map mockSessionMap = new HashMap<>();\n // Create a mock for the routeMap\n Map> mockRouteMap = new HashMap<>();\n // Create a mock for the resourceThrottler\n IResourceThrottler mockResourceThrottler = Mockito.mock(IResourceThrottler.class);\n\n // Set up the mock objects\n Mockito.when(mockTenantMeter.get(Mockito.anyString())).thenReturn(mockTenantMeter);\n Mockito.when(mockLocalDistService.isGlobal(Mockito.anyString())).thenReturn(false);\n Mockito.when(mockSessionMap.get(Mockito.anyString())).thenReturn(null);\n Mockito.when(mockRouteMap.get(Mockito.any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(null));\n Mockito.when(mockResourceThrottler.hasResource(Mockito.anyString(), Mockito.anyLong())).thenReturn(false);\n\n // Create a new instance of the class under test\n DistService distService = new DistService(mockTenantMeter, mockLocalDistService, mockSessionMap, mockRouteMap, mockResourceThrottler);\n\n // Create a test request\n DeliveryRequest request = DeliveryRequest.newBuilder().build();\n\n // Call the method under test\n CompletableFuture replyFuture = distService.dist(request);\n\n // Verify the results\n DeliveryReply reply = replyFuture.join();\n Assert.assertEquals(0, reply.getResultMap().size());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 71, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture dist(DeliveryRequest request) {\n DeliveryReply.Builder replyBuilder = DeliveryReply.newBuilder();\n DeliveryResults.Builder resultsBuilder = DeliveryResults.newBuilder();\n for (Map.Entry entry : request.getPackageMap().entrySet()) {\n String tenantId = entry.getKey();\n ITenantMeter tenantMeter = ITenantMeter.get(tenantId);\n boolean isFanOutThrottled = !resourceThrottler.hasResource(tenantId, TotalTransientFanOutBytesPerSeconds);\n boolean hasFanOutDone = false;\n Set ok = new HashSet<>();\n Set skip = new HashSet<>();\n Set noSub = new HashSet<>();\n for (DeliveryPack writePack : entry.getValue().getPackList()) {\n TopicMessagePack topicMsgPack = writePack.getMessagePack();\n int msgPackSize = SizeUtil.estSizeOf(topicMsgPack);\n int fanout = 1;\n for (MatchInfo matchInfo : writePack.getMatchInfoList()) {\n if (!noSub.contains(matchInfo) && !skip.contains(matchInfo)) {\n if (ILocalDistService.isGlobal(matchInfo.getReceiverId())) {\n IMQTTTransientSession session =\n sessionMap.get(ILocalDistService.parseReceiverId(matchInfo.getReceiverId()));\n if (session != null) {\n boolean success = session.publish(matchInfo, singletonList(topicMsgPack));\n if (success) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } else {\n // no session found for shared subscription\n noSub.add(matchInfo);\n }\n } else {\n if (isFanOutThrottled && hasFanOutDone) {\n continue;\n }\n int bucketId = LocalRoutes.parseBucketId(matchInfo.getReceiverId());\n CompletableFuture routesFuture =\n routeMap.get(new TopicFilter(tenantId, matchInfo.getTopicFilter(),\n bucketId));\n if (routesFuture == null) {\n noSub.add(matchInfo);\n continue;\n }\n if (!routesFuture.isDone() || routesFuture.isCompletedExceptionally()) {\n skip.add(matchInfo);\n }\n try {\n LocalRoutes localRoutes = routesFuture.join();\n if (!localRoutes.localizedReceiverId().equals(matchInfo.getReceiverId())) {\n noSub.add(matchInfo);\n continue;\n }\n boolean published = false;\n if (!isFanOutThrottled) {\n fanout *= localRoutes.routeList.size();\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n }\n }\n } else {\n // send to one subscriber to make sure matchinfo not lost\n for (IMQTTTransientSession session : localRoutes.routeList.values()) {\n // at least one session should publish the message\n if (session.publish(matchInfo, singletonList(topicMsgPack))) {\n published = true;\n hasFanOutDone = true;\n break;\n }\n }\n }\n if (published) {\n ok.add(matchInfo);\n } else {\n noSub.add(matchInfo);\n }\n } catch (Throwable e) {\n skip.add(matchInfo);\n }\n }\n }\n }\n tenantMeter.recordSummary(MqttTransientFanOutBytes, msgPackSize * Math.max(fanout, 1));\n }\n Sets.union(ok, skip).forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.OK)\n .build()));\n noSub.forEach(matchInfo -> resultsBuilder.addResult(DeliveryResult.newBuilder()\n .setMatchInfo(matchInfo)\n .setCode(DeliveryResult.Code.NO_SUB)\n .build()));\n replyBuilder.putResult(tenantId, resultsBuilder.build());\n }\n return CompletableFuture.completedFuture(replyBuilder.build());\n }\n```\n###test function signature:\n```java\n @Test\n public void distClientMatchOK()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "distClientMatchOK", + "reference": " @Test\n public void distClientMatchOK() {\n String topicFilter = \"topicFilter\";\n LocalDistService localDistService =\n new LocalDistService(serverId, distClient, resourceThrottler, eventCollector);\n\n CompletableFuture matchFuture = new CompletableFuture<>();\n when(distClient.match(anyLong(), anyString(), anyString(), anyString(), anyString(), anyInt()))\n .thenReturn(matchFuture);\n String tenantId = \"tenantId\";\n ClientInfo clientInfo = ClientInfo.newBuilder().setTenantId(tenantId).build();\n List sessions = new ArrayList<>();\n List> matchFutures = new ArrayList<>();\n for (int i = 0; i < 10 * TOPIC_FILTER_BUCKET_NUM; i++) {\n IMQTTTransientSession session = mock(IMQTTTransientSession.class);\n sessions.add(session);\n String channelId = \"channelId\" + i;\n when(session.clientInfo()).thenReturn(clientInfo);\n when(session.channelId()).thenReturn(channelId);\n when(session.publish(any(), anyList())).thenReturn(true);\n long reqId = System.nanoTime();\n matchFutures.add(localDistService.match(reqId, topicFilter, session));\n }\n matchFuture.complete(MatchResult.OK);\n for (CompletableFuture future : matchFutures) {\n assertEquals(future.join(), MatchResult.OK);\n }\n ArgumentCaptor receiverIdCaptor = ArgumentCaptor.forClass(String.class);\n verify(distClient, atMost(TOPIC_FILTER_BUCKET_NUM)).match(\n anyLong(),\n eq(tenantId),\n eq(topicFilter),\n receiverIdCaptor.capture(),\n anyString(),\n eq(0));\n List matchInfos = receiverIdCaptor.getAllValues().stream()\n .map(receiverId -> MatchInfo.newBuilder()\n .setReceiverId(receiverId)\n .setTopicFilter(topicFilter).build())\n .toList();\n DeliveryReply reply = localDistService.dist(DeliveryRequest.newBuilder()\n .putPackage(tenantId, DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .addAllMatchInfo(matchInfos)\n .build())\n .build())\n .build()).join();\n assertEquals(matchInfos.size(), reply.getResultMap().get(tenantId).getResultCount());\n assertTrue(reply.getResultMap().get(tenantId).getResultList().stream()\n .allMatch(result -> result.getCode() == DeliveryResult.Code.OK));\n for (IMQTTTransientSession session : sessions) {\n verify(session, times(1)).publish(any(), anyList());\n }\n }", + "completions": [ + "```java\n @Test\n public void distClientMatchOK() {\n // Given\n DeliveryRequest request = DeliveryRequest.newBuilder()\n .putPackage(\"tenant1\", DeliveryPackage.newBuilder()\n .addPack(DeliveryPack.newBuilder()\n .setMessagePack(TopicMessagePack.newBuilder()\n .setTopic(\"topic1\")\n .setPayload(\"payload1\")\n .build())\n .addMatchInfo(MatchInfo.newBuilder()\n .setReceiverId(\"receiver1\")\n .setTopicFilter(\"topic1\")\n .build())\n .build())\n .build())\n .build();\n sessionMap.put(\"receiver1\", mock(IMQTTTransientSession.class));\n when(sessionMap.get(\"receiver1\").publish(any(MatchInfo.class), anyList())).thenReturn(true);\n when(routeMap.get(any(TopicFilter.class))).thenReturn(CompletableFuture.completedFuture(LocalRoutes.newBuilder()\n .setLocalizedReceiverId(\"receiver1\")\n .putRoute(\"receiver1\", mock(IMQTTTransientSession.class))\n .build()));\n when(ITenantMeter.get(\"tenant1\")).thenReturn(mock(ITenantMeter.class));\n when(resourceThrottler.hasResource(\"tenant1\", TotalTransientFanOutBytesPerSeconds)).thenReturn(true);\n\n // When\n CompletableFuture replyFuture = dist(request);\n\n // Then\n DeliveryReply reply = replyFuture.join();\n DeliveryResults results = reply.getResultMap().get(\"tenant1\");\n assertEquals(1, results.getResultCount());\n DeliveryResult result = results.getResult(0);\n assertEquals(DeliveryResult.Code.OK, result.getCode());\n assertEquals(\"receiver1\", result.getMatchInfo().getReceiverId());\n assertEquals(\"topic1\", result.getMatchInfo().getTopicFilter());\n }\n```\n" + ], + "line": 109, + "token": 681, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 72, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Range calcInterquartileRangeApproximation(List sortedValues) {\n Range iqr = null;\n\n if (sortedValues.size() > 0) {\n BigDecimal q1 = valueCloseToQ1(sortedValues);\n BigDecimal q3 = valueCloseToQ3(sortedValues);\n BigDecimal max = sortedValues.get(sortedValues.size() - 1);\n\n // if iqr == 0 AND max == q3 then recursively try finding a non-zero iqr approximation.\n if (q1.compareTo(q3) == 0 && max.compareTo(q3) == 0) {\n // filter out max and try again\n iqr = this.calcInterquartileRangeApproximation(\n sortedValues.stream().filter(d -> d.compareTo(max) < 0).collect(Collectors.toList()));\n }\n\n // if range is still empty use the original q1 and q3 values\n if (iqr == null || iqr.isEmpty()) {\n iqr = Range.closedOpen(q1, q3);\n }\n }\n\n return iqr;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCalcInterquartileRangeApproximation()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCalcInterquartileRangeApproximation", + "reference": " @Test\n public void testCalcInterquartileRangeApproximation() {\n List values = decList(1, 2, 3, 4, 5, 6, 7, 8);\n assertEquals(new BigDecimal(\"3\"), dataBinHelper.calcInterquartileRangeApproximation(values).lowerEndpoint());\n assertEquals(new BigDecimal(\"7\"), dataBinHelper.calcInterquartileRangeApproximation(values).upperEndpoint());\n }", + "completions": [ + "```java\n @Test\n public void testCalcInterquartileRangeApproximation() {\n List sortedValues = Arrays.asList(\n new BigDecimal(\"1.0\"),\n new BigDecimal(\"2.0\"),\n new BigDecimal(\"3.0\"),\n new BigDecimal(\"4.0\"),\n new BigDecimal(\"5.0\"),\n new BigDecimal(\"6.0\"),\n new BigDecimal(\"7.0\"),\n new BigDecimal(\"8.0\"),\n new BigDecimal(\"9.0\"),\n new BigDecimal(\"10.0\")\n );\n\n Range expectedIqr = Range.closedOpen(new BigDecimal(\"2.5\"), new BigDecimal(\"7.5\"));\n Range actualIqr = calcInterquartileRangeApproximation(sortedValues);\n\n assertEquals(expectedIqr, actualIqr);\n }\n```\n" + ], + "line": 30, + "token": 250, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 73, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_serializable_complex_object_with_non_serializable_nested_object", + "reference": " @Test\n public void should_clone_serializable_complex_object_with_non_serializable_nested_object() {\n Map> map = new LinkedHashMap<>();\n map.put(\"key1\", Lists.newArrayList(new NonSerializableObject(\"name1\")));\n map.put(\"key2\", Lists.newArrayList(\n new NonSerializableObject(\"name2\"),\n new NonSerializableObject(\"name3\")\n ));\n Object original = new SerializableComplexObjectWithNonSerializableNestedObject(map);\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_serializable_complex_object_with_non_serialable_nested_object() {\n // Given\n Map map = new HashMap<>();\n map.put(\"key1\", \"value1\");\n map.put(\"key2\", new NonSerializableObject(\"nonSerializableValue\"));\n\n // When\n Map clonedMap = objectCloner.clone(map);\n\n // Then\n assertNotSame(map, clonedMap);\n assertEquals(map.size(), clonedMap.size());\n assertEquals(map.get(\"key1\"), clonedMap.get(\"key1\"));\n assertNotSame(map.get(\"key2\"), clonedMap.get(\"key2\"));\n assertEquals(map.get(\"key2\"), clonedMap.get(\"key2\"));\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 74, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SignedJWT token2Jwt(String token) {\n if (token == null) {\n log.error(\"Token is null.\");\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n try {\n return SignedJWT.parse(token);\n } catch (ParseException e) {\n log.error(\"Token to JWT Error,message = {}\", e.getMessage());\n throw new BizException(ErrorInfo.VERIFY_FAILED.getStatusCode(), ErrorInfo.VERIFY_FAILED.getCode(), ErrorInfo.VERIFY_FAILED.getMessage());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testToken2Jwt()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testToken2Jwt", + "reference": " @Test\n public void testToken2Jwt() {\n Assertions.assertThrows(BizException.class, () -> tokenParseHelper.token2Jwt(null));\n }", + "completions": [ + "```java\n @Test\n public void testToken2Jwt() {\n // Given\n String token = \"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c\";\n\n // When\n SignedJWT jwt = token2Jwt(token);\n\n // Then\n assertNotNull(jwt);\n assertEquals(token, jwt.serialize());\n }\n```\n" + ], + "line": 23, + "token": 177, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ConsumerRecords poll(final Duration timeout) {\n Timer timer = time.timer(timeout);\n\n acquireAndEnsureOpen();\n try {\n kafkaConsumerMetrics.recordPollStart(timer.currentTimeMs());\n\n if (subscriptions.hasNoSubscriptionOrUserAssignment()) {\n throw new IllegalStateException(\"Consumer is not subscribed to any topics or assigned any partitions\");\n }\n\n do {\n\n // Make sure to let the background thread know that we are still polling.\n applicationEventHandler.add(new PollEvent(timer.currentTimeMs()));\n\n // We must not allow wake-ups between polling for fetches and returning the records.\n // If the polled fetches are not empty the consumed position has already been updated in the polling\n // of the fetches. A wakeup between returned fetches and returning records would lead to never\n // returning the records in the fetches. Thus, we trigger a possible wake-up before we poll fetches.\n wakeupTrigger.maybeTriggerWakeup();\n\n updateAssignmentMetadataIfNeeded(timer);\n final Fetch fetch = pollForFetches(timer);\n if (!fetch.isEmpty()) {\n if (fetch.records().isEmpty()) {\n log.trace(\"Returning empty records from `poll()` \"\n + \"since the consumer's position has advanced for at least one topic partition\");\n }\n\n return interceptors.onConsume(new ConsumerRecords<>(fetch.records()));\n }\n // We will wait for retryBackoffMs\n } while (timer.notExpired());\n\n return ConsumerRecords.empty();\n } finally {\n kafkaConsumerMetrics.recordPollEnd(timer.currentTimeMs());\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testLongPollWaitIsLimited()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testLongPollWaitIsLimited", + "reference": " @Test\n public void testLongPollWaitIsLimited() {\n consumer = newConsumer();\n String topicName = \"topic1\";\n consumer.subscribe(singletonList(topicName));\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertTrue(consumer.assignment().isEmpty());\n\n final int partition = 3;\n final TopicPartition tp = new TopicPartition(topicName, partition);\n final List> records = asList(\n new ConsumerRecord<>(topicName, partition, 2, \"key1\", \"value1\"),\n new ConsumerRecord<>(topicName, partition, 3, \"key2\", \"value2\")\n );\n\n // On the first iteration, return no data; on the second, return two records\n doAnswer(invocation -> {\n // Mock the subscription being assigned as the first fetch is collected\n consumer.subscriptions().assignFromSubscribed(Collections.singleton(tp));\n return Fetch.empty();\n }).doAnswer(invocation -> {\n return Fetch.forPartition(tp, records, true);\n }).when(fetchCollector).collectFetch(any(FetchBuffer.class));\n\n // And then poll for up to 10000ms, which should return 2 records without timing out\n ConsumerRecords returnedRecords = consumer.poll(Duration.ofMillis(10000));\n assertEquals(2, returnedRecords.count());\n\n assertEquals(singleton(topicName), consumer.subscription());\n assertEquals(singleton(tp), consumer.assignment());\n }", + "completions": [ + "```java\n @Test\n public void testLongPollWaitIsLimited() {\n // Arrange\n Duration timeout = Duration.ofMillis(1000);\n KafkaConsumer consumer = new KafkaConsumer<>(props);\n consumer.subscribe(Collections.singletonList(\"test-topic\"));\n\n // Act\n long startTime = System.currentTimeMillis();\n ConsumerRecords records = consumer.poll(timeout);\n long endTime = System.currentTimeMillis();\n\n // Assert\n long pollDuration = endTime - startTime;\n assertTrue(pollDuration >= timeout.toMillis());\n }\n```\n" + ], + "line": 45, + "token": 366, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 76, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testFullIndexSearchForMinAgg()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testFullIndexSearchForMinAgg", + "reference": " @Test\n public void testFullIndexSearchForMinAgg() {\n Instant time = Instant.now();\n loadTestData(time);\n\n SearchResult allIndexItems =\n strictLogStore.logSearcher.search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new MinAggBuilder(\n \"test\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"0\", null));\n\n assertThat(allIndexItems.hits.size()).isEqualTo(4);\n\n InternalMin internalMin =\n (InternalMin) Objects.requireNonNull(allIndexItems.internalAggregation);\n\n assertThat(Double.valueOf(internalMin.getValue()).longValue()).isEqualTo(time.toEpochMilli());\n }", + "completions": [ + "```java\n @Test\n public void testFullIndexSearchForMinAgg() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = new MinAggBuilder();\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n assertNotNull(result.getAggregation());\n assertTrue(result.getAggregation() instanceof MinAgg);\n assertTrue(result.getElapsedTimeMicros() > 0);\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 77, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_bsaChecksumFetched", + "reference": " @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n setupMocks();\n when(connection.getResponseCode()).thenReturn(SC_OK);\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream(\"bsa-checksum\\ndata\".getBytes(UTF_8)));\n LazyBlockList download = fetcher.tryFetch(BlockListType.BLOCK);\n assertThat(download.getName()).isEqualTo(BlockListType.BLOCK);\n assertThat(download.checksum()).isEqualTo(\"bsa-checksum\");\n verify(connection, times(1)).setRequestMethod(\"GET\");\n verify(connection, times(1)).setRequestProperty(\"Authorization\", \"Bearer authToken\");\n }", + "completions": [ + "```java\n @Test\n void tryFetch_bsaChecksumFetched() throws Exception {\n // Given\n URL dataUrl = new URL(\"https://example.com/bsaChecksum\");\n HttpsURLConnection connection = mock(HttpsURLConnection.class);\n when(connection.getResponseCode()).thenReturn(HttpURLConnection.HTTP_OK);\n when(connection.getErrorStream()).thenReturn(new ByteArrayInputStream(\"\".getBytes()));\n when(urlConnectionService.createConnection(dataUrl)).thenReturn(connection);\n\n // When\n LazyBlockList lazyBlockList = bsaService.tryFetch(BlockListType.BSA_CHECKSUM);\n\n // Then\n assertEquals(BlockListType.BSA_CHECKSUM, lazyBlockList.getBlockListType());\n assertEquals(connection, lazyBlockList.getConnection());\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 78, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDisabledFullTextSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDisabledFullTextSearch", + "reference": " @Test\n public void testDisabledFullTextSearch() {\n Instant time = Instant.ofEpochSecond(1593365471);\n Trace.KeyValue field1Tag =\n Trace.KeyValue.newBuilder()\n .setVInt32(1234)\n .setKey(\"field1\")\n .setFieldType(Schema.SchemaFieldType.INTEGER)\n .build();\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(1, \"apple\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(2, \"apple baby\", time.plusSeconds(4), List.of(field1Tag)));\n\n strictLogStoreWithoutFts.logStore.addMessage(\n SpanUtil.makeSpan(3, \"baby car 1234\", time.plusSeconds(4)));\n strictLogStoreWithoutFts.logStore.commit();\n strictLogStoreWithoutFts.logStore.refresh();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_all:1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Without the _all field as default.\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n assertThat(\n strictLogStoreWithoutFts\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(3);\n\n // empty string\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"app*\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Returns baby or car, 2 messages.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"baby car\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n // Test numbers\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"apple 1234\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"123\",\n 0L,\n MAX_TIME,\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isZero();\n }", + "completions": [ + "```java\n @Test\n public void testDisabledFullTextSearch() {\n // Given\n String dataset = \"testDataset\";\n String queryStr = \"testQuery\";\n Long startTimeMsEpoch = null;\n Long endTimeMsEpoch = null;\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(0, result.getTotalHits());\n assertEquals(0, result.getResults().size());\n assertNull(result.getAggregation());\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 6, + "token_diff": 7 + }, + { + "id": 79, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyWithException", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n when(tenantAwareClient.search(any())).thenThrow(new IOException());\n\n assertThrows(\n OperateRuntimeException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyWithException() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response = Mockito.mock(SearchResponse.class);\n final SearchHits searchHits = Mockito.mock(SearchHits.class);\n Mockito.when(response.getHits()).thenReturn(searchHits);\n Mockito.when(searchHits.getTotalHits()).thenReturn(new TotalHits(2L, TotalHitsRelation.EQUAL_TO));\n Mockito.when(searchHits.getHits()).thenReturn(new SearchHit[0]);\n\n Mockito.when(tenantAwareClient.search(Mockito.any(SearchRequest.class))).thenReturn(response);\n\n // When\n try {\n processInstanceListViewService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException to be thrown\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find unique process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 80, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n LazyBlockList tryFetch(BlockListType blockListType) {\n try {\n URL dataUrl = new URL(blockListUrls.get(blockListType.name()));\n logger.atInfo().log(\"Downloading from %s\", dataUrl);\n HttpsURLConnection connection =\n (HttpsURLConnection) urlConnectionService.createConnection(dataUrl);\n connection.setRequestMethod(HttpMethods.GET);\n connection.setRequestProperty(\"Authorization\", \"Bearer \" + credential.getAuthToken());\n int code = connection.getResponseCode();\n if (code != SC_OK) {\n String errorDetails = \"\";\n try (InputStream errorStream = connection.getErrorStream()) {\n errorDetails = new String(ByteStreams.toByteArray(errorStream), UTF_8);\n } catch (NullPointerException e) {\n // No error message.\n } catch (Exception e) {\n errorDetails = \"Failed to retrieve error message: \" + e.getMessage();\n }\n throw new BsaException(\n String.format(\n \"Status code: [%s], error: [%s], details: [%s]\",\n code, connection.getResponseMessage(), errorDetails),\n /* retriable= */ true);\n }\n return new LazyBlockList(blockListType, connection);\n } catch (IOException e) {\n throw new BsaException(e, /* retriable= */ true);\n } catch (GeneralSecurityException e) {\n throw new BsaException(e, /* retriable= */ false);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "tryFetch_SecurityException_notRetriable", + "reference": " @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n when(connectionService.createConnection(any(URL.class)))\n .thenThrow(new GeneralSecurityException());\n assertThat(\n assertThrows(BsaException.class, () -> fetcher.tryFetch(BlockListType.BLOCK))\n .isRetriable())\n .isFalse();\n }", + "completions": [ + "```java\n @Test\n void tryFetch_SecurityException_notRetriable() throws Exception {\n // Arrange\n String blockListTypeName = \"TEST\";\n BlockListType blockListType = new BlockListType(blockListTypeName);\n URL dataUrl = new URL(\"http://example.com\");\n when(blockListUrls.get(blockListTypeName)).thenReturn(dataUrl.toString());\n when(urlConnectionService.createConnection(dataUrl)).thenThrow(new GeneralSecurityException());\n\n // Act and Assert\n assertThrows(BsaException.class, () -> {\n try {\n blockListService.tryFetch(blockListType);\n } catch (BsaException e) {\n assertFalse(e.isRetriable());\n throw e;\n }\n });\n }\n```\n" + ], + "line": 42, + "token": 330, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 81, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public void assign(Collection partitions) {\n acquireAndEnsureOpen();\n try {\n if (partitions == null) {\n throw new IllegalArgumentException(\"Topic partitions collection to assign to cannot be null\");\n }\n\n if (partitions.isEmpty()) {\n unsubscribe();\n return;\n }\n\n for (TopicPartition tp : partitions) {\n String topic = (tp != null) ? tp.topic() : null;\n if (isBlank(topic))\n throw new IllegalArgumentException(\"Topic partitions to assign to cannot have null or empty topic\");\n }\n\n // Clear the buffered data which are not a part of newly assigned topics\n final Set currentTopicPartitions = new HashSet<>();\n\n for (TopicPartition tp : subscriptions.assignedPartitions()) {\n if (partitions.contains(tp))\n currentTopicPartitions.add(tp);\n }\n\n fetchBuffer.retainAll(currentTopicPartitions);\n\n // assignment change event will trigger autocommit if it is configured and the group id is specified. This is\n // to make sure offsets of topic partitions the consumer is unsubscribing from are committed since there will\n // be no following rebalance.\n //\n // See the ApplicationEventProcessor.process() method that handles this event for more detail.\n applicationEventHandler.add(new AssignmentChangeEvent(subscriptions.allConsumed(), time.milliseconds()));\n\n log.info(\"Assigned to partition(s): {}\", join(partitions, \", \"));\n if (subscriptions.assignFromUser(new HashSet<>(partitions)))\n applicationEventHandler.add(new NewTopicsMetadataUpdateRequestEvent());\n } finally {\n release();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAssignOnEmptyTopicPartition()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAssignOnEmptyTopicPartition", + "reference": " @Test\n public void testAssignOnEmptyTopicPartition() {\n consumer = newConsumer();\n completeUnsubscribeApplicationEventSuccessfully();\n\n consumer.assign(Collections.emptyList());\n assertTrue(consumer.subscription().isEmpty());\n assertTrue(consumer.assignment().isEmpty());\n }", + "completions": [ + "```java\n @Test\n public void testAssignOnEmptyTopicPartition() {\n // Given\n Collection partitions = new ArrayList<>();\n ConsumerImpl consumer = new ConsumerImpl();\n\n // When\n consumer.assign(partitions);\n\n // Then\n // Add assertions to verify the expected behavior\n // For example, you can check if the consumer is unsubscribed and if the fetch buffer is empty\n // assertFalse(consumer.isSubscribed());\n // assertTrue(consumer.getFetchBuffer().isEmpty());\n }\n```\n" + ], + "line": 47, + "token": 377, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 82, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static ImmutableList applyLabelDiff(\n ImmutableList labels,\n IdnChecker idnChecker,\n DownloadSchedule schedule,\n DateTime now) {\n ImmutableList.Builder nonBlockedDomains = new ImmutableList.Builder<>();\n ImmutableMap> labelsByType =\n ImmutableMap.copyOf(\n labels.stream().collect(groupingBy(BlockLabel::labelType, toImmutableList())));\n\n tm().transact(\n () -> {\n for (Map.Entry> entry :\n labelsByType.entrySet()) {\n switch (entry.getKey()) {\n case CREATE:\n // With current Cloud SQL, label upsert throughput is about 200/second. If\n // better performance is needed, consider bulk insert in native SQL.\n tm().putAll(\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(\n label ->\n new BsaLabel(label.label(), schedule.jobCreationTime()))\n .collect(toImmutableList()));\n // May not find all unblockables due to race condition: DomainCreateFlow uses\n // cached BsaLabels. Eventually will be consistent.\n nonBlockedDomains.addAll(\n tallyUnblockableDomainsForNewLabels(entry.getValue(), idnChecker, now));\n break;\n case DELETE:\n ImmutableSet deletedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n // Delete labels in DB. Also cascade-delete BsaUnblockableDomain.\n int nDeleted = Queries.deleteBsaLabelByLabels(deletedLabels);\n if (nDeleted != deletedLabels.size()) {\n logger.atSevere().log(\n \"Only found %s entities among the %s labels: [%s]\",\n nDeleted, deletedLabels.size(), deletedLabels);\n }\n break;\n case NEW_ORDER_ASSOCIATION:\n ImmutableSet affectedLabels =\n entry.getValue().stream()\n .filter(label -> isValidInAtLeastOneTld(label, idnChecker))\n .map(BlockLabel::label)\n .collect(toImmutableSet());\n ImmutableSet labelsInDb =\n Queries.queryBsaLabelByLabels(affectedLabels)\n .map(BsaLabel::getLabel)\n .collect(toImmutableSet());\n verify(\n labelsInDb.size() == affectedLabels.size(),\n \"Missing labels in DB: %s\",\n LazyArgs.lazy(() -> difference(affectedLabels, labelsInDb)));\n\n // Reuse registered and reserved names that are already computed.\n Queries.queryBsaUnblockableDomainByLabels(affectedLabels)\n .map(BsaUnblockableDomain::toUnblockableDomain)\n .forEach(nonBlockedDomains::add);\n\n for (BlockLabel label : entry.getValue()) {\n getInvalidTldsForLabel(label, idnChecker)\n .map(tld -> UnblockableDomain.of(label.label(), tld, Reason.INVALID))\n .forEach(nonBlockedDomains::add);\n }\n break;\n }\n }\n },\n TRANSACTION_REPEATABLE_READ);\n logger.atInfo().log(\"Processed %s of labels.\", labels.size());\n return nonBlockedDomains.build();\n }\n```\n###test function signature:\n```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "applyLabelDiffs_newAssociationOfLabelToOrder", + "reference": " @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n tm().transact(\n () -> {\n tm().insert(new BsaLabel(\"label\", fakeClock.nowUtc()));\n tm().insert(new BsaUnblockableDomain(\"label\", \"app\", Reason.REGISTERED));\n });\n when(idnChecker.getSupportingTlds(any())).thenReturn(ImmutableSet.of(app));\n when(idnChecker.getForbiddingTlds(any()))\n .thenReturn(Sets.difference(ImmutableSet.of(dev), ImmutableSet.of()).immutableCopy());\n\n ImmutableList unblockableDomains =\n applyLabelDiff(\n ImmutableList.of(\n BlockLabel.of(\"label\", LabelType.NEW_ORDER_ASSOCIATION, ImmutableSet.of())),\n idnChecker,\n schedule,\n fakeClock.nowUtc());\n assertThat(unblockableDomains)\n .containsExactly(\n UnblockableDomain.of(\"label.app\", UnblockableDomain.Reason.REGISTERED),\n UnblockableDomain.of(\"label.dev\", UnblockableDomain.Reason.INVALID));\n assertThat(tm().transact(() -> tm().loadByKeyIfPresent(BsaLabel.vKey(\"label\")))).isPresent();\n assertThat(\n tm().transact(() -> tm().loadByKeyIfPresent(BsaUnblockableDomain.vKey(\"label\", \"app\"))))\n .isPresent();\n }", + "completions": [ + "```java\n @Test\n void applyLabelDiffs_newAssociationOfLabelToOrder() {\n // Given\n ImmutableList labels = ImmutableList.of(BlockLabel.of(\"label1\", LabelType.NEW_ORDER_ASSOCIATION));\n IdnChecker idnChecker = mock(IdnChecker.class);\n DownloadSchedule schedule = mock(DownloadSchedule.class);\n DateTime now = DateTime.now();\n\n // When\n ImmutableList result = applyLabelDiff(labels, idnChecker, schedule, now);\n\n // Then\n // Add assertions here to verify the result\n }\n```\n" + ], + "line": 85, + "token": 602, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 83, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public CompletableFuture createTopic(CreateTopicRequest request) {\n CompletableFuture future = new CompletableFuture<>();\n for (; ; ) {\n if (metadataStore.isLeader()) {\n try (SqlSession session = metadataStore.openSession()) {\n if (!metadataStore.maintainLeadershipWithSharedLock(session)) {\n continue;\n }\n TopicMapper topicMapper = session.getMapper(TopicMapper.class);\n if (null != topicMapper.get(null, request.getTopic())) {\n ControllerException e = new ControllerException(Code.DUPLICATED_VALUE,\n String.format(\"Topic %s was taken\", request.getTopic()));\n future.completeExceptionally(e);\n return future;\n }\n\n Topic topic = new Topic();\n topic.setName(request.getTopic());\n topic.setQueueNum(request.getCount());\n topic.setStatus(TopicStatus.TOPIC_STATUS_ACTIVE);\n topic.setAcceptMessageTypes(JsonFormat.printer().print(request.getAcceptTypes()));\n topic.setRetentionHours(request.getRetentionHours());\n topicMapper.create(topic);\n long topicId = topic.getId();\n List assignments = createQueues(IntStream.range(0, request.getCount()),\n topicId, session);\n // Commit transaction\n session.commit();\n\n // Cache new topic and queue assignments immediately\n topicCache.apply(List.of(topic));\n assignmentCache.apply(assignments);\n future.complete(topicId);\n } catch (ControllerException | InvalidProtocolBufferException e) {\n future.completeExceptionally(e);\n }\n return future;\n } else {\n Optional leaderAddress = metadataStore.electionService().leaderAddress();\n if (leaderAddress.isEmpty()) {\n return CompletableFuture.failedFuture(new ControllerException(Code.NO_LEADER_VALUE, \"No leader is elected yet\"));\n }\n metadataStore.controllerClient().createTopic(leaderAddress.get(), request).whenComplete((res, e) -> {\n if (null != e) {\n future.completeExceptionally(e);\n } else {\n future.complete(res);\n }\n });\n }\n break;\n }\n return future;\n }\n```\n###test function signature:\n```java\n @Test\n public void testCreateTopic()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testCreateTopic", + "reference": " @Test\n public void testCreateTopic() {\n MetadataStore metadataStore = Mockito.mock(MetadataStore.class);\n ControllerClient controllerClient = Mockito.mock(ControllerClient.class);\n Mockito.when(controllerClient.createTopic(ArgumentMatchers.anyString(), ArgumentMatchers.any())).thenReturn(\n CompletableFuture.failedFuture(new CompletionException(new RuntimeException()))\n );\n Mockito.when(metadataStore.controllerClient()).thenReturn(controllerClient);\n Mockito.when(metadataStore.isLeader()).thenReturn(false);\n\n ElectionService electionService = Mockito.mock(ElectionService.class);\n Mockito.when(metadataStore.electionService()).thenReturn(electionService);\n\n Mockito.when(electionService.leaderAddress()).thenReturn(Optional.of(\"localhost:1234\"));\n\n TopicManager topicManager = new TopicManager(metadataStore);\n topicManager.createTopic(CreateTopicRequest.newBuilder().build());\n }", + "completions": [ + "```java\n @Test\n public void testCreateTopic() {\n // Given\n CreateTopicRequest request = new CreateTopicRequest();\n request.setTopic(\"testTopic\");\n request.setCount(3);\n request.setRetentionHours(24);\n request.setAcceptTypes(List.of(MessageType.MESSAGE_TYPE_JSON));\n\n // When\n CompletableFuture future = controller.createTopic(request);\n\n // Then\n assertDoesNotThrow(() -> {\n Long topicId = future.get();\n assertNotNull(topicId);\n assertTrue(topicId > 0);\n });\n }\n```\n" + ], + "line": 63, + "token": 424, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 84, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public SendStatus sendHtmlMail(String subject, List toList, String username, String agentName, String emailHref, String maiContent) {\n SendStatus sendStatus = new SendStatus(SEND_STATUS_FALSE, \"\");\n List notInBlacklistAddress = checkNotInBlacklistAddress(toList,sendStatus);\n if(CollectionUtils.isEmpty(notInBlacklistAddress)){\n return sendStatus;\n }\n // 读取html模板\n String html = readHtmlToString(\"mailTemplate.html\");\n\n // 写入模板内容\n Document doc = Jsoup.parse(html);\n doc.getElementById(\"username\").html(username);\n\n if (StringUtils.isEmpty(agentName)) {\n sendStatus.setErrorMessage(\"agentName 为空\");\n return sendStatus;\n }\n doc.getElementById(\"sysName\").html(\"Your Agent:\");\n doc.getElementById(\"agentName\").html(agentName);\n doc.getElementById(\"mailContent\").html(maiContent);\n doc.getElementById(\"clickHref\").attr(\"href\", emailHref);\n\n String result = doc.toString();\n Properties props = new Properties();\n props.put(\"mail.smtp.host\", \"\");\n props.put(\"mail.smtp.auth\", \"true\");\n props.put(\"mail.smtp.socketFactory.class\", \"javax.net.ssl.SSLSocketFactory\");\n props.put(\"mail.smtp.socketFactory.fallback\", \"true\");\n Session session = Session.getDefaultInstance(props);\n session.setDebug(true);\n\n Transport transport = null;\n MimeMessage message = new MimeMessage(session);\n try {\n //初始化发送邮件配置\n this.initMailConfig();\n message.setFrom(new InternetAddress(this.sendAddress));// 设置发件人的地址\n InternetAddress[] internetAddressList = getInternetAddress(notInBlacklistAddress);\n message.setRecipients(Message.RecipientType.TO, internetAddressList);// 设置收件人,并设置其接收类型为TO\n message.setSubject(subject);// 设置标题\n message.setContent(result, \"text/html;charset=UTF-8\"); // 设置邮件内容类型为html\n message.setSentDate(new Date());// 设置发信时间\n message.saveChanges();// 存储邮件信息\n\n // 发送邮件\n transport = session.getTransport(\"smtp\");\n if (null != port) {\n transport.connect(host, port, user, password);\n } else {\n transport.connect(host, user, password);\n }\n transport.sendMessage(message, message.getAllRecipients());\n\n //发送邮件成功,status置为true\n sendStatus.setStatus(\"true\");\n } catch (Exception e) {\n log.error(\"邮件发送异常\", e);\n sendStatus.setErrorMessage(e.getMessage());\n } finally {\n if (null != transport) {\n try {\n transport.close();//关闭连接\n } catch (MessagingException e) {\n log.error(\"发送邮件 ,transport 关闭异常\", e);\n }\n }\n }\n return sendStatus;\n }\n```\n###test function signature:\n```java\n @Test\n void testSendHtmlMail_subjectNotNull()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSendHtmlMail_subjectNotNull", + "reference": " @Test\n void testSendHtmlMail_subjectNotNull(){\n try(MockedStatic sessionMockedStatic = Mockito.mockStatic(Session.class)){\n Session mockSession = mock(Session.class);\n Transport transport = mock(Transport.class);\n sessionMockedStatic.when(()->Session.getDefaultInstance(any(Properties.class))).thenReturn(mockSession);\n when(blacklistService.inBlacklist(anyString())).thenReturn(false);\n List addressList = new ArrayList<>();\n addressList.add(\"test@qq.com\");\n when(settingsService.getByCategoryAndKey(anyString(),anyString())).thenReturn(\"test\").thenReturn(5678).thenReturn(\"test\").thenReturn(\"test\").thenReturn(\"123456\");\n when(mockSession.getTransport(anyString())).thenReturn(transport);\n doNothing().when(transport).connect(anyString(),anyInt(),anyString(),anyString());\n doAnswer(invocationOnMock -> {\n InternetAddress[] internetAddressList = invocationOnMock.getArgument(1);\n Assertions.assertEquals(addressList.get(0),internetAddressList[0].getAddress());\n return null;\n }).when(transport).sendMessage(any(MimeMessage.class),any());\n mailUtils.sendHtmlMail(\"\",addressList,\"test\",\"test\",\"test\",\"test\");\n } catch (MessagingException e) {\n throw new RuntimeException(e);\n }\n }", + "completions": [ + "```java\n @Test\n void testSendHtmlMail_subjectNotNull() {\n // Arrange\n String subject = \"Test Subject\";\n List toList = Arrays.asList(\"test1@example.com\", \"test2@example.com\");\n String username = \"Test User\";\n String agentName = \"Test Agent\";\n String emailHref = \"http://example.com\";\n String maiContent = \"Test Content\";\n\n // Act\n SendStatus result = mailService.sendHtmlMail(subject, toList, username, agentName, emailHref, maiContent);\n\n // Assert\n assertEquals(SEND_STATUS_FALSE, result.getStatus());\n assertEquals(\"\", result.getErrorMessage());\n }\n```\n" + ], + "line": 74, + "token": 533, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 85, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead", + "reference": " @Test\n public void\n testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 1,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n AstraConfigs.IndexerConfig headLocationAndNoRecoveryConfig =\n AstraConfigs.IndexerConfig.newBuilder()\n .setCreateRecoveryTasksOnStart(false)\n .setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST)\n .build();\n\n // When there is no data and ReadFromLocationOnStart is set to LATEST, return the current head\n assertThat(\n recoveryTaskCreator.determineStartingOffset(1000, 0, headLocationAndNoRecoveryConfig))\n .isEqualTo(1000);\n\n // Data exists for not for this partition.\n final String name = \"testSnapshotId\";\n final String path = \"/testPath_\" + name;\n final long startTime = 1;\n final long endTime = 100;\n final long maxOffset = 100;\n\n final SnapshotMetadata partition1 =\n new SnapshotMetadata(name, path, startTime, endTime, maxOffset, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition1);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition1));\n\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n\n final SnapshotMetadata partition11 =\n new SnapshotMetadata(\n name + \"1\", path, endTime + 1, endTime * 2, maxOffset * 2, \"2\", LOGS_LUCENE9);\n snapshotMetadataStore.createSync(partition11);\n await().until(() -> snapshotMetadataStore.listSync().contains(partition11));\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore))\n .contains(partition1, partition11);\n assertThat(recoveryTaskCreator.determineStartingOffset(0, 0, indexerConfig)).isNegative();\n\n final String recoveryTaskName = \"recoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n \"2\",\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n await().until(() -> recoveryTaskStore.listSync().contains(recoveryTask1));\n assertThat(recoveryTaskCreator.determineStartingOffset(1, -1, headLocationAndNoRecoveryConfig))\n .isEqualTo(1);\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetReturnsHeadWhenCreateTasksIsFalseAndOffsetLocationIsHead() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(false);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n\n // When\n long result = determineStartingOffset(currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(currentEndOffsetForPartition, result);\n }\n```\n" + ], + "line": 128, + "token": 935, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 86, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n boolean handleRecoveryTask(RecoveryTaskMetadata recoveryTaskMetadata) {\n LOG.info(\"Started handling the recovery task: {}\", recoveryTaskMetadata);\n long startTime = System.nanoTime();\n Timer.Sample taskTimer = Timer.start(meterRegistry);\n\n PartitionOffsets partitionOffsets =\n validateKafkaOffsets(\n adminClient,\n recoveryTaskMetadata,\n AstraConfig.getRecoveryConfig().getKafkaConfig().getKafkaTopic());\n long offsetsValidatedTime = System.nanoTime();\n long consumerPreparedTime = 0, messagesConsumedTime = 0, rolloversCompletedTime = 0;\n\n if (partitionOffsets != null) {\n RecoveryTaskMetadata validatedRecoveryTask =\n new RecoveryTaskMetadata(\n recoveryTaskMetadata.name,\n recoveryTaskMetadata.partitionId,\n partitionOffsets.startOffset,\n partitionOffsets.endOffset,\n recoveryTaskMetadata.createdTimeEpochMs);\n\n if (partitionOffsets.startOffset != recoveryTaskMetadata.startOffset\n || recoveryTaskMetadata.endOffset != partitionOffsets.endOffset) {\n recoveryRecordsNoLongerAvailable.increment(\n (partitionOffsets.startOffset - recoveryTaskMetadata.startOffset)\n + (partitionOffsets.endOffset - recoveryTaskMetadata.endOffset));\n }\n\n try {\n RecoveryChunkManager chunkManager =\n RecoveryChunkManager.fromConfig(\n meterRegistry,\n searchMetadataStore,\n snapshotMetadataStore,\n AstraConfig.getIndexerConfig(),\n blobFs,\n AstraConfig.getS3Config());\n\n // Ingest data in parallel\n LogMessageWriterImpl logMessageWriterImpl = new LogMessageWriterImpl(chunkManager);\n AstraKafkaConsumer kafkaConsumer =\n new AstraKafkaConsumer(\n makeKafkaConfig(\n AstraConfig.getRecoveryConfig().getKafkaConfig(),\n validatedRecoveryTask.partitionId),\n logMessageWriterImpl,\n meterRegistry);\n\n kafkaConsumer.prepConsumerForConsumption(validatedRecoveryTask.startOffset);\n consumerPreparedTime = System.nanoTime();\n kafkaConsumer.consumeMessagesBetweenOffsetsInParallel(\n AstraKafkaConsumer.KAFKA_POLL_TIMEOUT_MS,\n validatedRecoveryTask.startOffset,\n validatedRecoveryTask.endOffset);\n messagesConsumedTime = System.nanoTime();\n // Wait for chunks to upload.\n boolean success = chunkManager.waitForRollOvers();\n rolloversCompletedTime = System.nanoTime();\n // Close the recovery chunk manager and kafka consumer.\n kafkaConsumer.close();\n chunkManager.stopAsync();\n chunkManager.awaitTerminated(DEFAULT_START_STOP_DURATION);\n LOG.info(\"Finished handling the recovery task: {}\", validatedRecoveryTask);\n taskTimer.stop(recoveryTaskTimerSuccess);\n return success;\n } catch (Exception ex) {\n LOG.error(\"Exception in recovery task [{}]: {}\", validatedRecoveryTask, ex);\n taskTimer.stop(recoveryTaskTimerFailure);\n return false;\n } finally {\n long endTime = System.nanoTime();\n LOG.info(\n \"Recovery task {} took {}ms, (subtask times offset validation {}, consumer prep {}, msg consumption {}, rollover {})\",\n recoveryTaskMetadata,\n nanosToMillis(endTime - startTime),\n nanosToMillis(offsetsValidatedTime - startTime),\n nanosToMillis(consumerPreparedTime - offsetsValidatedTime),\n nanosToMillis(messagesConsumedTime - consumerPreparedTime),\n nanosToMillis(rolloversCompletedTime - messagesConsumedTime));\n }\n } else {\n LOG.info(\n \"Recovery task {} data no longer available in Kafka (validation time {}ms)\",\n recoveryTaskMetadata,\n nanosToMillis(offsetsValidatedTime - startTime));\n recoveryRecordsNoLongerAvailable.increment(\n recoveryTaskMetadata.endOffset - recoveryTaskMetadata.startOffset + 1);\n return true;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets", + "reference": " @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n final TopicPartition topicPartition = new TopicPartition(TestKafkaServer.TEST_KAFKA_TOPIC, 0);\n TestKafkaServer.KafkaComponents components = getKafkaTestServer(S3_MOCK_EXTENSION);\n AstraConfigs.AstraConfig astraCfg =\n makeAstraConfig(components.testKafkaServer, TEST_S3_BUCKET, topicPartition.topic());\n curatorFramework =\n CuratorBuilder.build(meterRegistry, astraCfg.getMetadataStoreConfig().getZookeeperConfig());\n\n AstraConfigs.KafkaConfig kafkaConfig =\n AstraConfigs.KafkaConfig.newBuilder()\n .setKafkaTopic(topicPartition.topic())\n .setKafkaTopicPartition(Integer.toString(topicPartition.partition()))\n .setKafkaBootStrapServers(components.testKafkaServer.getBroker().getBrokerList().get())\n .setKafkaClientGroup(TEST_KAFKA_CLIENT_GROUP)\n .setEnableKafkaAutoCommit(\"true\")\n .setKafkaAutoCommitInterval(\"500\")\n .setKafkaSessionTimeout(\"500\")\n .putAllAdditionalProps(Maps.fromProperties(components.consumerOverrideProps))\n .build();\n\n final AstraKafkaConsumer localTestConsumer =\n new AstraKafkaConsumer(kafkaConfig, components.logMessageWriter, components.meterRegistry);\n final Instant startTime =\n LocalDateTime.of(2020, 10, 1, 10, 10, 0).atZone(ZoneOffset.UTC).toInstant();\n final long msgsToProduce = 100;\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await().until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce);\n // we immediately force delete the messages, as this is faster than changing the retention and\n // waiting for the cleaner to run\n components\n .adminClient\n .deleteRecords(Map.of(topicPartition, RecordsToDelete.beforeOffset(100)))\n .all()\n .get();\n assertThat(getStartOffset(components.adminClient, topicPartition)).isGreaterThan(0);\n\n // produce some more messages that won't be expired\n setRetentionTime(components.adminClient, topicPartition.topic(), 25000);\n TestKafkaServer.produceMessagesToKafka(\n components.testKafkaServer.getBroker(),\n startTime,\n topicPartition.topic(),\n topicPartition.partition(),\n (int) msgsToProduce);\n await()\n .until(() -> localTestConsumer.getEndOffSetForPartition() == msgsToProduce + msgsToProduce);\n\n SnapshotMetadataStore snapshotMetadataStore = new SnapshotMetadataStore(curatorFramework);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore).size()).isZero();\n\n // Start recovery service\n recoveryService =\n new RecoveryService(astraCfg, curatorFramework, components.meterRegistry, blobFs);\n recoveryService.startAsync();\n recoveryService.awaitRunning(DEFAULT_START_STOP_DURATION);\n long startOffset = 1;\n long endOffset = msgsToProduce - 1;\n RecoveryTaskMetadata recoveryTask =\n new RecoveryTaskMetadata(\n topicPartition.topic(),\n Integer.toString(topicPartition.partition()),\n startOffset,\n endOffset,\n Instant.now().toEpochMilli());\n assertThat(recoveryService.handleRecoveryTask(recoveryTask)).isTrue();\n assertThat(getCount(RECORDS_NO_LONGER_AVAILABLE, components.meterRegistry))\n .isEqualTo(endOffset - startOffset + 1);\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, components.meterRegistry)).isEqualTo(0);\n List snapshots =\n AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore);\n assertThat(snapshots.size()).isEqualTo(0);\n assertThat(blobFs.listFiles(BlobFsUtils.createURI(TEST_S3_BUCKET, \"/\", \"\"), true)).isEmpty();\n assertThat(getCount(MESSAGES_FAILED_COUNTER, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_INITIATED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_COMPLETED, meterRegistry)).isEqualTo(0);\n assertThat(getCount(ROLLOVERS_FAILED, meterRegistry)).isEqualTo(0);\n }", + "completions": [ + "```java\n @Test\n public void testShouldHandleRecoveryTaskWithCompletelyUnavailableOffsets() throws Exception {\n // Given\n RecoveryTaskMetadata recoveryTaskMetadata = new RecoveryTaskMetadata(\n \"testTask\",\n 0,\n 0,\n 100,\n System.currentTimeMillis());\n\n // When\n boolean result = handleRecoveryTask(recoveryTaskMetadata);\n\n // Then\n assertTrue(result);\n }\n```\n" + ], + "line": 97, + "token": 544, + "line_diff": 7, + "token_diff": 6 + }, + { + "id": 87, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public static CacheControlHeader fromExpirationTime(String expirationTime) {\n String maxAge = DEFAULT_MAX_AGE;\n\n if (expirationTime != null) {\n if (expirationTimeIsValid(expirationTime)) {\n Duration totalTime = Duration.ZERO;\n for (String timeString : Splitter.on(\" \").split(expirationTime)) {\n String timeUnitShort = Ascii.toLowerCase(timeString.substring(timeString.length() - 1));\n TemporalUnit timeUnit = EXPIRATION_TIME_UNITS.get(timeUnitShort);\n String timeValue = timeString.substring(0, timeString.length() - 1);\n totalTime = totalTime.plus(Long.parseLong(timeValue), timeUnit);\n }\n maxAge = String.valueOf(totalTime.getSeconds());\n } else {\n logger.atWarning().log(\n \"Failed to parse expiration time: \\\"%s\\\". Using default value instead.\",\n expirationTime\n );\n }\n }\n\n String output = DEFAULT_BASE_VALUE + maxAge;\n return new CacheControlHeader(output);\n }\n```\n###test function signature:\n```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits", + "reference": " @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n CacheControlHeader cacheControlHeader = CacheControlHeader.fromExpirationTime(\"3g\");\n assertThat(cacheControlHeader.getValue()).isEqualTo(\"public, max-age=600\");\n }", + "completions": [ + "```java\n @Test\n public void fromExpirationTime_usesDefaultMaxAgeForIncorrectTimeUnits() throws Exception {\n String expirationTime = \"10x\";\n CacheControlHeader header = CacheControlHeader.fromExpirationTime(expirationTime);\n assertEquals(CacheControlHeader.DEFAULT_BASE_VALUE + CacheControlHeader.DEFAULT_MAX_AGE, header.getValue());\n }\n```\n" + ], + "line": 33, + "token": 241, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 88, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public T clone(T object) {\n if (object instanceof String) {\n return object;\n } else if (object instanceof Collection) {\n Object firstElement = findFirstNonNullElement((Collection) object);\n if (firstElement != null && !(firstElement instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), firstElement.getClass());\n return objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n } else if (object instanceof Map) {\n Map.Entry firstEntry = this.findFirstNonNullEntry((Map) object);\n if (firstEntry != null) {\n Object key = firstEntry.getKey();\n Object value = firstEntry.getValue();\n if (!(key instanceof Serializable) || !(value instanceof Serializable)) {\n JavaType type = TypeFactory.defaultInstance().constructParametricType(object.getClass(), key.getClass(), value.getClass());\n return (T) objectMapperWrapper.fromBytes(objectMapperWrapper.toBytes(object), type);\n }\n }\n } else if (object instanceof JsonNode) {\n return (T) ((JsonNode) object).deepCopy();\n }\n\n if (object instanceof Serializable) {\n try {\n return (T) SerializationHelper.clone((Serializable) object);\n } catch (SerializationException e) {\n //it is possible that object itself implements java.io.Serializable, but underlying structure does not\n //in this case we switch to the other JSON marshaling strategy which doesn't use the Java serialization\n }\n }\n\n return jsonClone(object);\n }\n```\n###test function signature:\n```java\n @Test\n public void should_clone_empty_collection()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "should_clone_empty_collection", + "reference": " @Test\n public void should_clone_empty_collection() {\n List original = new ArrayList<>();\n Object cloned = serializer.clone(original);\n assertEquals(original, cloned);\n assertNotSame(original, cloned);\n }", + "completions": [ + "```java\n @Test\n public void should_clone_empty_collection() {\n // Given\n Collection collection = new ArrayList<>();\n\n // When\n Collection clonedCollection = (Collection) clone(collection);\n\n // Then\n assertNotSame(collection, clonedCollection);\n assertEquals(collection, clonedCollection);\n }\n```\n" + ], + "line": 44, + "token": 358, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 89, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public SearchResult search(\n String dataset,\n String queryStr,\n Long startTimeMsEpoch,\n Long endTimeMsEpoch,\n int howMany,\n AggBuilder aggBuilder) {\n\n ensureNonEmptyString(dataset, \"dataset should be a non-empty string\");\n ensureNonNullString(queryStr, \"query should be a non-empty string\");\n if (startTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch >= 0, \"start time should be non-negative value\");\n }\n if (startTimeMsEpoch != null && endTimeMsEpoch != null) {\n ensureTrue(startTimeMsEpoch < endTimeMsEpoch, \"end time should be greater than start time\");\n }\n ensureTrue(howMany >= 0, \"hits requested should not be negative.\");\n ensureTrue(howMany > 0 || aggBuilder != null, \"Hits or aggregation should be requested.\");\n\n ScopedSpan span = Tracing.currentTracer().startScopedSpan(\"LogIndexSearcherImpl.search\");\n span.tag(\"dataset\", dataset);\n span.tag(\"startTimeMsEpoch\", String.valueOf(startTimeMsEpoch));\n span.tag(\"endTimeMsEpoch\", String.valueOf(endTimeMsEpoch));\n span.tag(\"howMany\", String.valueOf(howMany));\n\n Stopwatch elapsedTime = Stopwatch.createStarted();\n try {\n // Acquire an index searcher from searcher manager.\n // This is a useful optimization for indexes that are static.\n IndexSearcher searcher = searcherManager.acquire();\n\n Query query =\n openSearchAdapter.buildQuery(\n dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, searcher);\n try {\n List results;\n InternalAggregation internalAggregation = null;\n\n if (howMany > 0) {\n CollectorManager topFieldCollector =\n buildTopFieldCollector(howMany, aggBuilder != null ? Integer.MAX_VALUE : howMany);\n MultiCollectorManager collectorManager;\n if (aggBuilder != null) {\n collectorManager =\n new MultiCollectorManager(\n topFieldCollector,\n openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n } else {\n collectorManager = new MultiCollectorManager(topFieldCollector);\n }\n Object[] collector = searcher.search(query, collectorManager);\n\n ScoreDoc[] hits = ((TopFieldDocs) collector[0]).scoreDocs;\n results = new ArrayList<>(hits.length);\n for (ScoreDoc hit : hits) {\n results.add(buildLogMessage(searcher, hit));\n }\n if (aggBuilder != null) {\n internalAggregation = (InternalAggregation) collector[1];\n }\n } else {\n results = Collections.emptyList();\n internalAggregation =\n searcher.search(\n query, openSearchAdapter.getCollectorManager(aggBuilder, searcher, query));\n }\n\n elapsedTime.stop();\n return new SearchResult<>(\n results, elapsedTime.elapsed(TimeUnit.MICROSECONDS), 0, 0, 1, 1, internalAggregation);\n } finally {\n searcherManager.release(searcher);\n }\n } catch (IOException e) {\n span.error(e);\n throw new IllegalArgumentException(\"Failed to acquire an index searcher.\", e);\n } finally {\n span.finish();\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testTimeBoundSearch()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testTimeBoundSearch", + "reference": " @Test\n public void testTimeBoundSearch() {\n Instant time = Instant.now();\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(1, time));\n strictLogStore.logStore.addMessage(SpanUtil.makeSpan(2, time.plusSeconds(100)));\n strictLogStore.logStore.commit();\n strictLogStore.logStore.refresh();\n\n assertThat(getCount(MESSAGES_RECEIVED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(2);\n assertThat(getCount(MESSAGES_FAILED_COUNTER, strictLogStore.metricsRegistry)).isEqualTo(0);\n assertThat(getTimerCount(REFRESHES_TIMER, strictLogStore.metricsRegistry)).isEqualTo(1);\n\n // Start inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.toEpochMilli(),\n time.plusSeconds(10).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Extended range still only picking one element.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"Message1\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(90).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(1);\n\n // Both ranges are inclusive.\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.toEpochMilli(),\n time.plusSeconds(100).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n\n // Extended range to pick up both events\n assertThat(\n strictLogStore\n .logSearcher\n .search(\n TEST_DATASET_NAME,\n \"_id:Message1 OR Message2\",\n time.minusSeconds(1).toEpochMilli(),\n time.plusSeconds(1000).toEpochMilli(),\n 1000,\n new DateHistogramAggBuilder(\n \"1\", LogMessage.SystemField.TIME_SINCE_EPOCH.fieldName, \"1s\"))\n .hits\n .size())\n .isEqualTo(2);\n }", + "completions": [ + "```java\n @Test\n public void testTimeBoundSearch() {\n // Given\n String dataset = \"test_dataset\";\n String queryStr = \"test_query\";\n Long startTimeMsEpoch = 1577836800000L; // 2020-01-01 00:00:00\n Long endTimeMsEpoch = 1577923199000L; // 2020-01-02 23:59:59\n int howMany = 10;\n AggBuilder aggBuilder = null;\n\n // When\n SearchResult result = logIndexSearcher.search(dataset, queryStr, startTimeMsEpoch, endTimeMsEpoch, howMany, aggBuilder);\n\n // Then\n assertNotNull(result);\n assertEquals(howMany, result.getResults().size());\n for (LogMessage logMessage : result.getResults()) {\n assertTrue(logMessage.getTimestampMs() >= startTimeMsEpoch);\n assertTrue(logMessage.getTimestampMs() <= endTimeMsEpoch);\n }\n }\n```\n" + ], + "line": 85, + "token": 606, + "line_diff": 6, + "token_diff": 7 + }, + { + "id": 90, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic void setUserData(String theKey, Object theValue) {\n\t\tValidate.isTrue(isNotBlank(theKey), \"Invalid key; key must be non-empty, non-null.\");\n\t\tif (theValue == null) {\n\t\t\tgetUserData().remove(theKey);\n\t\t} else {\n\t\t\tValidate.isTrue(\n\t\t\t\t\tvalidateValue(theValue),\n\t\t\t\t\tString.format(\n\t\t\t\t\t\t\t\"Invalid data type provided %s\", theValue.getClass().getName()));\n\t\t\tgetUserData().put(theKey, theValue);\n\t\t}\n\t}\n```\n###test function signature:\n```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "\t@ParameterizedTest\n\t@MethodSource", + "reference": "\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\t// setup\n\t\tString key = \"key\";\n\t\tTestParameters parameters = new TestParameters();\n\t\tObject testValue = theParams.getTestValue();\n\n\t\t// test\n\t\tif (theParams.isExpectedToWork()) {\n\t\t\tparameters.setUserData(key, testValue);\n\t\t\tassertFalse(parameters.getUserData().isEmpty());\n\t\t\tassertEquals(testValue, parameters.getUserData().get(key));\n\t\t} else {\n\t\t\ttry {\n\t\t\t\tparameters.setUserData(key, testValue);\n\t\t\t\tfail();\n\t\t\t} catch (IllegalArgumentException ex) {\n\t\t\t\tString dataType = testValue.getClass().getName();\n\t\t\t\tassertTrue(ex.getMessage().contains(\"Invalid data type provided \" + dataType),\n\t\t\t\t\tex.getMessage());\n\t\t\t\tassertTrue(parameters.getUserData().isEmpty());\n\t\t\t}\n\t\t}\n\t}", + "completions": [ + "```java\n\t@ParameterizedTest\n\t@MethodSource(\"parameters\")\n\tpublic void setUserData_acceptsStringNumberAndBooleansOnly(TestParam theParams) {\n\t\tUser user = new User();\n\t\tuser.setUserData(theParams.getKey(), theParams.getValue());\n\t\tassertEquals(theParams.getValue(), user.getUserData().get(theParams.getKey()));\n\t}\n\n\tstatic Stream parameters() {\n\t\treturn Stream.of(\n\t\t\t\tnew TestParam(\"stringKey\", \"stringValue\"),\n\t\t\t\tnew TestParam(\"numberKey\", 123),\n\t\t\t\tnew TestParam(\"booleanKey\", true),\n\t\t\t\tnew TestParam(\"nullKey\", null)\n\t\t);\n\t}\n\n\tstatic class TestParam {\n\t\tprivate final String key;\n\t\tprivate final Object value;\n\n\t\tTestParam(String key, Object value) {\n\t\t\tthis.key = key;\n\t\t\tthis.value = value;\n\t\t}\n\n\t\tString getKey() {\n\t\t\treturn key;\n\t\t}\n\n\t\tObject getValue() {\n\t\t\treturn value;\n\t\t}\n\t}\n```\n" + ], + "line": 24, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n\tpublic Embedding embedText(String text) {\n\t\tAssert.hasText(text, \"The text can not be null or empty.\");\n\n\t\t@JsonInclude(Include.NON_NULL)\n\t\trecord EmbeddingResponse(Embedding embedding) {\n\t\t}\n\n\t\tEmbeddingResponse response = this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class);\n\n\t\treturn response != null ? response.embedding() : null;\n\t}\n```\n###test function signature:\n```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "embedText", + "reference": "\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\n\t\tString text = \"Hello, how are you?\";\n\n\t\tEmbedding expectedEmbedding = new Embedding(List.of(0.1, 0.2, 0.3));\n\n\t\tserver\n\t\t\t.expect(requestToUriTemplate(\"/models/{generative}:embedText?key={apiKey}\",\n\t\t\t\t\tVertexAiPaLm2Api.DEFAULT_EMBEDDING_MODEL, TEST_API_KEY))\n\t\t\t.andExpect(method(HttpMethod.POST))\n\t\t\t.andExpect(content().json(objectMapper.writeValueAsString(Map.of(\"text\", text))))\n\t\t\t.andRespond(withSuccess(objectMapper.writeValueAsString(Map.of(\"embedding\", expectedEmbedding)),\n\t\t\t\t\tMediaType.APPLICATION_JSON));\n\n\t\tEmbedding embedding = client.embedText(text);\n\n\t\tassertThat(embedding).isEqualTo(expectedEmbedding);\n\n\t\tserver.verify();\n\t}", + "completions": [ + "```java\n\t@Test\n\tpublic void embedText() throws JsonProcessingException {\n\t\tString text = \"This is a test text.\";\n\t\tEmbedding expectedEmbedding = new Embedding(new double[] {0.1, 0.2, 0.3});\n\n\t\tEmbeddingResponse response = new EmbeddingResponse(expectedEmbedding);\n\t\twhen(this.restClient.post()\n\t\t\t\t.uri(\"/models/{model}:embedText?key={apiKey}\", this.embeddingModel, this.apiKey)\n\t\t\t\t.body(Map.of(\"text\", text))\n\t\t\t\t.retrieve()\n\t\t\t\t.body(EmbeddingResponse.class)).thenReturn(response);\n\n\t\tEmbedding actualEmbedding = this.embeddingService.embedText(text);\n\n\t\tassertEquals(expectedEmbedding, actualEmbedding);\n\t}\n```\n" + ], + "line": 23, + "token": 179, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public Mono syncAclWithAclCsv(KafkaCluster cluster, String csv) {\n return adminClientService.get(cluster)\n .flatMap(ac -> ac.listAcls(ResourcePatternFilter.ANY).flatMap(existingAclList -> {\n var existingSet = Set.copyOf(existingAclList);\n var newAcls = Set.copyOf(AclCsv.parseCsv(csv));\n var toDelete = Sets.difference(existingSet, newAcls);\n var toAdd = Sets.difference(newAcls, existingSet);\n logAclSyncPlan(cluster, toAdd, toDelete);\n if (toAdd.isEmpty() && toDelete.isEmpty()) {\n return Mono.empty();\n }\n log.info(\"Starting new ACLs creation\");\n return ac.createAcls(toAdd)\n .doOnSuccess(v -> {\n log.info(\"{} new ACLs created\", toAdd.size());\n log.info(\"Starting ACLs deletion\");\n })\n .then(ac.deleteAcls(toDelete)\n .doOnSuccess(v -> log.info(\"{} ACLs deleted\", toDelete.size())));\n }));\n }\n```\n###test function signature:\n```java\n @Test\n void testSyncAclWithAclCsv()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testSyncAclWithAclCsv", + "reference": " @Test\n void testSyncAclWithAclCsv() {\n var existingBinding1 = new AclBinding(\n new ResourcePattern(ResourceType.TOPIC, \"*\", PatternType.LITERAL),\n new AccessControlEntry(\"User:test1\", \"*\", AclOperation.READ, AclPermissionType.ALLOW));\n\n var existingBinding2 = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"group1\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test2\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n var newBindingToBeAdded = new AclBinding(\n new ResourcePattern(ResourceType.GROUP, \"groupNew\", PatternType.PREFIXED),\n new AccessControlEntry(\"User:test3\", \"localhost\", AclOperation.DESCRIBE, AclPermissionType.DENY));\n\n when(adminClientMock.listAcls(ResourcePatternFilter.ANY))\n .thenReturn(Mono.just(List.of(existingBinding1, existingBinding2)));\n\n ArgumentCaptor> createdCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.createAcls(createdCaptor.capture()))\n .thenReturn(Mono.empty());\n\n ArgumentCaptor> deletedCaptor = ArgumentCaptor.forClass(Collection.class);\n when(adminClientMock.deleteAcls(deletedCaptor.capture()))\n .thenReturn(Mono.empty());\n\n aclsService.syncAclWithAclCsv(\n CLUSTER,\n \"Principal,ResourceType, PatternType, ResourceName,Operation,PermissionType,Host\\n\"\n + \"User:test1,TOPIC,LITERAL,*,READ,ALLOW,*\\n\"\n + \"User:test3,GROUP,PREFIXED,groupNew,DESCRIBE,DENY,localhost\"\n ).block();\n\n Collection createdBindings = createdCaptor.getValue();\n assertThat(createdBindings)\n .hasSize(1)\n .contains(newBindingToBeAdded);\n\n Collection deletedBindings = deletedCaptor.getValue();\n assertThat(deletedBindings)\n .hasSize(1)\n .contains(existingBinding2);\n }", + "completions": [ + "```java\n @Test\n void testSyncAclWithAclCsv() {\n // Given\n KafkaCluster cluster = new KafkaCluster(\"testCluster\", \"localhost:9092\");\n String csv = \"principal,host,operation,permissionType\\nUser:Alice,*,Read,Allow\\nUser:Bob,*,Write,Allow\";\n AdminClient adminClient = Mockito.mock(AdminClient.class);\n when(adminClientService.get(cluster)).thenReturn(Mono.just(adminClient));\n when(adminClient.listAcls(ResourcePatternFilter.ANY)).thenReturn(Flux.just(\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Alice\", \"*\", AclOperation.READ, AclPermissionType.ALLOW)),\n new AclBinding(new ResourcePattern(ResourceType.TOPIC, \"testTopic\", PatternType.LITERAL),\n new AccessControlEntry(\"User:Bob\", \"*\", AclOperation.WRITE, AclPermissionType.ALLOW))\n ).collectList());\n\n // When\n Mono result = service.syncAclWithAclCsv(cluster, csv);\n\n // Then\n StepVerifier.create(result)\n .verifyComplete();\n\n verify(adminClient).createAcls(any());\n verify(adminClient).deleteAcls(any());\n }\n```\n" + ], + "line": 32, + "token": 259, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 93, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ProcessInstanceForListViewEntity getProcessInstanceListViewByKey(Long processInstanceKey) {\n try {\n final QueryBuilder query =\n joinWithAnd(\n idsQuery().addIds(String.valueOf(processInstanceKey)),\n termQuery(ListViewTemplate.PROCESS_INSTANCE_KEY, processInstanceKey));\n\n final SearchRequest request =\n ElasticsearchUtil.createSearchRequest(listViewTemplate, ALL)\n .source(new SearchSourceBuilder().query(constantScoreQuery(query)));\n\n final SearchResponse response;\n\n response = tenantAwareClient.search(request);\n final SearchHits searchHits = response.getHits();\n if (searchHits.getTotalHits().value == 1 && searchHits.getHits().length == 1) {\n return ElasticsearchUtil.fromSearchHit(\n searchHits.getAt(0).getSourceAsString(),\n objectMapper,\n ProcessInstanceForListViewEntity.class);\n } else if (response.getHits().getTotalHits().value > 1) {\n throw new NotFoundException(\n String.format(\n \"Could not find unique process instance with id '%s'.\", processInstanceKey));\n } else {\n throw new NotFoundException(\n (String.format(\"Could not find process instance with id '%s'.\", processInstanceKey)));\n }\n } catch (IOException e) {\n throw new OperateRuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testGetProcessInstanceListViewByKeyNoResults", + "reference": " @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n when(listViewTemplate.getAlias()).thenReturn(\"listViewIndexAlias\");\n\n final SearchResponse mockResponse = Mockito.mock(SearchResponse.class);\n final SearchHits mockHits = Mockito.mock(SearchHits.class);\n // Normally TotalHits would just be mocked, but Mockito can't stub or mock direct field accesses\n final TotalHits.Relation mockRelation = Mockito.mock(TotalHits.Relation.class);\n final TotalHits mockTotalHits = new TotalHits(0L, mockRelation);\n\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(mockTotalHits);\n when(tenantAwareClient.search(any())).thenReturn(mockResponse);\n\n assertThrows(NotFoundException.class, () -> underTest.getProcessInstanceListViewByKey(123L));\n }", + "completions": [ + "```java\n @Test\n public void testGetProcessInstanceListViewByKeyNoResults() throws IOException {\n // Given\n final Long processInstanceKey = 1L;\n final SearchResponse mockResponse = mock(SearchResponse.class);\n final SearchHits mockHits = mock(SearchHits.class);\n when(mockResponse.getHits()).thenReturn(mockHits);\n when(mockHits.getTotalHits()).thenReturn(new TotalHits(0L, TotalHits.Relation.EQUAL_TO));\n when(mockHits.getHits()).thenReturn(new SearchHit[0]);\n when(tenantAwareClient.search(any(SearchRequest.class))).thenReturn(mockResponse);\n\n // When\n try {\n processInstanceService.getProcessInstanceListViewByKey(processInstanceKey);\n fail(\"Expected NotFoundException\");\n } catch (NotFoundException e) {\n // Then\n assertEquals(\n \"Could not find process instance with id '1'.\", e.getMessage());\n }\n }\n```\n" + ], + "line": 41, + "token": 286, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 94, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public long determineStartingOffset(\n long currentEndOffsetForPartition,\n long currentBeginningOffsetForPartition,\n AstraConfigs.IndexerConfig indexerConfig) {\n // Filter stale snapshots for partition.\n if (partitionId == null) {\n LOG.warn(\"PartitionId can't be null.\");\n }\n\n List snapshots = snapshotMetadataStore.listSync();\n List snapshotsForPartition =\n snapshots.stream()\n .filter(\n snapshotMetadata -> {\n if (snapshotMetadata == null || snapshotMetadata.partitionId == null) {\n LOG.warn(\n \"snapshot metadata or partition id can't be null: {} \",\n Strings.join(snapshots, ','));\n }\n return snapshotMetadata != null\n && snapshotMetadata.partitionId != null\n && snapshotMetadata.partitionId.equals(partitionId);\n })\n .collect(Collectors.toUnmodifiableList());\n List deletedSnapshots = deleteStaleLiveSnapshots(snapshotsForPartition);\n\n List nonLiveSnapshotsForPartition =\n snapshotsForPartition.stream()\n .filter(s -> !deletedSnapshots.contains(s))\n .collect(Collectors.toUnmodifiableList());\n\n // Get the highest offset that is indexed in durable store.\n List recoveryTasks = recoveryTaskMetadataStore.listSync();\n long highestDurableOffsetForPartition =\n getHighestDurableOffsetForPartition(\n nonLiveSnapshotsForPartition, recoveryTasks, partitionId);\n LOG.debug(\n \"The highest durable offset for partition {} is {}\",\n partitionId,\n highestDurableOffsetForPartition);\n\n if (highestDurableOffsetForPartition <= 0) {\n LOG.info(\"There is no prior offset for this partition {}.\", partitionId);\n\n // If the user wants to start at the current offset in Kafka and _does not_ want to create\n // recovery tasks to backfill, then we can just return the current offset.\n // If the user wants to start at the current offset in Kafka and _does_ want to create\n // recovery tasks to backfill, then we create the recovery tasks needed and then return\n // the current offset for the indexer. And if the user does _not_ want to start at the\n // current offset in Kafka, then we'll just default to the old behavior of starting from\n // the very beginning\n if (!indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n LOG.info(\n \"CreateRecoveryTasksOnStart is set to false and ReadLocationOnStart is set to current. Reading from current and\"\n + \" NOT spinning up recovery tasks\");\n return currentEndOffsetForPartition;\n } else if (indexerConfig.getCreateRecoveryTasksOnStart()\n && indexerConfig.getReadFromLocationOnStart()\n == AstraConfigs.KafkaOffsetLocation.LATEST) {\n // Todo - this appears to be able to create recovery tasks that have a start and end\n // position of 0, which is invalid. This seems to occur when new clusters are initialized,\n // and is especially problematic when indexers are created but never get assigned (ie,\n // deploy 5, only assign 3).\n LOG.info(\n \"CreateRecoveryTasksOnStart is set and ReadLocationOnStart is set to current. Reading from current and\"\n + \" spinning up recovery tasks\");\n createRecoveryTasks(\n partitionId,\n currentBeginningOffsetForPartition,\n currentEndOffsetForPartition,\n indexerConfig.getMaxMessagesPerChunk());\n return currentEndOffsetForPartition;\n\n } else {\n return highestDurableOffsetForPartition;\n }\n }\n\n // The current head offset shouldn't be lower than the highest durable offset. If it is it\n // means that we indexed more data than the current head offset. This is either a bug in the\n // offset handling mechanism or the kafka partition has rolled over. We throw an exception\n // for now, so we can investigate.\n if (currentEndOffsetForPartition < highestDurableOffsetForPartition) {\n final String message =\n String.format(\n \"The current head for the partition %d can't \"\n + \"be lower than the highest durable offset for that partition %d\",\n currentEndOffsetForPartition, highestDurableOffsetForPartition);\n LOG.error(message);\n throw new IllegalStateException(message);\n }\n\n // The head offset for Kafka partition is the offset of the next message to be indexed. We\n // assume that offset is passed into this function. The highest durable offset is the partition\n // offset of the message that is indexed. Hence, the offset is incremented by 1 to get the\n // next message.\n long nextOffsetForPartition = highestDurableOffsetForPartition + 1;\n\n // Create a recovery task if needed.\n if (currentEndOffsetForPartition - highestDurableOffsetForPartition > maxOffsetDelay) {\n LOG.info(\n \"Recovery task needed. The current position {} and head location {} are higher than max\"\n + \" offset {}\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay);\n createRecoveryTasks(\n partitionId,\n nextOffsetForPartition,\n currentEndOffsetForPartition - 1,\n maxMessagesPerRecoveryTask);\n return currentEndOffsetForPartition;\n } else {\n LOG.info(\n \"The difference between the last indexed position {} and head location {} is lower \"\n + \"than max offset {}. So, using {} position as the start offset\",\n highestDurableOffsetForPartition,\n currentEndOffsetForPartition,\n maxOffsetDelay,\n nextOffsetForPartition);\n return nextOffsetForPartition;\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testDetermineStartingOffsetOnlyMultipleRecoveryBehind", + "reference": " @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n RecoveryTaskCreator recoveryTaskCreator =\n new RecoveryTaskCreator(\n snapshotMetadataStore,\n recoveryTaskStore,\n partitionId,\n 100,\n TEST_MAX_MESSAGES_PER_RECOVERY_TASK,\n meterRegistry);\n\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n assertThat(AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore)).isEmpty();\n\n // When there is no data return -1.\n assertThat(recoveryTaskCreator.determineStartingOffset(1000, 0, indexerConfig)).isNegative();\n final String recoveryTaskName = \"BasicRecoveryTask\";\n final long recoveryStartOffset = 400;\n final long createdTimeUtc = Instant.now().toEpochMilli();\n\n final RecoveryTaskMetadata recoveryTask1 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"1\",\n partitionId,\n recoveryStartOffset,\n recoveryStartOffset * 2,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask1);\n final RecoveryTaskMetadata recoveryTask11 =\n new RecoveryTaskMetadata(\n recoveryTaskName + \"11\",\n partitionId,\n recoveryStartOffset * 2 + 1,\n recoveryStartOffset * 3,\n createdTimeUtc);\n recoveryTaskStore.createSync(recoveryTask11);\n await()\n .until(\n () -> recoveryTaskStore.listSync().containsAll(List.of(recoveryTask1, recoveryTask11)));\n\n final long currentHeadOffset = 4000;\n assertThat(recoveryTaskCreator.determineStartingOffset(currentHeadOffset, 0, indexerConfig))\n .isEqualTo(currentHeadOffset);\n List recoveryTasks =\n AstraMetadataTestUtils.listSyncUncached(recoveryTaskStore);\n assertThat(recoveryTasks.size()).isEqualTo(3);\n assertThat(recoveryTasks).contains(recoveryTask1, recoveryTask11);\n Optional newRecoveryTask =\n recoveryTasks.stream().filter(r -> !r.name.contains(recoveryTaskName)).findFirst();\n assertThat(newRecoveryTask).isNotEmpty();\n RecoveryTaskMetadata recoveryTask = newRecoveryTask.get();\n assertThat(recoveryTask.startOffset).isEqualTo(recoveryStartOffset * 3 + 1);\n assertThat(recoveryTask.endOffset).isEqualTo(currentHeadOffset - 1);\n assertThat(AstraMetadataTestUtils.listSyncUncached(snapshotMetadataStore)).isEmpty();\n }", + "completions": [ + "```java\n @Test\n public void testDetermineStartingOffsetOnlyMultipleRecoveryBehind() {\n // Given\n String partitionId = \"testPartition\";\n long currentEndOffsetForPartition = 100L;\n long currentBeginningOffsetForPartition = 0L;\n AstraConfigs.IndexerConfig indexerConfig = new AstraConfigs.IndexerConfig();\n indexerConfig.setCreateRecoveryTasksOnStart(true);\n indexerConfig.setReadFromLocationOnStart(AstraConfigs.KafkaOffsetLocation.LATEST);\n indexerConfig.setMaxMessagesPerChunk(100);\n\n // When\n long startingOffset = determineStartingOffset(\n currentEndOffsetForPartition, currentBeginningOffsetForPartition, indexerConfig);\n\n // Then\n assertEquals(101L, startingOffset);\n }\n```\n" + ], + "line": 127, + "token": 935, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 95, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n private String saveAndChecksumBlockList(String jobName, LazyBlockList blockList) {\n BlobId blobId = getBlobId(jobName, getBlockListFileName(blockList.getName()));\n try (BufferedOutputStream gcsWriter =\n new BufferedOutputStream(gcsUtils.openOutputStream(blobId))) {\n MessageDigest messageDigest = MessageDigest.getInstance(checksumAlgorithm);\n blockList.consumeAll(\n (byteArray, length) -> {\n try {\n gcsWriter.write(byteArray, 0, length);\n } catch (IOException e) {\n throw new RuntimeException(e);\n }\n messageDigest.update(byteArray, 0, length);\n });\n return base16().lowerCase().encode(messageDigest.digest());\n } catch (IOException | NoSuchAlgorithmException e) {\n throw new RuntimeException(e);\n }\n }\n```\n###test function signature:\n```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "saveAndChecksumBlockList_success", + "reference": " @Test\n void saveAndChecksumBlockList_success() throws Exception {\n String payload = \"somedata\\n\";\n String payloadChecksum = \"0737c8e591c68b93feccde50829aca86a80137547d8cfbe96bab6b20f8580c63\";\n\n when(connection.getInputStream())\n .thenReturn(new ByteArrayInputStream((\"bsa-checksum\\n\" + payload).getBytes(UTF_8)));\n lazyBlockList = new LazyBlockList(BlockListType.BLOCK, connection);\n\n ImmutableMap checksums =\n gcsClient.saveAndChecksumBlockList(\"some-name\", ImmutableList.of(lazyBlockList));\n assertThat(gcsUtils.existsAndNotEmpty(BlobId.of(\"my-bucket\", \"some-name/BLOCK.csv\"))).isTrue();\n assertThat(checksums).containsExactly(BlockListType.BLOCK, payloadChecksum);\n assertThat(gcsClient.readBlockList(\"some-name\", BlockListType.BLOCK))\n .containsExactly(\"somedata\");\n }", + "completions": [ + "```java\n @Test\n void saveAndChecksumBlockList_success() throws Exception {\n // Given\n String jobName = \"testJob\";\n LazyBlockList blockList = mock(LazyBlockList.class);\n BlobId blobId = mock(BlobId.class);\n BufferedOutputStream gcsWriter = mock(BufferedOutputStream.class);\n MessageDigest messageDigest = mock(MessageDigest.class);\n when(gcsUtils.openOutputStream(blobId)).thenReturn(gcsWriter);\n when(checksumAlgorithm).thenReturn(\"SHA-256\");\n when(messageDigest.digest()).thenReturn(new byte[] {1, 2, 3, 4, 5});\n when(base16().lowerCase().encode(messageDigest.digest())).thenReturn(\"0102030405\");\n\n // When\n String checksum = saveAndChecksumBlockList(jobName, blockList);\n\n // Then\n assertEquals(\"0102030405\", checksum);\n verify(gcsWriter).write(any(), anyInt(), anyInt());\n verify(messageDigest).update(any(), anyInt(), anyInt());\n }\n```\n" + ], + "line": 30, + "token": 211, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 96, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public ListenableFuture> findMissingBlobs(\n Iterable blobDigests, RequestMetadata requestMetadata) {\n // Some requests have been blocked, and we should tell the client we refuse to perform a lookup.\n try {\n if (inDenyList(requestMetadata)) {\n return immediateFailedFuture(\n Status.UNAVAILABLE\n .withDescription(\"The action associated with this request is forbidden\")\n .asException());\n }\n } catch (IOException e) {\n return immediateFailedFuture(Status.fromThrowable(e).asException());\n }\n\n // Empty blobs are an exceptional case. Filter them out.\n // If the user only requested empty blobs we can immediately tell them we already have it.\n Iterable nonEmptyDigests =\n Iterables.filter(blobDigests, (digest) -> digest.getSizeBytes() != 0);\n if (Iterables.isEmpty(nonEmptyDigests)) {\n return immediateFuture(ImmutableList.of());\n }\n\n if (configs.getServer().isFindMissingBlobsViaBackplane()) {\n return findMissingBlobsViaBackplane(nonEmptyDigests, requestMetadata);\n }\n\n return findMissingBlobsQueryingEachWorker(nonEmptyDigests, requestMetadata);\n }\n```\n###test function signature:\n```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception\n```\n### Answer: (use the provided format with backticks)\n", + "name": "findMissingBlobsTest_ViaBackPlane", + "reference": " @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n Set activeWorkers = ImmutableSet.of(\"worker1\", \"worker2\", \"worker3\");\n Set expiredWorkers = ImmutableSet.of(\"workerX\", \"workerY\", \"workerZ\");\n Set imposterWorkers = ImmutableSet.of(\"imposter1\", \"imposter2\", \"imposter3\");\n\n Set availableDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFound1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFound3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build());\n\n Set missingDigests =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"missing1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missing3\").setSizeBytes(1).build(),\n // a copy is added in final digest list\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build());\n\n Set digestAvailableOnImposters =\n ImmutableSet.of(\n Digest.newBuilder().setHash(\"toBeFoundOnImposter1\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter2\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"toBeFoundOnImposter3\").setSizeBytes(1).build());\n\n Set emptyDigests =\n new HashSet<>(\n Arrays.asList(\n Digest.newBuilder().setHash(\"empty1\").build(),\n Digest.newBuilder().setHash(\"empty2\").build()));\n\n Iterable allDigests =\n Iterables.concat(\n availableDigests,\n missingDigests,\n emptyDigests,\n digestAvailableOnImposters,\n Arrays.asList(\n Digest.newBuilder().setHash(\"toBeFoundDuplicate\").setSizeBytes(1).build(),\n Digest.newBuilder().setHash(\"missingDuplicate\").setSizeBytes(1).build()));\n\n Map> digestAndWorkersMap = new HashMap<>();\n\n for (Digest digest : availableDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(activeWorkers));\n }\n for (Digest digest : missingDigests) {\n digestAndWorkersMap.put(digest, getRandomSubset(expiredWorkers));\n }\n for (Digest digest : digestAvailableOnImposters) {\n digestAndWorkersMap.put(digest, getRandomSubset(imposterWorkers));\n }\n\n BuildfarmConfigs buildfarmConfigs = instance.getBuildFarmConfigs();\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(true);\n Set activeAndImposterWorkers =\n Sets.newHashSet(Iterables.concat(activeWorkers, imposterWorkers));\n\n when(mockBackplane.getStorageWorkers()).thenReturn(activeAndImposterWorkers);\n when(mockBackplane.getBlobDigestsWorkers(any(Iterable.class))).thenReturn(digestAndWorkersMap);\n when(mockInstanceLoader.load(anyString())).thenReturn(mockWorkerInstance);\n when(mockWorkerInstance.findMissingBlobs(anyIterable(), any(RequestMetadata.class)))\n .thenReturn(Futures.immediateFuture(new ArrayList<>()));\n\n long serverStartTime = 1686951033L; // june 15th, 2023\n Map workersStartTime = new HashMap<>();\n for (String worker : activeAndImposterWorkers) {\n workersStartTime.put(worker, serverStartTime);\n }\n when(mockBackplane.getWorkersStartTimeInEpochSecs(activeAndImposterWorkers))\n .thenReturn(workersStartTime);\n long oneDay = 86400L;\n for (Digest digest : availableDigests) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime + oneDay);\n }\n for (Digest digest : digestAvailableOnImposters) {\n when(mockBackplane.getDigestInsertTime(digest)).thenReturn(serverStartTime - oneDay);\n }\n\n Iterable actualMissingDigests =\n instance.findMissingBlobs(allDigests, RequestMetadata.getDefaultInstance()).get();\n Iterable expectedMissingDigests =\n Iterables.concat(missingDigests, digestAvailableOnImposters);\n\n assertThat(actualMissingDigests).containsExactlyElementsIn(expectedMissingDigests);\n verify(mockWorkerInstance, atMost(3))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n verify(mockWorkerInstance, atLeast(1))\n .findMissingBlobs(anyIterable(), any(RequestMetadata.class));\n\n for (Digest digest : actualMissingDigests) {\n assertThat(digest).isNotIn(availableDigests);\n assertThat(digest).isNotIn(emptyDigests);\n assertThat(digest).isIn(expectedMissingDigests);\n }\n\n // reset BuildfarmConfigs\n buildfarmConfigs.getServer().setFindMissingBlobsViaBackplane(false);\n }", + "completions": [ + "```java\n @Test\n public void findMissingBlobsTest_ViaBackPlane() throws Exception {\n // Given\n Iterable blobDigests = Arrays.asList(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n RequestMetadata requestMetadata = RequestMetadata.newBuilder().setClientId(\"client1\").build();\n\n // When\n ListenableFuture> future =\n underTest.findMissingBlobs(blobDigests, requestMetadata);\n\n // Then\n Iterable missingBlobs = future.get();\n assertThat(missingBlobs).containsExactly(\n Digest.newBuilder().setHash(\"hash1\").setSizeBytes(10).build(),\n Digest.newBuilder().setHash(\"hash2\").setSizeBytes(20).build(),\n Digest.newBuilder().setHash(\"hash3\").setSizeBytes(30).build()\n );\n }\n```\n" + ], + "line": 37, + "token": 276, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 97, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n public ComponentXO readComponent(final Repository repository, final EntityId componentId) {\n Component component;\n List assets;\n try (StorageTx storageTx = repository.facet(StorageFacet.class).txSupplier().get()) {\n storageTx.begin();\n component = storageTx.findComponent(componentId);\n if (component == null) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n Iterable browsedAssets = storageTx.browseAssets(component);\n if (browsedAssets == null || Iterables.isEmpty(browsedAssets)) {\n throw new WebApplicationException(Status.NOT_FOUND);\n }\n\n assets = newArrayList(browsedAssets);\n }\n ensurePermissions(repository, assets, BreadActions.BROWSE);\n return COMPONENT_CONVERTER(component, repository.getName());\n }\n```\n###test function signature:\n```java\n @Test\n public void testReadComponent_notFound()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testReadComponent_notFound", + "reference": " @Test\n public void testReadComponent_notFound() {\n Component component = mock(Component.class);\n when(storageTx.findComponent(eq(new DetachedEntityId(\"someid\")))).thenReturn(component);\n when(storageTx.browseAssets(component)).thenReturn(null);\n try {\n underTest.readComponent(\"someid\", \"testRepositoryName\");\n fail(\"Exception should have been thrown\");\n }\n catch (WebApplicationException e) {\n assertThat(e.getResponse(), is(notNullValue()));\n assertThat(e.getResponse().getStatus(), is(404));\n }\n }", + "completions": [ + "```java\n @Test\n public void testReadComponent_notFound() {\n // Given\n Repository repository = mock(Repository.class);\n EntityId componentId = mock(EntityId.class);\n StorageFacet storageFacet = mock(StorageFacet.class);\n StorageTx storageTx = mock(StorageTx.class);\n when(repository.facet(StorageFacet.class)).thenReturn(storageFacet);\n when(storageFacet.txSupplier()).thenReturn(() -> storageTx);\n when(storageTx.findComponent(componentId)).thenReturn(null);\n\n // When\n try {\n readComponent(repository, componentId);\n fail(\"Expected WebApplicationException\");\n } catch (WebApplicationException e) {\n // Then\n assertEquals(Status.NOT_FOUND.getStatusCode(), e.getResponse().getStatus());\n }\n }\n```\n" + ], + "line": 29, + "token": 205, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 98, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @Override\n public CompletableFuture auth(MQTT3AuthData authData) {\n try {\n HttpRequest request = HttpRequest.newBuilder()\n .uri(URI.create(webhookURI + \"/auth\"))\n .header(\"Content-Type\", \"application/json\")\n .POST(HttpRequest.BodyPublishers.ofString(JsonFormat.printer().print(authData)))\n .timeout(Duration.ofSeconds(5))\n .build();\n return httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString())\n .thenApply(response -> {\n if (response.statusCode() == 200) {\n try {\n MQTT3AuthResult.Builder resultBuilder = MQTT3AuthResult.newBuilder();\n JsonFormat.parser()\n .ignoringUnknownFields()\n .merge(response.body(), resultBuilder);\n return resultBuilder.build();\n } catch (Throwable e) {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build();\n }\n } else {\n return MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(\"Authenticate failed\")\n .build())\n .build();\n }\n })\n .exceptionally(e -> {\n System.out.println(\"Failed to call webhook: \" + e.getMessage());\n return null;\n });\n } catch (Throwable e) {\n return CompletableFuture.completedFuture(MQTT3AuthResult.newBuilder()\n .setReject(Reject.newBuilder()\n .setCode(Reject.Code.Error)\n .setReason(e.getMessage())\n .build())\n .build());\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testAuth()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testAuth", + "reference": " @Test\n public void testAuth() {\n MQTT3AuthData authedUser = MQTT3AuthData.newBuilder()\n .setUsername(\"authUser\")\n .build();\n\n MQTT3AuthData unauthUser = MQTT3AuthData.newBuilder()\n .setUsername(\"unauthUser\")\n .build();\n testServer.addAuthedUser(authedUser.getUsername());\n WebHookBasedAuthProvider provider = new WebHookBasedAuthProvider(testServer.getURI());\n MQTT3AuthResult authResult = provider.auth(authedUser).join();\n assertTrue(authResult.hasOk());\n\n authResult = provider.auth(unauthUser).join();\n assertFalse(authResult.hasOk());\n }", + "completions": [ + "```java\n @Test\n public void testAuth() {\n // Given\n String webhookURI = \"http://localhost:8080\";\n HttpClient httpClient = HttpClient.newHttpClient();\n MQTT3AuthData authData = MQTT3AuthData.newBuilder()\n .setClientId(\"client1\")\n .setUsername(\"user1\")\n .setPassword(\"password1\")\n .build();\n MQTT3AuthResult expectedResult = MQTT3AuthResult.newBuilder()\n .setSuccess(Success.newBuilder()\n .setIsSuperuser(false)\n .build())\n .build();\n\n // When\n CompletableFuture result = new MQTT3AuthService(webhookURI, httpClient).auth(authData);\n\n // Then\n assertNotNull(result);\n assertEquals(expectedResult, result.join());\n }\n```\n" + ], + "line": 59, + "token": 335, + "line_diff": 5, + "token_diff": 3 + }, + { + "id": 99, + "prompt": "\nYou are an expert Java programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be java test function and without anything else.\n\n### focal method:\n```java\n @VisibleForTesting\n void validateRoleDuplicate(String name, String code, Long id) {\n // 0. 超级管理员,不允许创建\n if (RoleCodeEnum.isSuperAdmin(code)) {\n throw exception(ROLE_ADMIN_CODE_ERROR, code);\n }\n // 1. 该 name 名字被其它角色所使用\n RoleDO role = roleMapper.selectByName(name);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_NAME_DUPLICATE, name);\n }\n // 2. 是否存在相同编码的角色\n if (!StringUtils.hasText(code)) {\n return;\n }\n // 该 code 编码被其它角色所使用\n role = roleMapper.selectByCode(code);\n if (role != null && !role.getId().equals(id)) {\n throw exception(ROLE_CODE_DUPLICATE, code);\n }\n }\n```\n###test function signature:\n```java\n @Test\n public void testValidateRoleDuplicate_success()\n```\n### Answer: (use the provided format with backticks)\n", + "name": "testValidateRoleDuplicate_success", + "reference": " @Test\n public void testValidateRoleDuplicate_success() {\n // 调用,不会抛异常\n roleService.validateRoleDuplicate(randomString(), randomString(), null);\n }", + "completions": [ + "```java\n @Test\n public void testValidateRoleDuplicate_success() {\n // Given\n String name = \"TestRole\";\n String code = \"TEST_ROLE\";\n Long id = 1L;\n\n // When\n // Call the method to be tested\n validateRoleDuplicate(name, code, id);\n\n // Then\n // No exception should be thrown\n }\n```\n" + ], + "line": 32, + "token": 213, + "line_diff": 1, + "token_diff": 1 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Java/8/QS/line_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Java/8/QS/line_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..470eb4fe296571d36af9c8d31062df0f727b777a --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/8/QS/line_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(22~26),subset_1(26~32),subset_2(33~38),subset_3(39~44),subset_4(44~51),subset_5(51~74),subset_6(85~85),subset_7(94~128) +StarCoder2-15b,33.53,31.84,29.20,35.55,32.95,25.96,23.55,17.77 +CodeLlama-7b,37.14,34.38,28.18,36.84,32.22,29.76,19.82,23.49 +CodeLlama-13b,30.62,31.71,26.14,34.43,31.27,27.68,21.71,21.73 +CodeLlama-34b,33.27,33.66,24.49,36.26,30.75,28.87,21.69,23.39 +DeepSeek-Coder-1.3b,27.48,29.11,30.19,34.76,30.91,31.49,23.99,20.56 +DeepSeek-Coder-6.7b,16.65,22.56,22.65,21.89,18.83,21.96,22.51,22.39 +DeepSeek-Coder-33b,29.25,35.16,35.16,38.43,34.81,31.47,23.15,20.25 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/8/QS/token_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Java/8/QS/token_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..83952a9045259b67cd16529b52bb64775f83cb7e --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/8/QS/token_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(139~179),subset_1(181~241),subset_2(241~290),subset_3(300~358),subset_4(358~367),subset_5(373~533),subset_6(544~606),subset_7(606~952) +StarCoder2-15b,30.41,35.66,25.70,31.79,38.08,27.93,21.87,19.46 +CodeLlama-7b,32.83,37.73,29.81,32.13,35.39,30.97,21.23,22.08 +CodeLlama-13b,27.04,36.15,26.33,30.35,34.57,27.66,23.15,20.30 +CodeLlama-34b,30.22,33.89,25.98,31.87,40.98,24.83,21.86,23.22 +DeepSeek-Coder-1.3b,26.76,30.70,29.82,30.10,38.03,28.88,25.47,19.08 +DeepSeek-Coder-6.7b,18.59,20.53,22.52,23.10,20.67,19.06,22.10,22.80 +DeepSeek-Coder-33b,29.98,38.19,30.10,33.34,39.18,34.03,22.20,21.20 diff --git a/dataset/Test Generation/ComplexCodeEval-Java/8/QS/tongji.py b/dataset/Test Generation/ComplexCodeEval-Java/8/QS/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Java/8/QS/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/3/EI/EI.json b/dataset/Test Generation/ComplexCodeEval-Python/3/EI/EI.json new file mode 100644 index 0000000000000000000000000000000000000000..377f1d801eb5d2a75006720d165e1909eec234e6 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/3/EI/EI.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_pair_vector_and_distance(g: dgl.DGLGraph):\n dst_pos = g.ndata[\"pos\"][g.edges()[1]] + g.edata[\"pbc_offshift\"]\n src_pos = g.ndata[\"pos\"][g.edges()[0]]\n bond_vec = dst_pos - src_pos\n bond_dist = torch.norm(bond_vec, dim=1)\n\n return bond_vec, bond_dist\n```\n###test function signature:\n```python\n\n def test_compute_angle(self, graph_Mo, graph_CH4):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_angle", + "reference": "\n def test_compute_angle(self, graph_Mo, graph_CH4):\n s1, g1, state1 = graph_Mo\n lattice = torch.tensor(s1.lattice.matrix, dtype=matgl.float_th).unsqueeze(dim=0)\n g1.edata[\"pbc_offshift\"] = torch.matmul(g1.edata[\"pbc_offset\"], lattice[0])\n g1.ndata[\"pos\"] = g1.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g1)\n g1.edata[\"bond_vec\"] = bv\n g1.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g1, 4.0)\n\n line_graph = create_line_graph(g1, 4.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))\n np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata[\"theta\"])), decimal=4)\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n s2, g2, state2 = graph_CH4\n lattice = torch.tensor(np.identity(3), dtype=matgl.float_th).unsqueeze(dim=0)\n g2.edata[\"pbc_offshift\"] = torch.matmul(g2.edata[\"pbc_offset\"], lattice[0])\n g2.ndata[\"pos\"] = g2.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g2)\n g2.edata[\"bond_vec\"] = bv\n g2.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g2, 2.0)\n\n line_graph = create_line_graph(g2, 2.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata[\"theta\"]))\n )\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )", + "completions": [ + "```python\n def test_compute_angle(self, graph_Mo, graph_CH4):\n \"\"\"\n Test compute_angle function\n \"\"\"\n # compute angle between Mo-C and Mo-H\n angle_Mo_C = compute_angle(graph_Mo, graph_CH4, 0, 1)\n angle_Mo_H = compute_angle(graph_Mo, graph_CH4, 0, 2)\n\n # check angle between Mo-C and Mo-H\n self.assertAlmostEqual(angle_Mo_C, 104.5, places=1)\n self.assertAlmostEqual(angle_Mo_H, 109.5, places=1)\n```\n" + ], + "line": 16, + "token": 148, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 1, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def _make_schedule(self, targets: List[Task]) -> List[Task]:\n self.schedule = []\n self.dependencies = self._build_dependencies(targets)\n\n edge_tups = []\n for node in self.dependencies:\n for dependency in self.dependencies[node]:\n edge_tups.append((dependency, node))\n\n for task in targets:\n # add edges from a dummy node to each target to guarantee\n # they will be included in the final schedule\n edge_tups.append((Executor.DUMMY_TASK_VALUE, task))\n\n def _compare_key(task: Union[Task, str]):\n if task == Executor.DUMMY_TASK_VALUE:\n return (\"\", 0)\n return (\n task.group_label() or \"\",\n -task.priority(),\n )\n\n graph = networkx.DiGraph(edge_tups)\n res = [\n t\n for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)\n if t != Executor.DUMMY_TASK_VALUE\n ]\n return res\n```\n###test function signature:\n```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_group_label_with_dependencies", + "reference": "\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = create_mock_task(\"task1\", result=1, group_label=\"group1\")\n task2 = create_mock_task(\n \"task2\", result=2, dependencies={\"task1\": task1}, group_label=\"group2\"\n )\n task3 = create_mock_task(\n \"task3\", result=3, dependencies={\"task2\": task2}, group_label=\"group1\"\n )\n\n executor = Executor([task3])\n schedule = executor._make_schedule([task3])\n scheduled_labels = [\n task.group_label() for task in schedule if task.group_label()\n ]\n\n # Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1\n group1_indices = [\n i for i, label in enumerate(scheduled_labels) if label == \"group1\"\n ]\n group2_index = scheduled_labels.index(\"group2\")\n\n assert (\n group1_indices[-1] > group2_index\n ), \"Task with the same group label but later dependency was not scheduled after different group label\"", + "completions": [ + "```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = Task(name=\"task1\", group_label=\"group1\")\n task2 = Task(name=\"task2\", group_label=\"group1\")\n task3 = Task(name=\"task3\", group_label=\"group2\")\n task4 = Task(name=\"task4\", group_label=\"group2\")\n task5 = Task(name=\"task5\", group_label=\"group3\")\n task6 = Task(name=\"task6\", group_label=\"group3\")\n task7 = Task(name=\"task7\", group_label=\"group4\")\n task8 = Task(name=\"task8\", group_label=\"group4\")\n task9 = Task(name=\"task9\", group_label=\"group5\")\n task10 = Task(name=\"task10\", group_label=\"group5\")\n\n # Creating dependencies between tasks\n task1.depends_on(task2)\n task2.depends_on(task3)\n task3.depends_on(task4)\n task4.depends_on(task5)\n task5.depends_on(task6)\n task6.depends_on(task7)\n task7.depends_on(task8)\n task8.depends_on(task9)\n task9.depends_on(task10)\n\n # Creating a list of tasks\n tasks = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies\n tasks_with_dependencies = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with group labels\n tasks_with_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_" + ], + "line": 36, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 2, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n):\n if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:\n return\n\n head_dim = dim // heads\n max_seq_len = find_multiple(max_seq_len, 8)\n\n for b in layers:\n b.attention.kv_cache = KVCache(\n max_batch_size, max_seq_len, heads, head_dim\n )\n\n freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)\n causal_mask = torch.tril(\n torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)\n )\n\n return causal_mask, freq_cis\n```\n###test function signature:\n```python\n\ndef test_setup_cache_max_batch_size_greater_than_max():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_setup_cache_max_batch_size_greater_than_max", + "reference": "\ndef test_setup_cache_max_batch_size_greater_than_max():\n layers = [\n Mock(attention=Mock(kw_cache=None)),\n Mock(attention=Mock(kw_cache=None)),\n ]\n max_batch_size = 64\n max_seq_len = 32\n heads = 4\n head_dim = 32\n block_size = 32\n rope_base = 1000\n setup_cache(\n max_batch_size + 10,\n max_seq_len,\n head_dim * heads,\n heads,\n layers,\n block_size,\n rope_base,\n )\n for layer in layers:\n assert isinstance(layer.attention.kw_cache, KVCache)\n assert layer.attention.kw_cache.k_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )\n assert layer.attention.kw_cache.v_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )", + "completions": [ + "```python\ndef test_setup_cache_max_batch_size_greater_than_max():\n max_batch_size = 10\n max_seq_len = 10\n dim = 10\n heads = 10\n layers = [1, 2, 3]\n block_size = 10\n rope_base = 10\n\n causal_mask, freq_cis = setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n )\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size" + ], + "line": 26, + "token": 168, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_file_to_be_downloaded", + "reference": "\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n response = Response()\n response.status_code = 200\n image = np.zeros((128, 128, 3), dtype=np.uint8)\n _, encoded_image = cv2.imencode(\".jpg\", image)\n response._content = encoded_image\n requests_get_mock.return_value = response\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=\"https://some/file.jpg\"\n )\n\n # then\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert scaling_factor is None\n assert decoding_result.shape == image.shape\n assert np.allclose(decoding_result, image)\n requests_get_mock.assert_called_once_with(\"https://some/file.jpg\")", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n reference = \"https://www.example.com/image.jpg\"\n max_height = 100\n max_width = 100\n requests_get_mock.return_value.content = b\"image_bytes\"\n # when\n result = load_image_from_string(\n reference=reference, max_height=max_height, max_width=max_width\n )\n # then\n assert result == (\n \"data:image/jpeg;base64,image_bytes\",\n None,\n )\n```\n" + ], + "line": 47, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 4, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_random_circuits", + "reference": "\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n placer = placer(connectivity=connectivity)\n layout_circ = Circuit(5)\n initial_layout = placer(layout_circ)\n router = Sabre(connectivity=connectivity, lookahead=look, decay_lookahead=decay)\n circuit = generate_random_circuit(nqubits=5, ngates=n_gates)\n measurement = gates.M(*range(5))\n circuit.add(measurement)\n transpiled_circuit, final_qubit_map = router(circuit, initial_layout)\n assert router.added_swaps >= 0\n assert_connectivity(connectivity, transpiled_circuit)\n assert_placement(transpiled_circuit, final_qubit_map)\n assert n_gates + router.added_swaps + 1 == transpiled_circuit.ngates\n assert_circuit_equivalence(\n original_circuit=circuit,\n transpiled_circuit=transpiled_circuit,\n final_map=final_qubit_map,\n initial_map=initial_layout,\n )\n assert transpiled_circuit.queue[-1].register_name == measurement.register_name", + "completions": [ + "```python\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n import random\n import networkx as nx\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import" + ], + "line": 27, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 5, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef xx_circuit_step(source, strength, target, embodiment):\n\n permute_source_for_overlap, permute_target_for_overlap = None, None\n\n # apply all possible reflections, shifts to the source\n for source_reflection_name in reflection_options:\n reflected_source_coord, source_reflection, reflection_phase_shift = apply_reflection(\n source_reflection_name, source\n )\n for source_shift_name in shift_options:\n shifted_source_coord, source_shift, shift_phase_shift = apply_shift(\n source_shift_name, reflected_source_coord\n )\n\n # check for overlap, back out permutation\n source_shared, target_shared = None, None\n for i, j in [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]:\n\n if (\n abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi)) < EPSILON\n or abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi) - np.pi)\n < EPSILON\n ):\n source_shared, target_shared = i, j\n break\n if source_shared is None:\n continue\n\n # pick out the other coordinates\n source_first, source_second = (x for x in [0, 1, 2] if x != source_shared)\n target_first, target_second = (x for x in [0, 1, 2] if x != target_shared)\n\n # check for arccos validity\n r, s, u, v, x, y = decompose_xxyy_into_xxyy_xx(\n float(target[target_first]),\n float(target[target_second]),\n float(shifted_source_coord[source_first]),\n float(shifted_source_coord[source_second]),\n float(strength),\n )\n if any(math.isnan(val) for val in (r, s, u, v, x, y)):\n continue\n\n # OK: this combination of things works.\n # save the permutation which rotates the shared coordinate into ZZ.\n permute_source_for_overlap = canonical_rotation_circuit(source_first, source_second)\n permute_target_for_overlap = canonical_rotation_circuit(target_first, target_second)\n break\n\n if permute_source_for_overlap is not None:\n break\n\n if permute_source_for_overlap is None:\n raise QiskitError(\n \"Error during RZX decomposition: Could not find a suitable Weyl \"\n f\"reflection to match {source} to {target} along {strength}.\"\n )\n\n prefix_circuit, affix_circuit = QuantumCircuit(2), QuantumCircuit(2)\n\n # the basic formula we're trying to work with is:\n # target^p_t_f_o =\n # rs * (source^s_reflection * s_shift)^p_s_f_o * uv * operation * xy\n # but we're rearranging it into the form\n # target = affix source prefix\n # and computing just the prefix / affix circuits.\n\n # the outermost prefix layer comes from the (inverse) target permutation.\n prefix_circuit.compose(permute_target_for_overlap.inverse(), inplace=True)\n # the middle prefix layer comes from the local Z rolls.\n prefix_circuit.rz(2 * x, [0])\n prefix_circuit.rz(2 * y, [1])\n prefix_circuit.compose(embodiment, inplace=True)\n prefix_circuit.rz(2 * u, [0])\n prefix_circuit.rz(2 * v, [1])\n # the innermost prefix layer is source_reflection, shifted by source_shift,\n # finally conjugated by p_s_f_o.\n prefix_circuit.compose(permute_source_for_overlap, inplace=True)\n prefix_circuit.compose(source_reflection, inplace=True)\n prefix_circuit.global_phase += -np.log(reflection_phase_shift).imag\n prefix_circuit.global_phase += -np.log(shift_phase_shift).imag\n\n # the affix circuit is constructed in reverse.\n # first (i.e., innermost), we install the other half of the source transformations and p_s_f_o.\n affix_circuit.compose(source_reflection.inverse(), inplace=True)\n affix_circuit.compose(source_shift, inplace=True)\n affix_circuit.compose(permute_source_for_overlap.inverse(), inplace=True)\n # then, the other local rolls in the middle.\n affix_circuit.rz(2 * r, [0])\n affix_circuit.rz(2 * s, [1])\n # finally, the other half of the p_t_f_o conjugation.\n affix_circuit.compose(permute_target_for_overlap, inplace=True)\n\n return {\"prefix_circuit\": prefix_circuit, \"affix_circuit\": affix_circuit}\n```\n###test function signature:\n```python\n def test_xx_circuit_step(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_xx_circuit_step", + "reference": " def test_xx_circuit_step(self):\n\n for _ in range(100):\n source_coordinate, interaction, target_coordinate = self._generate_xxyy_test_case()\n\n source_embodiment = qiskit.QuantumCircuit(2)\n source_embodiment.append(UnitaryGate(canonical_matrix(*source_coordinate)), [0, 1])\n interaction_embodiment = qiskit.QuantumCircuit(2)\n interaction_embodiment.append(UnitaryGate(canonical_matrix(*interaction)), [0, 1])\n\n prefix_circuit, affix_circuit = itemgetter(\"prefix_circuit\", \"affix_circuit\")(\n xx_circuit_step(\n source_coordinate, interaction[0], target_coordinate, interaction_embodiment\n )\n )\n\n target_embodiment = QuantumCircuit(2)\n target_embodiment.compose(prefix_circuit, inplace=True)\n target_embodiment.compose(source_embodiment, inplace=True)\n target_embodiment.compose(affix_circuit, inplace=True)\n self.assertTrue(\n np.all(\n np.abs(\n qiskit.quantum_info.operators.Operator(target_embodiment).data\n - canonical_matrix(*target_coordinate)\n )\n < EPSILON\n )\n )", + "completions": [ + "```python\n def test_xx_circuit_step(self):\n source = np.array([0.1, 0.2, 0.3])\n strength = 0.4\n target = np.array([0.5, 0.6, 0.7])\n embodiment = QuantumCircuit(2)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.cx(0, 1)\n embodiment.barrier()\n embodiment.rz(0.1, 0)\n embodiment.rz(0.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.3, 0)\n embodiment.rz(0.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.5, 0)\n embodiment.rz(0.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.7, 0)\n embodiment.rz(0.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.9, 0)\n embodiment.rz(1.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.1, 0)\n embodiment.rz(1.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.3, 0)\n embodiment.rz(1.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.5, 0)\n embodiment.rz(1.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.7, 0)\n embodiment.rz(1.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.9, 0)\n embodiment.rz(2.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(2.1, 0)\n embodiment.rz(2.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment" + ], + "line": 90, + "token": 807, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_einsum(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_einsum", + "reference": "\n def test_einsum(self):\n x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n self.assertAllClose(\n knp.einsum(\"ijk,lkj->il\", x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"ijk,ikj->i\", x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"i...,j...k->...ijk\", x, y),\n np.einsum(\"i..., j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n self.assertAllClose(\n knp.Einsum(\"ijk,lkj->il\")(x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"ijk,ikj->i\")(x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n np.einsum(\"i...,j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))", + "completions": [ + "```python\n\n def test_arange(self):\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32),\n torch.arange(0, 10, 1, dtype=torch.float32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64),\n torch.arange(0, 10, 1, dtype=torch.float64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32),\n torch.arange(0, 10, 1, dtype=torch.int32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64),\n torch.arange(0, 10, 1, dtype=torch.int64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool),\n torch.arange(0, 10, 1, dtype=torch.bool),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64),\n torch.arange(0, 10, 1, dtype=torch.complex64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128),\n torch.arange(0, 10, 1, dtype=torch.complex128),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n )\n```\n" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 7, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_center_padding(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_center_padding", + "reference": "\n def test_spectrogram_center_padding(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"reflect\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,\n 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,\n 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,\n 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,\n 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,\n 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,\n 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,\n 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,\n 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,\n 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,\n 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,\n 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,\n 0.00217659, 0.00276204, 0.00260835, 0.00299299,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"constant\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,\n 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,\n 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,\n 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,\n 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,\n 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,\n 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,\n 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,\n 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,\n 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,\n 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,\n 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,\n 0.00788239, 0.00664407, 0.00824227, 0.00628301,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=False,\n )\n self.assertEqual(spec.shape, (257, 728))\n\n # fmt: off\n expected = np.array([\n 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,\n 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,\n 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,\n 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,\n 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,\n 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,\n 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,\n 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,\n 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,\n 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,\n 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,\n 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,\n 0.00811857, 0.00538216, 0.00685749, 0.00535275,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))", + "completions": [ + "```python\n def test_spectrogram_center_padding(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 4\n hop_length = 2\n fft_length = 16\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n np.testing.assert_array_equal(\n spectrogram,\n np.array(\n [\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0," + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 8, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_kword_ref_line", + "reference": "\ndef test_nanoplot_out_bars_with_kword_ref_line():\n\n out_bars_with_kword_ref_line = _generate_nanoplot(**CASES[12])\n\n assert _is_nanoplot_output(out_bars_with_kword_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = \"max\"\n y_ref_area = [1, 2]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = \"min\"\n expand_y = \"min\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_stroke_color=data_bar_negative_stroke_" + ], + "line": 749, + "token": 7038, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 9, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_take_along_axis(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_take_along_axis", + "reference": "\n def test_take_along_axis(self):\n x = np.arange(24).reshape([1, 2, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=1),\n np.take_along_axis(x, indices, axis=1),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=1)(x, indices),\n np.take_along_axis(x, indices, axis=1),\n )\n\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )\n\n # Test with axis=None\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.array([1, 2, 3], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=None),\n np.take_along_axis(x, indices, axis=None),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=None)(x, indices),\n np.take_along_axis(x, indices, axis=None),\n )\n\n # Test with negative indices\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )", + "completions": [ + "```python\n\n def test_take_along_axis(self):\n # test arange\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=1, dtype=torch.float32),\n torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=2, dtype=torch.float32),\n torch.tensor([0, 2, 4, 6, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=3, dtype=torch.float32),\n torch.tensor([0, 3, 6, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=4, dtype=torch.float32),\n torch.tensor([0, 4, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=5, dtype=torch.float32),\n torch.tensor([0, 5], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=6, dtype=torch.float32),\n torch.tensor([0, 6], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=7, dtype=torch.float32),\n torch.tensor([0, 7], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=8, dtype=torch.float32),\n torch.tensor([0, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=9, dtype=torch.float32),\n torch.tensor([0, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=10, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=11, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=12, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=13, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=14, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 10, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_correlation(\n response_level_scores: dict[str, list[dict[str, dict[str, Any]]]],\n name: str = '',\n) -> dict[str, dict[str, dict[str, float]]]:\n factscore_metrics = set(list_metrics(response_level_scores[FACTSCORE]))\n pred_metrics = set(list_metrics(response_level_scores[EVAL_METHOD]))\n result = {}\n\n for metric in factscore_metrics.intersection(pred_metrics):\n factscore_scores = [\n find_metric(response_level_scores[FACTSCORE][i], metric)\n for i in range(len(response_level_scores[FACTSCORE]))\n ]\n pred_scores = [\n find_metric(response_level_scores[EVAL_METHOD][i], metric)\n for i in range(len(response_level_scores[EVAL_METHOD]))\n ]\n\n # Remove failed runs\n for i in range(len(factscore_scores) - 1, -1, -1):\n if factscore_scores[i] == -1 or pred_scores[i] == -1:\n factscore_scores.pop(i)\n pred_scores.pop(i)\n\n # Cannot calculate correlation with only one data point\n if len(factscore_scores) <= 1 or len(pred_scores) <= 1:\n pearson_result, spearman_result = None, None\n else:\n pearson_result = stats.pearsonr(factscore_scores, pred_scores)\n spearman_result = stats.spearmanr(factscore_scores, pred_scores)\n\n scatter_plot(\n factscore_scores,\n pred_scores,\n f'{name}-{metric}',\n FACTSCORE,\n EVAL_METHOD,\n )\n result[metric] = {\n _PEARSON: CorrelationResult(pearson_result).result,\n _SPEARMAN: CorrelationResult(spearman_result).result,\n }\n\n return result\n```\n###test function signature:\n```python\n\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_correlation_one_data_point", + "reference": "\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n response_level_scores = {\n correlation_vs_factscore.FACTSCORE: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'N': 1}},\n ],\n correlation_vs_factscore.EVAL_METHOD: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'M': 1}},\n ],\n }\n actual_output = correlation_vs_factscore.compute_correlation(\n response_level_scores=response_level_scores\n )\n mock_pearsonr.assert_not_called()\n mock_spearmanr.assert_not_called()\n mock_scatter_plot.assert_called()\n self.assertIsInstance(actual_output, dict)\n self.assertLen(actual_output, 1)\n self.assertIn('S', actual_output)\n self.assertIsInstance(actual_output['S'], dict)\n self.assertLen(actual_output['S'], 2)\n self.assertIn('Pearson', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Pearson'], dict)\n self.assertIn('correlation', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['correlation']))\n self.assertIn('p_value', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['p_value']))\n self.assertIn('Spearman', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Spearman'], dict)\n self.assertIn('correlation', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['correlation']))\n self.assertIn('p_value', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['p_value']))", + "completions": [ + "```python\ndef test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n) -> None:\n response_level_scores = {\n FACTSCORE: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n EVAL_METHOD: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n }\n result = compute_correlation(response_level_scores)\n self.assertEqual(result, {})\n mock_spearmanr.assert_not_called()\n mock_pearsonr.assert_not_called()\n mock_scatter_plot.assert_not_called()\n```\n" + ], + "line": 54, + "token": 384, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_data_lines():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_data_lines", + "reference": "\ndef test_nanoplot_out_data_lines():\n\n out_data_lines = _generate_nanoplot(**CASES[0])\n\n assert _is_nanoplot_output(out_data_lines)\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_data_lines():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n expected = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_" + ], + "line": 749, + "token": 7038, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 12, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_daily_temp_mean", + "reference": "\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data = il_electricity_cdd_hdd_daily[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_daily[\"temperature_data\"]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert df.shape == (810, 3)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_mean\",\n ]\n\n assert round(df.temperature_mean.mean()) == 55.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"D\", tz=\"UTC\"\n )\n temperature_data = pd.DataFrame(\n {\"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]},\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n df = compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=[],\n cooling_balance_points=[],\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n )\n expected = pd.DataFrame(\n {\n \"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],\n \"n_hours_kept\": [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],\n \"n_hours_dropped\": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n },\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n pd.testing.assert_frame_equal(df, expected)\n```\n" + ], + "line": 166, + "token": 983, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 13, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef power_color(\n frequency,\n power,\n power_err=None,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n):\n freq_edges = np.asarray(freq_edges)\n if len(freq_edges) != 5:\n raise ValueError(\"freq_edges must have 5 elements\")\n\n frequency = np.asarray(frequency)\n power = np.asarray(power)\n\n if df is None:\n df = np.median(np.diff(frequency))\n input_frequency_low_edges = frequency - df / 2\n input_frequency_high_edges = frequency + df / 2\n\n if freq_edges.min() < input_frequency_low_edges[0]:\n raise ValueError(\"The minimum frequency is larger than the first frequency edge\")\n if freq_edges.max() > input_frequency_high_edges[-1]:\n raise ValueError(\"The maximum frequency is lower than the last frequency edge\")\n\n if power_err is None:\n power_err = power / np.sqrt(m)\n else:\n power_err = np.asarray(power_err)\n\n if freqs_to_exclude is not None:\n if len(np.shape(freqs_to_exclude)) == 1:\n freqs_to_exclude = [freqs_to_exclude]\n\n if (\n not isinstance(freqs_to_exclude, Iterable)\n or len(np.shape(freqs_to_exclude)) != 2\n or np.shape(freqs_to_exclude)[1] != 2\n ):\n raise ValueError(\"freqs_to_exclude must be of format [[f0, f1], [f2, f3], ...]\")\n for f0, f1 in freqs_to_exclude:\n frequency_mask = (input_frequency_low_edges > f0) & (input_frequency_high_edges < f1)\n idx0, idx1 = np.searchsorted(frequency, [f0, f1])\n power[frequency_mask] = np.mean([power[idx0], power[idx1]])\n\n var00, var00_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[:2],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var01, var01_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[2:4],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var10, var10_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[1:3],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var11, var11_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[3:5],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n pc0 = var00 / var01\n pc1 = var10 / var11\n pc0_err = pc0 * (var00_err / var00 + var01_err / var01)\n pc1_err = pc1 * (var10_err / var10 + var11_err / var11)\n if return_log:\n pc0_err = 1 / pc0 * pc0_err\n pc1_err = 1 / pc1 * pc1_err\n pc0 = np.log10(pc0)\n pc1 = np.log10(pc1)\n return pc0, pc0_err, pc1, pc1_err\n```\n###test function signature:\n```python\n\n def test_with_power_err(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_power_err", + "reference": "\n def test_with_power_err(self):\n pc0, pc0_err, pc1, pc1_err = power_color(\n self.freq,\n self.power,\n power_err=self.power / 2,\n )\n pc0e, pc0e_err, pc1e, pc1e_err = power_color(\n self.freq,\n self.power,\n power_err=self.power,\n )\n assert np.isclose(pc0, 1, atol=0.001)\n assert np.isclose(pc1, 1, atol=0.001)\n assert np.isclose(pc0e, 1, atol=0.001)\n assert np.isclose(pc1e, 1, atol=0.001)\n assert np.isclose(pc0e_err / pc0_err, 2, atol=0.001)\n assert np.isclose(pc1e_err / pc1_err, 2, atol=0.001)", + "completions": [ + "```python\n def test_with_power_err(self):\n frequency = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power_err = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n pc0, pc0_err, pc1, pc1_err = power_color(\n frequency,\n power,\n power_err=power_err,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n )\n self.assertEqual(pc0, 0.0)\n self.assertEqual(pc0_err, 0.0)\n self.assertEqual(pc1, 0.0)\n self.assertEqual(pc1_err, 0.0)\n```\n" + ], + "line": 97, + "token": 560, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 14, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def apply_top_k(\n self, logits: numpy.ndarray, filter_value=_MIN_FLOAT\n ) -> numpy.ndarray:\n logits_shape = logits.shape\n logits = logits.reshape(logits.shape[-1])\n top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]\n logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value\n\n return logits.reshape(logits_shape)\n```\n###test function signature:\n```python\n\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_apply_topk", + "reference": "\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11))\n\n token_generator = TokenGenerator(\n logits_shape=logits[-1].shape[-1],\n top_k=3,\n )\n\n filter_value = -float(\"Inf\")\n new_logits = token_generator.apply_top_k(\n logits.copy(), filter_value=filter_value\n )\n\n for _ in range(token_generator.top_k):\n curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits)\n assert curr_max > filter_value\n new_logits = numpy.delete(new_logits, idx)\n\n assert numpy.all(new_logits == filter_value)", + "completions": [ + "```python\n def test_apply_topk(\n self,\n ):\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9," + ], + "line": 21, + "token": 152, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(\n self,\n sources: List[Union[str, Path, ByteStream]],\n meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,\n ):\n parser = MarkdownIt(renderer_cls=RendererPlain)\n if self.table_to_single_line:\n parser.enable(\"table\")\n\n documents = []\n meta_list = normalize_metadata(meta=meta, sources_count=len(sources))\n\n for source, metadata in tqdm(\n zip(sources, meta_list),\n total=len(sources),\n desc=\"Converting markdown files to Documents\",\n disable=not self.progress_bar,\n ):\n try:\n bytestream = get_bytestream_from_source(source)\n except Exception as e:\n logger.warning(\"Could not read {source}. Skipping it. Error: {error}\", source=source, error=e)\n continue\n try:\n file_content = bytestream.data.decode(\"utf-8\")\n text = parser.render(file_content)\n except Exception as conversion_e:\n logger.warning(\n \"Failed to extract text from {source}. Skipping it. Error: {error}\",\n source=source,\n error=conversion_e,\n )\n continue\n\n merged_metadata = {**bytestream.meta, **metadata}\n document = Document(content=text, meta=merged_metadata)\n documents.append(document)\n\n return {\"documents\": documents}\n```\n###test function signature:\n```python\n\n def test_run_calls_normalize_metadata(self, test_files_path):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_calls_normalize_metadata", + "reference": "\n def test_run_calls_normalize_metadata(self, test_files_path):\n bytestream = ByteStream(data=b\"test\", meta={\"author\": \"test_author\", \"language\": \"en\"})\n\n converter = MarkdownToDocument()\n\n with patch(\"haystack.components.converters.markdown.normalize_metadata\") as normalize_metadata, patch(\n \"haystack.components.converters.markdown.MarkdownIt\"\n ):\n converter.run(sources=[bytestream, test_files_path / \"markdown\" / \"sample.md\"], meta={\"language\": \"it\"})\n\n # check that the metadata normalizer is called properly\n normalize_metadata.assert_called_with(meta={\"language\": \"it\"}, sources_count=2)", + "completions": [ + "```python\n def test_run_calls_normalize_metadata(self, test_files_path):\n # Arrange\n sources = [test_files_path / \"test_file.md\"]\n meta = {\"key\": \"value\"}\n\n # Act\n run(sources, meta)\n\n # Assert\n normalize_metadata.assert_called_once_with(meta=meta, sources_count=1)\n```\n" + ], + "line": 45, + "token": 311, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 16, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_config(config: object) -> None:\n if hasattr(config, \"kind\"):\n Config.running_config.classes[config.kind] = config # type: ignore\n Config.running_config.types[config.kind] = {}\n for field in fields(config): # type: ignore\n if hasattr(field, \"type\"):\n Config.running_config.types[config.kind][field.name] = optional_origin(field.type)\n else:\n raise RuntimeError(\"Config must have a 'kind' attribute\")\n```\n###test function signature:\n```python\n\ndef test_config():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_config", + "reference": "\ndef test_config():\n arg_parser = get_arg_parser()\n core_add_args(arg_parser)\n arg_parser.parse_known_args()\n cfg = Config(\"test\")\n cfg.add_config(ConfigTest)\n cfg.init_default_config()\n assert Config.dict() == {\n \"configtest\": {\n \"testvar1\": \"testing123\",\n \"testvar2\": 12345,\n \"testvar3\": {\n \"mydict\": {\"foo\": \"bar\", \"abc\": {\"def\": \"ghi\"}},\n \"myint\": 0,\n \"mystr\": \"Hello\",\n },\n }\n }\n cfg2 = Config(\"test2\")\n assert cfg.configtest.testvar1 == cfg2.configtest.testvar1\n assert cfg.configtest.testvar1 == Config.configtest.testvar1\n assert cfg.configtest.testvar2 == cfg2.configtest.testvar2\n assert cfg.configtest.testvar3 == cfg2.configtest.testvar3\n Config.configtest.testvar2 += 1\n assert cfg.configtest.testvar2 == 12346\n with pytest.raises(ConfigNotFoundError):\n Config.does_not_exist.foo = \"bar\"\n with pytest.raises(ConfigNotFoundError):\n cfg.does_not_exist.foo = \"bar\"", + "completions": [ + "```python\ndef test_config():\n from dataclasses import dataclass, field\n from typing import Optional\n\n @dataclass\n class Config:\n kind: str\n a: Optional[int] = field(default=None)\n b: Optional[int] = field(default=None)\n\n add_config(Config)\n\n assert Config.running_config.classes[\"Config\"] == Config\n assert Config.running_config.types[\"Config\"] == {\"a\": int, \"b\": int}\n```\n" + ], + "line": 19, + "token": 166, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 17, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_with_num_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_with_num_ref_line", + "reference": "\ndef test_nanoplot_out_with_num_ref_line():\n\n out_with_num_ref_line = _generate_nanoplot(**CASES[1])\n\n assert _is_nanoplot_output(out_with_num_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_with_num_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color," + ], + "line": 749, + "token": 7038, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 18, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def begin_typing(self, x=None):\n self.capturekeystrokes = True\n # Disable keypress shortcuts, which may otherwise cause the figure to\n # be saved, closed, etc., until the user stops typing. The way to\n # achieve this depends on whether toolmanager is in use.\n stack = ExitStack() # Register cleanup actions when user stops typing.\n self._on_stop_typing = stack.close\n toolmanager = getattr(\n self.ax.figure.canvas.manager, \"toolmanager\", None)\n if toolmanager is not None:\n # If using toolmanager, lock keypresses, and plan to release the\n # lock when typing stops.\n toolmanager.keypresslock(self)\n stack.callback(toolmanager.keypresslock.release, self)\n else:\n # If not using toolmanager, disable all keypress-related rcParams.\n # Avoid spurious warnings if keymaps are getting deprecated.\n with _api.suppress_matplotlib_deprecation_warning():\n stack.enter_context(mpl.rc_context(\n {k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n```\n###test function signature:\n```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_TextBox", + "reference": "\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n plt.rcParams._set(\"toolbar\", toolbar)\n\n submit_event = mock.Mock(spec=noop, return_value=None)\n text_change_event = mock.Mock(spec=noop, return_value=None)\n tool = widgets.TextBox(ax, '')\n tool.on_submit(submit_event)\n tool.on_text_change(text_change_event)\n\n assert tool.text == ''\n\n do_event(tool, '_click')\n\n tool.set_val('x**2')\n\n assert tool.text == 'x**2'\n assert text_change_event.call_count == 1\n\n tool.begin_typing()\n tool.stop_typing()\n\n assert submit_event.call_count == 2\n\n do_event(tool, '_click')\n do_event(tool, '_keypress', key='+')\n do_event(tool, '_keypress', key='5')\n\n assert text_change_event.call_count == 3", + "completions": [ + "```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n with ExitStack() as stack:\n stack.enter_context(mpl.rc_context({k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n ax.begin_typing()\n assert ax.capturekeystrokes is True\n assert toolbar.keypresslock.is_locked(ax)\n```\n" + ], + "line": 31, + "token": 258, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def read_events(\n db: PrefectDBInterface,\n session: AsyncSession,\n events_filter: EventFilter,\n limit: \"int | None\" = None,\n offset: \"int | None\" = None,\n) -> Sequence[ORMEvent]:\n # Always order by occurred timestamp, with placeholder for order direction\n order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc\n\n # Check if distinct fields are provided\n if distinct_fields := build_distinct_queries(events_filter):\n # Define window function\n window_function = (\n sa.func.row_number()\n .over(partition_by=distinct_fields, order_by=order(db.Event.occurred))\n .label(\"row_number\")\n )\n # Create a subquery with the window function\n subquery = (\n sa.select(db.Event, window_function)\n .where(\n sa.and_(\n *events_filter.build_where_clauses(db)\n ) # Ensure the same filters are applied here\n )\n .subquery()\n )\n\n # Alias the subquery for easier column references\n aliased_table = aliased(db.Event, subquery)\n\n # Create the final query from the subquery, filtering to get only rows with row_number = 1\n select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1)\n\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(subquery.c.occurred))\n\n else:\n # If no distinct fields are provided, create a query for all events\n select_events_query = sa.select(db.Event).where(\n sa.and_(*events_filter.build_where_clauses(db))\n )\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(db.Event.occurred))\n\n if limit is not None:\n limit = max(0, min(limit, events_filter.logical_limit))\n select_events_query = select_events_query.limit(limit=limit)\n if offset is not None:\n select_events_query = select_events_query.offset(offset=offset)\n\n logger.debug(\"Running PostgreSQL query: %s\", select_events_query)\n\n select_events_query_result = await session.execute(select_events_query)\n return select_events_query_result.scalars().unique().all()\n```\n###test function signature:\n```python\n\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_read_events", + "reference": "\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_1.id\n\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=3),\n until=pendulum.now(\"UTC\").subtract(days=1),\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_2.id\n\n # Read criteria should apply AND logic\n # Nothing should be returned for this resource\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n resource=EventResourceFilter(id=[\"prefect.garbage.foo\"]),\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 0", + "completions": [ + "```python\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n # Arrange\n events_filter = EventFilter(\n occurred_after=datetime.datetime(2022, 1, 1),\n occurred_before=datetime.datetime(2022, 1, 2),\n order=EventOrder.DESC,\n )\n # Act\n events = await read_events(\n db=self.db,\n session=session,\n events_filter=events_filter,\n )\n # Assert\n self.assertEqual(len(events), 2)\n self.assertEqual(events[0], event_2)\n self.assertEqual(events[1], event_1)\n```\n" + ], + "line": 58, + "token": 412, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 20, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n def test_evaluate_diff_dim(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_evaluate_diff_dim", + "reference": " def test_evaluate_diff_dim(self):\n x1 = np.arange(3, 10, 2)\n kde = mlab.GaussianKDE(x1)\n x2 = np.arange(3, 12, 2)\n y_expected = [\n 0.08797252, 0.11774109, 0.11774109, 0.08797252, 0.0370153\n ]\n y = kde.evaluate(x2)\n np.testing.assert_array_almost_equal(y, y_expected, 7)", + "completions": [ + "```python\n def test_evaluate_diff_dim(self):\n \"\"\"\n Test evaluate with different dimensions\n \"\"\"\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n\n # Generate a random dataset\n X, _ = make_blobs(n_samples=100, n_features=10, centers=5,\n random_state=42)\n\n # Standardize the data\n scaler = StandardScaler()\n X = scaler.fit_transform(X)\n\n # Create a GaussianMixture object\n gmm = GaussianMixture(n_components=5, covariance_type='full')\n gmm.fit(X)\n\n # Generate a random set of points\n n_points = 100\n points = np.random.rand(n_points, 10)\n\n # Evaluate the points\n result = gmm.evaluate(points)\n\n # Check that the result is correct\n self.assertEqual(result.shape, (n_points,))\n self.assertTrue(np.all(result >= 0))\n self.assertTrue(np.all(result <= 1))\n```\n" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 21, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _differentiate(func, x, *, args=(), atol=None, rtol=None, maxiter=10,\n order=8, initial_step=0.5, step_factor=2.0,\n step_direction=0, preserve_shape=False, callback=None):\n # TODO (followup):\n # - investigate behavior at saddle points\n # - array initial_step / step_factor?\n # - multivariate functions?\n\n res = _differentiate_iv(func, x, args, atol, rtol, maxiter, order, initial_step,\n step_factor, step_direction, preserve_shape, callback)\n (func, x, args, atol, rtol, maxiter, order,\n h0, fac, hdir, preserve_shape, callback) = res\n\n # Initialization\n # Since f(x) (no step) is not needed for central differences, it may be\n # possible to eliminate this function evaluation. However, it's useful for\n # input validation and standardization, and everything else is designed to\n # reduce function calls, so let's keep it simple.\n temp = eim._initialize(func, (x,), args, preserve_shape=preserve_shape)\n func, xs, fs, args, shape, dtype = temp\n x, f = xs[0], fs[0]\n df = np.full_like(f, np.nan)\n # Ideally we'd broadcast the shape of `hdir` in `_elementwise_algo_init`, but\n # it's simpler to do it here than to generalize `_elementwise_algo_init` further.\n # `hdir` and `x` are already broadcasted in `_differentiate_iv`, so we know\n # that `hdir` can be broadcasted to the final shape.\n hdir = np.broadcast_to(hdir, shape).flatten()\n\n status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress\n nit, nfev = 0, 1 # one function evaluations performed above\n # Boolean indices of left, central, right, and (all) one-sided steps\n il = hdir < 0\n ic = hdir == 0\n ir = hdir > 0\n io = il | ir\n\n # Most of these attributes are reasonably obvious, but:\n # - `fs` holds all the function values of all active `x`. The zeroth\n # axis corresponds with active points `x`, the first axis corresponds\n # with the different steps (in the order described in\n # `_differentiate_weights`).\n # - `terms` (which could probably use a better name) is half the `order`,\n # which is always even.\n work = _RichResult(x=x, df=df, fs=f[:, np.newaxis], error=np.nan, h=h0,\n df_last=np.nan, error_last=np.nan, h0=h0, fac=fac,\n atol=atol, rtol=rtol, nit=nit, nfev=nfev,\n status=status, dtype=dtype, terms=(order+1)//2,\n hdir=hdir, il=il, ic=ic, ir=ir, io=io)\n # This is the correspondence between terms in the `work` object and the\n # final result. In this case, the mapping is trivial. Note that `success`\n # is prepended automatically.\n res_work_pairs = [('status', 'status'), ('df', 'df'), ('error', 'error'),\n ('nit', 'nit'), ('nfev', 'nfev'), ('x', 'x')]\n\n def pre_func_eval(work):\n \"\"\"Determine the abscissae at which the function needs to be evaluated.\n\n See `_differentiate_weights` for a description of the stencil (pattern\n of the abscissae).\n\n In the first iteration, there is only one stored function value in\n `work.fs`, `f(x)`, so we need to evaluate at `order` new points. In\n subsequent iterations, we evaluate at two new points. Note that\n `work.x` is always flattened into a 1D array after broadcasting with\n all `args`, so we add a new axis at the end and evaluate all point\n in one call to the function.\n\n For improvement:\n - Consider measuring the step size actually taken, since `(x + h) - x`\n is not identically equal to `h` with floating point arithmetic.\n - Adjust the step size automatically if `x` is too big to resolve the\n step.\n - We could probably save some work if there are no central difference\n steps or no one-sided steps.\n \"\"\"\n n = work.terms # half the order\n h = work.h # step size\n c = work.fac # step reduction factor\n d = c**0.5 # square root of step reduction factor (one-sided stencil)\n # Note - no need to be careful about dtypes until we allocate `x_eval`\n\n if work.nit == 0:\n hc = h / c**np.arange(n)\n hc = np.concatenate((-hc[::-1], hc))\n else:\n hc = np.asarray([-h, h]) / c**(n-1)\n\n if work.nit == 0:\n hr = h / d**np.arange(2*n)\n else:\n hr = np.asarray([h, h/d]) / c**(n-1)\n\n n_new = 2*n if work.nit == 0 else 2 # number of new abscissae\n x_eval = np.zeros((len(work.hdir), n_new), dtype=work.dtype)\n il, ic, ir = work.il, work.ic, work.ir\n x_eval[ir] = work.x[ir, np.newaxis] + hr\n x_eval[ic] = work.x[ic, np.newaxis] + hc\n x_eval[il] = work.x[il, np.newaxis] - hr\n return x_eval\n\n def post_func_eval(x, f, work):\n \"\"\" Estimate the derivative and error from the function evaluations\n\n As in `pre_func_eval`: in the first iteration, there is only one stored\n function value in `work.fs`, `f(x)`, so we need to add the `order` new\n points. In subsequent iterations, we add two new points. The tricky\n part is getting the order to match that of the weights, which is\n described in `_differentiate_weights`.\n\n For improvement:\n - Change the order of the weights (and steps in `pre_func_eval`) to\n simplify `work_fc` concatenation and eliminate `fc` concatenation.\n - It would be simple to do one-step Richardson extrapolation with `df`\n and `df_last` to increase the order of the estimate and/or improve\n the error estimate.\n - Process the function evaluations in a more numerically favorable\n way. For instance, combining the pairs of central difference evals\n into a second-order approximation and using Richardson extrapolation\n to produce a higher order approximation seemed to retain accuracy up\n to very high order.\n - Alternatively, we could use `polyfit` like Jacobi. An advantage of\n fitting polynomial to more points than necessary is improved noise\n tolerance.\n \"\"\"\n n = work.terms\n n_new = n if work.nit == 0 else 1\n il, ic, io = work.il, work.ic, work.io\n\n # Central difference\n # `work_fc` is *all* the points at which the function has been evaluated\n # `fc` is the points we're using *this iteration* to produce the estimate\n work_fc = (f[ic, :n_new], work.fs[ic, :], f[ic, -n_new:])\n work_fc = np.concatenate(work_fc, axis=-1)\n if work.nit == 0:\n fc = work_fc\n else:\n fc = (work_fc[:, :n], work_fc[:, n:n+1], work_fc[:, -n:])\n fc = np.concatenate(fc, axis=-1)\n\n # One-sided difference\n work_fo = np.concatenate((work.fs[io, :], f[io, :]), axis=-1)\n if work.nit == 0:\n fo = work_fo\n else:\n fo = np.concatenate((work_fo[:, 0:1], work_fo[:, -2*n:]), axis=-1)\n\n work.fs = np.zeros((len(ic), work.fs.shape[-1] + 2*n_new))\n work.fs[ic] = work_fc\n work.fs[io] = work_fo\n\n wc, wo = _differentiate_weights(work, n)\n work.df_last = work.df.copy()\n work.df[ic] = fc @ wc / work.h\n work.df[io] = fo @ wo / work.h\n work.df[il] *= -1\n\n work.h /= work.fac\n work.error_last = work.error\n # Simple error estimate - the difference in derivative estimates between\n # this iteration and the last. This is typically conservative because if\n # convergence has begin, the true error is much closer to the difference\n # between the current estimate and the *next* error estimate. However,\n # we could use Richarson extrapolation to produce an error estimate that\n # is one order higher, and take the difference between that and\n # `work.df` (which would just be constant factor that depends on `fac`.)\n work.error = abs(work.df - work.df_last)\n\n def check_termination(work):\n \"\"\"Terminate due to convergence, non-finite values, or error increase\"\"\"\n stop = np.zeros_like(work.df).astype(bool)\n\n i = work.error < work.atol + work.rtol*abs(work.df)\n work.status[i] = eim._ECONVERGED\n stop[i] = True\n\n if work.nit > 0:\n i = ~((np.isfinite(work.x) & np.isfinite(work.df)) | stop)\n work.df[i], work.status[i] = np.nan, eim._EVALUEERR\n stop[i] = True\n\n # With infinite precision, there is a step size below which\n # all smaller step sizes will reduce the error. But in floating point\n # arithmetic, catastrophic cancellation will begin to cause the error\n # to increase again. This heuristic tries to avoid step sizes that are\n # too small. There may be more theoretically sound approaches for\n # detecting a step size that minimizes the total error, but this\n # heuristic seems simple and effective.\n i = (work.error > work.error_last*10) & ~stop\n work.status[i] = _EERRORINCREASE\n stop[i] = True\n\n return stop\n\n def post_termination_check(work):\n return\n\n def customize_result(res, shape):\n return shape\n\n return eim._loop(work, callback, shape, maxiter, func, args, dtype,\n pre_func_eval, post_func_eval, check_termination,\n post_termination_check, customize_result, res_work_pairs,\n preserve_shape)\n```\n###test function signature:\n```python\n\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_maxiter_callback", + "reference": "\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n x = 0.612814\n dist = stats.norm()\n maxiter = 3\n\n def f(x):\n res = dist.cdf(x)\n return res\n\n default_order = 8\n res = differentiate(f, x, maxiter=maxiter, rtol=1e-15)\n assert not np.any(res.success)\n assert np.all(res.nfev == default_order + 1 + (maxiter - 1)*2)\n assert np.all(res.nit == maxiter)\n\n def callback(res):\n callback.iter += 1\n callback.res = res\n assert hasattr(res, 'x')\n assert res.df not in callback.dfs\n callback.dfs.add(res.df)\n assert res.status == eim._EINPROGRESS\n if callback.iter == maxiter:\n raise StopIteration\n callback.iter = -1 # callback called once before first iteration\n callback.res = None\n callback.dfs = set()\n\n res2 = differentiate(f, x, callback=callback, rtol=1e-15)\n # terminating with callback is identical to terminating due to maxiter\n # (except for `status`)\n for key in res.keys():\n if key == 'status':\n assert res[key] == eim._ECONVERR\n assert callback.res[key] == eim._EINPROGRESS\n assert res2[key] == eim._ECALLBACK\n else:\n assert res2[key] == callback.res[key] == res[key]", + "completions": [ + "```python\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n def callback(work):\n if work.nit == 0:\n work.maxiter = 1\n else:\n work.maxiter = 0\n res = _differentiate(lambda x: x**2, 0, maxiter=2, callback=callback)\n self.assertEqual(res.nit, 1)\n self.assertEqual(res.status, 0)\n```\n" + ], + "line": 187, + "token": 1981, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 22, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef get_standard_metadata(\n pyproject_dict: Mapping[str, Any],\n settings: ScikitBuildSettings,\n) -> StandardMetadata:\n new_pyproject_dict = dict(pyproject_dict)\n # Handle any dynamic metadata\n for field, provider, config in load_dynamic_metadata(settings.metadata):\n if provider is None:\n msg = f\"{field} is missing provider\"\n raise KeyError(msg)\n if field not in pyproject_dict.get(\"project\", {}).get(\"dynamic\", []):\n msg = f\"{field} is not in project.dynamic\"\n raise KeyError(msg)\n new_pyproject_dict[\"project\"][field] = provider.dynamic_metadata(field, config)\n new_pyproject_dict[\"project\"][\"dynamic\"].remove(field)\n\n metadata = StandardMetadata.from_pyproject(new_pyproject_dict)\n\n # For scikit-build-core < 0.5, we keep the normalized name for back-compat\n if settings.minimum_version is not None and settings.minimum_version < Version(\n \"0.5\"\n ):\n metadata.name = metadata.canonical_name\n\n # The description field is required to be one line. Instead of merging it\n # or cutting off subsequent lines (setuptools), we throw a nice error.\n # But we didn't validate before 0.9.\n if (\n settings.minimum_version is None or settings.minimum_version >= Version(\"0.9\")\n ) and \"\\n\" in (metadata.description or \"\"):\n msg = \"Multiple lines in project.description are not supported; this is supposed to be a one line summary\"\n raise ValueError(msg)\n\n return metadata\n```\n###test function signature:\n```python\n\ndef test_plugin_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_plugin_metadata", + "reference": "\ndef test_plugin_metadata():\n reason_msg = (\n \"install hatch-fancy-pypi-readme and setuptools-scm to test the \"\n \"dynamic metadata plugins\"\n )\n pytest.importorskip(\"hatch_fancy_pypi_readme\", reason=reason_msg)\n pytest.importorskip(\"setuptools_scm\", reason=reason_msg)\n if shutil.which(\"git\") is None:\n pytest.skip(\"git is not installed\")\n\n shutil.copy(\"plugin_project.toml\", \"pyproject.toml\")\n\n subprocess.run([\"git\", \"init\", \"--initial-branch=main\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.name\", \"bot\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.email\", \"bot@scikit-build.org\"], check=True)\n subprocess.run([\"git\", \"add\", \"pyproject.toml\"], check=True)\n subprocess.run([\"git\", \"commit\", \"-m\", \"initial commit\"], check=True)\n subprocess.run([\"git\", \"tag\", \"v0.1.0\", \"-m\", \"initial commint\"], check=True)\n\n with Path(\"pyproject.toml\").open(\"rb\") as ft:\n pyproject = tomllib.load(ft)\n settings_reader = SettingsReader(pyproject, {}, state=\"metadata_wheel\")\n settings = settings_reader.settings\n\n settings_reader.validate_may_exit()\n\n metadata = get_standard_metadata(pyproject, settings)\n\n assert str(metadata.version) == \"0.1.0\"\n assert metadata.readme == pyproject_metadata.Readme(\n \"Fragment #1Fragment #2\", None, \"text/x-rst\"\n )\n\n assert set(GetRequires().dynamic_metadata()) == {\n \"hatch-fancy-pypi-readme>=22.3\",\n \"setuptools-scm\",\n }", + "completions": [ + "```python\n\ndef test_plugin_metadata():\n from skbuild import __version__ as skbuild_version\n from skbuild.settings import ScikitBuildSettings\n from skbuild.metadata import StandardMetadata\n from skbuild.utils import load_dynamic_metadata\n from skbuild.version import Version\n\n pyproject_dict = {\n \"build-system\": {\n \"requires\": [\"setuptools>=40.8.0\", \"wheel\"],\n \"build-backend\": \"setuptools.build_meta\",\n },\n \"project\": {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n \"description\": \"A build system for Python packages\",\n \"dynamic\": [\"name\", \"version\"],\n },\n }\n\n settings = ScikitBuildSettings(\n minimum_version=Version(\"0.1.0\"),\n metadata=[\n (\"name\", None, None),\n (\"version\", None, None),\n ],\n )\n\n metadata = get_standard_metadata(pyproject_dict, settings)\n\n assert metadata.name == \"scikit-build\"\n assert metadata.version == \"0.1.0\"\n assert metadata.description == \"A build system for Python packages\"\n\n assert metadata.canonical_name == \"scikit-build\"\n assert metadata.canonical_version == \"0.1.0\"\n\n assert metadata.dynamic_metadata == {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n }\n\n assert metadata.dynamic_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_providers == [None, None]\n\n assert metadata.dynamic_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_metadata_providers == [None, None]\n\n assert metadata.dynamic_metadata_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields_by_provider == {None: [\"name\", \"version\"]}\n\n assert metadata.dynamic_metadata_providers_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_configs_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field == {\n (None, \"name\"): \"scikit-build\",\n (None, \"version\"): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider == {\n (\"name\", None): \"scikit-build\",\n (\"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider == {\n (\"name\", None): None,\n (\"version\", None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config == {\n (None, \"name\", None): \"scikit-build\",\n (None, \"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config == {\n (\"name\", None, None): \"scikit-build\",\n (\"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider_and_config == {\n (\"name\", None, None): None,\n (\"version\", None, None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config_and_provider == {\n (None, \"name\", None, None): \"scikit-build\",\n (None, \"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config_and_provider == {\n (\"name\", None, None" + ], + "line": 40, + "token": 347, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 23, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _get_datetime_format(value):\n if not isinstance(value, pd.Series):\n value = pd.Series(value)\n\n value = value[~value.isna()]\n value = value.astype(str).to_numpy()\n\n return _guess_datetime_format_for_array(value)\n```\n###test function signature:\n```python\ndef test__get_datetime_format():\n # Setup\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test__get_datetime_format", + "reference": "def test__get_datetime_format():\n # Setup\n string_value = '2021-02-02'\n list_value = [np.nan, '2021-02-02']\n series_value = pd.Series(['2021-02-02T12:10:59'])\n\n # Run\n string_out = _get_datetime_format(string_value)\n list_out = _get_datetime_format(list_value)\n series_out = _get_datetime_format(series_value)\n\n # Assert\n expected_output = '%Y-%m-%d'\n assert string_out == expected_output\n assert list_out == expected_output\n assert series_out == '%Y-%m-%dT%H:%M:%S'", + "completions": [ + "```python\ndef test__get_datetime_format():\n # Setup\n value = pd.Series([\"2020-01-01\", \"2020-01-02\", \"2020-01-03\"])\n expected = \"%Y-%m-%d\"\n\n # Exercise\n actual = _get_datetime_format(value)\n\n # Verify\n assert actual == expected\n```\n" + ], + "line": 17, + "token": 124, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 24, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def load(self) -> None:\n super().load()\n\n _path = str(\n importlib_resources.files(\"distilabel\")\n / \"steps\"\n / \"tasks\"\n / \"templates\"\n / \"ultrafeedback\"\n / f\"{self.aspect}.jinja2\"\n )\n\n self._template = Template(open(_path).read())\n```\n###test function signature:\n```python\n\n def test_process_with_complex_aspect(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_process_with_complex_aspect", + "reference": "\n def test_process_with_complex_aspect(self) -> None:\n pipeline = Pipeline(name=\"unit-test-pipeline\")\n llm = UltraFeedbackLLM()\n\n task = UltraFeedback(\n name=\"ultrafeedback\",\n aspect=\"truthfulness\",\n llm=llm,\n pipeline=pipeline,\n )\n task.load()\n\n assert next(\n task.process([{\"instruction\": \"test\", \"generations\": [\"A\", \"B\"]}])\n ) == [\n {\n \"instruction\": \"test\",\n \"generations\": [\"A\", \"B\"],\n \"types\": [1, 2],\n \"rationales\": [\"text\", \"text\"],\n \"ratings\": [1, 2],\n \"rationales-for-ratings\": [\"text\", \"text\"],\n \"model_name\": \"ultrafeedback-model\",\n }\n ]", + "completions": [ + "```python\n def test_process_with_complex_aspect(self) -> None:\n self.aspect = \"complex_aspect\"\n self.load()\n self.assertEqual(\n self._template.render(\n {\n \"aspect\": self.aspect,\n \"aspect_name\": \"Complex Aspect\",\n \"aspect_description\": \"This is a complex aspect\",\n \"aspect_type\": \"complex\",\n \"aspect_value\": \"1\",\n \"aspect_weight\": \"1\",\n \"aspect_color\": \"red\",\n \"aspect_icon\": \"icon-circle\",\n \"aspect_icon_color\": \"red\",\n \"aspect_icon_bg_color\": \"white\",\n \"aspect_icon_size\": \"20\",\n \"aspect_icon_border_color\": \"red\",\n \"aspect_icon_border_width\": \"2\",\n \"aspect_icon_border_radius\": \"50\",\n \"aspect_icon_padding\": \"5\",\n \"aspect_icon_margin\": \"5\",\n \"aspect_icon_bg_size\": \"20\",\n \"aspect_icon_bg_border_color\": \"red\",\n \"aspect_icon_bg_border_width\": \"2\",\n \"aspect_icon_bg_border_radius\": \"50\",\n \"aspect_icon_bg_padding\": \"5\",\n \"aspect_icon_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_" + ], + "line": 21, + "token": 145, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 25, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _setup_cuda_device(device: torch.device) -> torch.device:\n local_rank = _get_local_rank() or 0\n if device.index is None:\n device = torch.device(type=\"cuda\", index=local_rank)\n\n # Ensure index is available before setting device\n if device.index >= torch.cuda.device_count():\n raise RuntimeError(\n \"The local rank is larger than the number of available GPUs.\"\n )\n\n torch.cuda.set_device(device)\n return device\n```\n###test function signature:\n```python\n\n def test_get_gpu_device(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_gpu_device", + "reference": "\n def test_get_gpu_device(self) -> None:\n device_idx = torch.cuda.device_count() - 1\n assert device_idx >= 0\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(device_idx)}, clear=True):\n device = get_device()\n assert device.type == \"cuda\"\n assert device.index == device_idx\n assert device.index == torch.cuda.current_device()\n\n # Test that we raise an error if the device index is specified on distributed runs\n if device_idx > 0:\n with pytest.raises(\n RuntimeError,\n match=f\"Device specified is cuda:0 but was assigned cuda:{device_idx}\",\n ):\n device = get_device(\"cuda:0\")\n\n invalid_device_idx = device_idx + 10\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(invalid_device_idx)}):\n with pytest.raises(\n RuntimeError,\n match=\"The local rank is larger than the number of available GPUs\",\n ):\n device = get_device(\"cuda\")\n\n # Test that we fall back to 0 if LOCAL_RANK is not specified\n device = torch.device(_get_device_type_from_env())\n device = _setup_cuda_device(device)\n assert device.type == \"cuda\"\n assert device.index == 0\n assert device.index == torch.cuda.current_device()", + "completions": [ + "```python\n def test_get_gpu_device(self) -> None:\n device = torch.device(type=\"cuda\", index=0)\n self.assertEqual(device, _setup_cuda_device(device))\n```\n" + ], + "line": 21, + "token": 157, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):\n local_dir = settings.home_dir / settings.cache_dir / \"models\" / model_id\n client.load_artifact(local_dir, posixpath.join(\"models\", model_id))\n meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)\n # internal worker case, no token based http client\n if not local_dir.exists():\n raise RuntimeError(f\"Cannot find existing model {project_key}.{model_id} in {local_dir}\")\n with (Path(local_dir) / META_FILENAME).open(encoding=\"utf-8\") as f:\n file_meta = yaml.load(f, Loader=yaml.Loader)\n classification_labels = cls.cast_labels(meta_response)\n meta = ModelMeta(\n name=meta_response.name,\n description=meta_response.description,\n model_type=SupportedModelTypes[meta_response.modelType],\n feature_names=meta_response.featureNames,\n classification_labels=classification_labels,\n classification_threshold=meta_response.threshold,\n loader_module=file_meta[\"loader_module\"],\n loader_class=file_meta[\"loader_class\"],\n )\n\n model_py_ver = (\n tuple(meta_response.languageVersion.split(\".\")) if \"PYTHON\" == meta_response.language.upper() else None\n )\n\n clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver)\n\n constructor_params = meta.__dict__\n constructor_params[\"id\"] = str(model_id)\n\n del constructor_params[\"loader_module\"]\n del constructor_params[\"loader_class\"]\n\n model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params)\n return model\n```\n###test function signature:\n```python\n\ndef test_model_download(request):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_model_download", + "reference": "\ndef test_model_download(request):\n model: BaseModel = request.getfixturevalue(\"german_credit_model\")\n project_key = str(uuid.uuid4())\n\n with utils.MockedProjectCacheDir():\n with utils.MockedClient(mock_all=False) as (client, mr):\n # The model needs to request files\n requested_urls = []\n requested_urls.extend(utils.register_uri_for_model_meta_info(mr, model, project_key))\n requested_urls.extend(\n utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True)\n )\n\n downloaded_model = BaseModel.download(client=client, project_key=project_key, model_id=str(model.id))\n\n for requested_url in requested_urls:\n assert utils.is_url_requested(mr.request_history, requested_url)\n\n assert downloaded_model.id == model.id\n assert downloaded_model.meta == model.meta", + "completions": [ + "```python\ndef test_model_download(request):\n from pathlib import Path\n from unittest.mock import MagicMock\n from giskard.client import GiskardClient\n from giskard.models import ModelMetaInfo, ModelMeta, SupportedModelTypes\n from giskard.utils import cast_labels\n from giskard.models.base import Model\n from giskard.models.classification import ClassificationModel\n from giskard.models.regression import RegressionModel\n from giskard.models.clustering import ClusteringModel\n from giskard.models.dim_reduction import DimReductionModel\n from giskard.models.anomaly_detection import AnomalyDetectionModel\n from giskard.models.time_series import TimeSeriesModel\n from giskard.models.text_classification import TextClassificationModel\n from giskard.models.text_recommendation import TextRecommendationModel\n from giskard.models.image_classification import ImageClassificationModel\n from giskard.models.image_recommendation import ImageRecommendationModel\n from giskard.models.image_segmentation import ImageSegmentationModel\n from giskard.models.image_generation import ImageGenerationModel\n from giskard.models.image_style_transfer import ImageStyleTransferModel\n from giskard.models.image_super_resolution import ImageSuperResolutionModel\n from giskard.models.image_denoising import ImageDenoisingModel\n from giskard.models.image_colorization import ImageColorizationModel\n from giskard.models.image_deraining import ImageDerainingModel\n from giskard.models.image_deblurring import ImageDeblurringModel\n from giskard.models.image_edge_detection import ImageEdgeDetectionModel\n from giskard.models.image_face_detection import ImageFaceDetectionModel\n from giskard.models.image_face_recognition import ImageFaceRecognitionModel\n from giskard.models.image_object_detection import ImageObjectDetectionModel\n from giskard.models.image_object_recognition import ImageObjectRecognitionModel\n from giskard.models.image_segmentation_mask_rcnn import ImageSegmentationMaskRCNNModel\n from giskard.models.image_segmentation_yolo import ImageSegmentationYOLOModel\n from giskard.models.image_segmentation_faster_rcnn import ImageSegmentationFasterRCNNModel\n from giskard.models.image_segmentation_unet import ImageSegmentationUNetModel\n from giskard.models.image_segmentation_deeplab import ImageSegmentationDeepLabModel\n from giskard.models.image_segmentation_fcn import ImageSegmentationFCNModel\n from giskard.models.image_segmentation_pspnet import ImageSegmentationPSPNetModel\n from giskard.models.image_segmentation_linknet import ImageSegmentationLinkNetModel\n from giskard.models.image_segmentation_unet_plus import ImageSegmentationUNetPlusModel\n from giskard.models.image_segmentation_deeplab_v3 import ImageSegmentationDeepLabV3Model\n from giskard.models.image_segmentation_deeplab_v3_plus import ImageSegmentationDeepLabV3PlusModel\n from giskard.models.image_segmentation_fcn_resnet101 import ImageSegmentationFCNResNet101Model\n from giskard.models.image_segmentation_fcn_resnet50 import ImageSegmentationFCNResNet50Model\n from giskard.models.image_segmentation_fcn_mobilenet_v2 import ImageSegmentationFCNMobileNetV2Model\n from giskard.models.image_segmentation_fcn_mobilenet_v3 import ImageSegmentationFCNMobileNetV3Model\n from giskard.models.image_segmentation_fcn_efficientnet_b" + ], + "line": 40, + "token": 316, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\ndef test_rounded_date_columns():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rounded_date_columns", + "reference": "def test_rounded_date_columns():\n df = pl.DataFrame(\n [\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.000000\", \"exp\": True},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.123456\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:01.000000\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00\", \"exp\": True},\n ]\n )\n\n df = df.with_columns(df[\"a\"].str.to_date().alias(\"a_dt\"))\n actual = columns_equal(df[\"a_dt\"], df[\"b\"])\n expected = df[\"exp\"]\n assert_series_equal(actual, expected, check_names=False)", + "completions": [ + "```python\ndef test_rounded_date_columns():\n # given\n col_1 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n col_2 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n\n # when\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n\n # then\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef annotations_to_instances(annos, image_size, mask_format=\"polygon\"):\n boxes = [BoxMode.convert(obj[\"bbox\"], obj[\"bbox_mode\"], BoxMode.XYXY_ABS) for obj in annos]\n target = Instances(image_size)\n boxes = target.gt_boxes = Boxes(boxes)\n boxes.clip(image_size)\n\n classes = [obj[\"category_id\"] for obj in annos]\n classes = torch.tensor(classes, dtype=torch.int64)\n target.gt_classes = classes\n\n if len(annos) and \"segmentation\" in annos[0]:\n segms = [obj[\"segmentation\"] for obj in annos]\n if mask_format == \"polygon\":\n masks = PolygonMasks(segms)\n else:\n assert mask_format == \"bitmask\", mask_format\n masks = []\n for segm in segms:\n if isinstance(segm, list):\n # polygon\n masks.append(polygons_to_bitmask(segm, *image_size))\n elif isinstance(segm, dict):\n # COCO RLE\n masks.append(mask_util.decode(segm))\n elif isinstance(segm, np.ndarray):\n assert segm.ndim == 2, \"Expect segmentation of 2 dimensions, got {}.\".format(\n segm.ndim\n )\n # mask array\n masks.append(segm)\n else:\n raise ValueError(\n \"Cannot convert segmentation of type '{}' to BitMasks!\"\n \"Supported types are: polygons as list[list[float] or ndarray],\"\n \" COCO-style RLE as a dict, or a full-image segmentation mask \"\n \"as a 2D ndarray.\".format(type(segm))\n )\n # torch.from_numpy does not support array with negative stride.\n masks = BitMasks(\n torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])\n )\n target.gt_masks = masks\n\n if len(annos) and \"keypoints\" in annos[0]:\n kpts = [obj.get(\"keypoints\", []) for obj in annos]\n target.gt_keypoints = Keypoints(kpts)\n\n return target\n```\n###test function signature:\n```python\n\n def test_transform_simple_annotation(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_simple_annotation", + "reference": "\n def test_transform_simple_annotation(self):\n transforms = T.TransformList([T.HFlipTransform(400)])\n anno = {\n \"bbox\": np.asarray([10, 10, 200, 300]),\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 3,\n \"segmentation\": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],\n }\n\n output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))\n self.assertTrue(np.allclose(output[\"bbox\"], [200, 10, 390, 300]))\n self.assertEqual(len(output[\"segmentation\"]), len(anno[\"segmentation\"]))\n self.assertTrue(np.allclose(output[\"segmentation\"][0], [390, 10, 300, 100, 300, 10]))\n\n detection_utils.annotations_to_instances([output, output], (400, 400))", + "completions": [ + "```python\n def test_transform_simple_annotation(self):\n annos = [\n {\n \"bbox\": [10, 10, 100, 100],\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 0,\n \"segmentation\": [\n [10, 10, 50, 50, 90, 90, 100, 100],\n [20, 20, 50, 50, 80, 80, 100, 100],\n ],\n \"keypoints\": [10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n }\n ]\n image_size = (100, 100)\n target = annotations_to_instances(annos, image_size)\n assert target.gt_boxes.tensor.tolist() == [[10, 10, 100, 100]]\n assert target.gt_classes.tolist() == [0]\n assert target.gt_masks.polygons().tolist() == [\n [[10, 10], [50, 50], [90, 90], [100, 100]],\n [[20, 20], [50, 50], [80, 80], [100, 100]],\n ]\n assert target.gt_keypoints.tensor.tolist() == [[10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n```\n" + ], + "line": 54, + "token": 437, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 29, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_simple(seed):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_simple", + "reference": "\ndef test_sabre_simple(seed):\n placer = Trivial()\n circ = Circuit(5)\n circ.add(gates.CZ(0, 1))\n initial_layout = placer(circ)\n router = Sabre(connectivity=star_connectivity(), seed=seed)\n routed_circuit, final_map = router(circuit=circ, initial_layout=initial_layout)\n assert router.added_swaps == 1\n assert final_map == {\"q0\": 2, \"q1\": 1, \"q2\": 0, \"q3\": 3, \"q4\": 4}\n assert routed_circuit.queue[0].qubits == (0, 2)\n assert isinstance(routed_circuit.queue[0], gates.SWAP)\n assert isinstance(routed_circuit.queue[1], gates.CZ)\n assert_connectivity(circuit=routed_circuit, connectivity=star_connectivity())\n assert_circuit_equivalence(\n original_circuit=circ,\n transpiled_circuit=routed_circuit,\n final_map=final_map,\n initial_map=initial_layout,\n )", + "completions": [ + "```python\ndef test_sabre_simple(seed):\n import random\n random.seed(seed)\n import networkx as nx\n from pyquil.quil import Program\n from pyquil.api import QVMConnection\n from pyquil.gates import CNOT, H, MEASURE\n from pyquil.quilbase import Declare\n from pyquil.quilbase import DefGate\n from pyquil.quilbase import DefCircuit\n from pyquil.quilbase import Gate\n from pyquil.quilbase import Circuit\n from pyquil.quilbase import QubitPlaceholder\n from pyquil.quilbase import MemoryReference\n from pyquil.quilbase import Parameter\n from pyquil.quilbase import Pragma\n from pyquil.quilbase import PragmaDeclare\n from pyquil.quilbase import PragmaDefGate\n from pyquil.quilbase import PragmaDefCircuit\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefQubit\n from pyquil.quilbase import PragmaDefMemory\n from pyquil.quilbase import PragmaDefRegister\n from pyquil.quilbase import PragmaDefLabel\n from pyquil.quilbase import PragmaDefFunction\n from pyquil.quilbase import PragmaDefProcedure\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyqu" + ], + "line": 27, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 30, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef unique_patterns_2d(\n agrid: NDArray[np.int64], ksize: int, periodic_input: bool\n) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:\n assert ksize >= 1\n if periodic_input:\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n else:\n # TODO: implement non-wrapped image handling\n # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n\n patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(\n agrid,\n (\n agrid.shape[0] - ksize + 1,\n agrid.shape[1] - ksize + 1,\n ksize,\n ksize,\n *agrid.shape[2:],\n ),\n agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],\n writeable=False,\n )\n patch_codes = hash_downto(patches, 2)\n uc, ui = np.unique(patch_codes, return_index=True)\n locs = np.unravel_index(ui, patch_codes.shape)\n up: NDArray[np.int64] = patches[locs[0], locs[1]]\n ids: NDArray[np.int64] = np.vectorize(\n {code: ind for ind, code in enumerate(uc)}.get\n )(patch_codes)\n return ids, up, patch_codes\n```\n###test function signature:\n```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_patterns_2d", + "reference": "\ndef test_unique_patterns_2d(img_redmaze) -> None:\n img = img_redmaze\n tile_size = 1\n pattern_width = 2\n _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(\n img, tile_size\n )\n\n (\n _patterns_in_grid,\n pattern_contents_list,\n patch_codes,\n ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)\n assert patch_codes[1][2] == 4867810695119132864\n assert pattern_contents_list[7][1][1] == 8253868773529191888", + "completions": [ + "```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n ids, up, patch_codes = unique_patterns_2d(img_redmaze, 3, True)\n assert ids.shape == (10,)\n assert up.shape == (10, 3, 3)\n assert patch_codes.shape == (10,)\n```\n" + ], + "line": 48, + "token": 425, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 31, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef serialize(data: Any, encoding: str | None = None) -> str | bytes:\n serialized: bytes | str\n\n if encoding is not None:\n serialized = yaml.dump(data, encoding=encoding, Dumper=AiiDADumper)\n else:\n serialized = yaml.dump(data, Dumper=AiiDADumper)\n\n return serialized\n```\n###test function signature:\n```python\ndef test_mixed_attribute_normal_dict():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mixed_attribute_normal_dict", + "reference": "def test_mixed_attribute_normal_dict():\n from aiida.common.extendeddicts import AttributeDict\n\n # Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively\n dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}\n attribute_dict = AttributeDict(dictionary)\n\n # Now add a normal dictionary in the attribute dictionary\n attribute_dict['nested']['normal'] = {'a': 2}\n\n serialized = serialize.serialize(attribute_dict)\n deserialized = serialize.deserialize_unsafe(serialized)\n\n assert attribute_dict, deserialized", + "completions": [ + "```python\ndef test_mixed_attribute_normal_dict():\n data = {\n \"a\": 1,\n \"b\": 2,\n \"c\": 3,\n \"d\": 4,\n \"e\": 5,\n \"f\": 6,\n \"g\": 7,\n \"h\": 8,\n \"i\": 9,\n \"j\": 10,\n \"k\": 11,\n \"l\": 12,\n \"m\": 13,\n \"n\": 14,\n \"o\": 15,\n \"p\": 16,\n \"q\": 17,\n \"r\": 18,\n \"s\": 19,\n \"t\": 20,\n \"u\": 21,\n \"v\": 22,\n \"w\": 23,\n \"x\": 24,\n \"y\": 25,\n \"z\": 26,\n }\n assert serialize(data) == \"\"\"a: 1\nb: 2\nc: 3\nd: 4\ne: 5\nf: 6\ng: 7\nh: 8\ni: 9\nj: 10\nk: 11\nl: 12\nm: 13\nn: 14\no: 15\np: 16\nq: 17\nr: 18\ns: 19\nt: 20\nu: 21\nv: 22\nw: 23\nx: 24\ny: 25\nz: 26\n\"\"\"\n```\n" + ], + "line": 17, + "token": 133, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected_against_alias", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n http_client.select_model(model_id=\"yolov8n-640\")\n\n # when\n result = http_client.unload_model(model_id=\"yolov8n-640\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"coco/3\",\n }\n assert (\n http_client.selected_model is None\n ), \"Even when alias is in use - selected model should be emptied\"", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n requests_mock.post(\n f\"{API_URL}/model/remove\",\n json={\n \"model_id\": \"test_model_id\",\n },\n headers=DEFAULT_HEADERS,\n )\n client = RoboflowClient(api_key=\"test_api_key\", api_secret=\"test_api_secret\")\n # when\n response = client.unload_model(model_id=\"test_model_alias\")\n # then\n assert response.model_id == \"test_model_id\"\n assert response.model_name == \"test_model_name\"\n assert response.model_version == \"test_model_version\"\n assert response.model_type == \"test_model_type\"\n assert response.model_status == \"test_model_status\"\n assert response.model_created_at == \"test_model_created_at\"\n assert response.model_updated_at == \"test_model_updated_at\"\n assert response.model_tags == \"test_model_tags\"\n assert response.model_description == \"test_model_description\"\n assert response.model_url == \"test_model_url\"\n assert response.model_thumbnail_url == \"test_model_thumbnail_url\"\n assert response.model_training_data_url == \"test_model_training_data_url\"\n assert response.model_training_data_type == \"test_model_training_data_type\"\n assert response.model_training_data_size == \"test_model_training_data_size\"\n assert response.model_training_data_description == \"test_model_training_data_description\"\n assert response.model_training_data_tags == \"test_model_training_data_tags\"\n assert response.model_training_data_created_at == \"test_model_training_data_created_at\"\n assert response.model_training_data_updated_at == \"test_model_training_data_updated_at\"\n assert response.model_training_data_status == \"test_model_training_data_status\"\n assert response.model_training_data_status_message == \"test_model_training_data_status_message\"\n assert response.model_training_data_status_details == \"test_model_training_data_status_details\"\n assert response.model_training_data_status_created_at == \"test_model_training_data_status_created_at\"\n assert response.model_training_data_status_updated_at == \"test_model_training_data_status_updated_at\"\n assert response.model_training_data_status_completed_at == \"test_model_training_data_status_completed_at\"\n assert response.model_training_data_status_duration == \"test_model_training_data_status_duration\"\n assert response.model_training_data_status_budget_type == \"test_model_training_data_status_budget_type\"\n assert response.model_training_data_status_budget_amount == \"test_model_training_data_status_budget_amount\"\n assert response.model_training_data_status_budget_unit == \"test_model_training_data_status_budget_unit\"\n assert response.model_training_data_status_budget_remaining == \"test_model_training_data_status_budget_remaining\"\n assert response.model_training_data_status_budget_message == \"test_model_training_data_status_budget_message\"\n assert response.model_training_data_status_budget_details == \"test_model_training_data_status_budget_details\"\n assert response.model_training_data_status_budget_created_at == \"test_model_training_data_status_budget_created_at\"\n assert response.model_training_data_status_budget_updated_at == \"test_" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 33, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef edge_index_to_vector(\n edge_index: Tensor,\n size: Tuple[int, int],\n bipartite: bool,\n force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n row, col = edge_index\n\n if bipartite: # No need to account for self-loops.\n idx = (row * size[1]).add_(col)\n population = size[0] * size[1]\n return idx, population\n\n elif force_undirected:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We only operate on the upper triangular matrix:\n mask = row < col\n row, col = row[mask], col[mask]\n offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n idx = row.mul_(num_nodes).add_(col).sub_(offset)\n population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n return idx, population\n\n else:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We remove self-loops as we do not want to take them into account\n # when sampling negative values.\n mask = row != col\n row, col = row[mask], col[mask]\n col[row < col] -= 1\n idx = row.mul_(num_nodes - 1).add_(col)\n population = num_nodes * num_nodes - num_nodes\n return idx, population\n```\n###test function signature:\n```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_edge_index_to_vector_and_vice_versa", + "reference": "\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n N = 10\n row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n edge_index = torch.stack([row, col], dim=0)\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n assert population == N * N\n assert idx.tolist() == list(range(population))\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n assert is_undirected(edge_index2)\n assert edge_index.tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n assert population == N * N - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n force_undirected=True)\n assert population == (N * (N + 1)) / 2 - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n force_undirected=True)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()", + "completions": [ + "```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n\n # Test force_undirected:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test force_undirected with bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n```\n" + ], + "line": 43, + "token": 351, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 34, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_integration_test(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_integration_test", + "reference": "\n def test_spectrogram_integration_test(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,\n 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,\n 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,\n 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,\n 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,\n 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,\n 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,\n 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,\n 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,\n 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,\n 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,\n 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,\n 0.0293578 , 0.03452379, 0.02194803, 0.01676056,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=256,\n num_mel_filters=400,\n min_frequency=20,\n max_frequency=8000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"kaldi\",\n triangularize_in_mel_space=True,\n )\n\n mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"povey\", periodic=False),\n frame_length=400,\n hop_length=160,\n fft_length=512,\n power=2.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n preemphasis=0.97,\n mel_filters=mel_filters,\n log_mel=\"log\",\n mel_floor=1.1920928955078125e-07,\n remove_dc_offset=True,\n )\n self.assertEqual(spec.shape, (400, 584))\n\n # fmt: off\n expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -6.52463769, -7.73677889, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.18650018, -3.37195286,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -4.70190154, -2.4217066 , -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -5.62755239, -3.53385194,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -9.43303023, -8.77480925, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.2951092 , -5.51585994,\n -15.94238515, -15.94238515, -15.94238515, -4.40151721,\n -3.95228878, -15.94238515, -15.94238515, -15.94238515,\n -6.10365415, -4.59494697, -15.94238515, -15.94238515,\n -15.94238515, -8.10727767, -6.2585298 , -15.94238515,\n -15.94238515, -15.94238515, -5.60161702, -4.47217004,\n -15.94238515, -15.94238515, -15.94238515, -5.91641988]\n )\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))", + "completions": [ + "```python\n def test_spectrogram_integration_test(self):\n waveform = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])\n window = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n expected_spectrogram = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]])\n\n self.assertTrue(np.allclose(spectrogram, expected_spectrogram))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 35, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef run_query_expansion_node(modules: List[Callable],\n module_params: List[Dict],\n previous_result: pd.DataFrame,\n node_line_dir: str,\n strategies: Dict,\n ) -> pd.DataFrame:\n if not os.path.exists(node_line_dir):\n os.makedirs(node_line_dir)\n node_dir = os.path.join(node_line_dir, \"query_expansion\")\n if not os.path.exists(node_dir):\n os.makedirs(node_dir)\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n\n # run query expansion\n results, execution_times = zip(*map(lambda task: measure_speed(\n task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))\n average_times = list(map(lambda x: x / len(results[0]), execution_times))\n\n # save results to folder\n pseudo_module_params = deepcopy(module_params)\n for i, module_param in enumerate(pseudo_module_params):\n if 'prompt' in module_params:\n module_param['prompt'] = str(i)\n filepaths = list(map(lambda x: os.path.join(node_dir, f'{x}.parquet'), range(len(modules))))\n list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet\n filenames = list(map(lambda x: os.path.basename(x), filepaths))\n\n # make summary file\n summary_df = pd.DataFrame({\n 'filename': filenames,\n 'module_name': list(map(lambda module: module.__name__, modules)),\n 'module_params': module_params,\n 'execution_time': average_times,\n })\n\n # Run evaluation when there are more than one module.\n if len(modules) > 1:\n # pop general keys from strategies (e.g. metrics, speed_threshold)\n general_key = ['metrics', 'speed_threshold']\n general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))\n extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))\n\n # first, filter by threshold if it is enabled.\n if general_strategy.get('speed_threshold') is not None:\n results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],\n filenames)\n\n # check metrics in strategy\n if general_strategy.get('metrics') is None:\n raise ValueError(\"You must at least one metrics for query expansion evaluation.\")\n\n if extra_strategy.get('top_k') is None:\n extra_strategy['top_k'] = 10 # default value\n\n # get retrieval modules from strategy\n retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)\n\n # get retrieval_gt\n retrieval_gt = pd.read_parquet(os.path.join(project_dir, \"data\", \"qa.parquet\"))['retrieval_gt'].tolist()\n\n # run evaluation\n evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(\n retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,\n general_strategy['metrics'], project_dir, previous_result), results))\n\n evaluation_df = pd.DataFrame({\n 'filename': filenames,\n **{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))\n for metric_name in general_strategy['metrics']}\n })\n summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')\n\n best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)\n # change metric name columns to query_expansion_metric_name\n best_result = best_result.rename(columns={\n metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})\n best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])\n else:\n best_result, best_filename = results[0], filenames[0]\n best_result = pd.concat([previous_result, best_result], axis=1)\n\n # add 'is_best' column at summary file\n summary_df['is_best'] = summary_df['filename'] == best_filename\n\n # save files\n summary_df.to_csv(os.path.join(node_dir, \"summary.csv\"), index=False)\n best_result.to_parquet(os.path.join(node_dir, f\"best_{os.path.splitext(best_filename)[0]}.parquet\"), index=False)\n\n return best_result\n```\n###test function signature:\n```python\n\ndef test_run_query_expansion_node_default(node_line_dir):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_query_expansion_node_default", + "reference": "\ndef test_run_query_expansion_node_default(node_line_dir):\n generator_models['mock'] = MockLLM\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n qa_path = os.path.join(project_dir, \"data\", \"qa.parquet\")\n previous_result = pd.read_parquet(qa_path)\n\n modules = [query_decompose, hyde]\n module_params = [{'llm': \"mock\", 'temperature': 0.2, 'batch': 7}, {'llm': \"mock\"}]\n strategies = {\n 'metrics': metrics\n }\n best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n base_query_expansion_test(best_result, node_line_dir)", + "completions": [ + "```python\ndef test_run_query_expansion_node_default(node_line_dir):\n modules = [query_expansion_module]\n module_params = [{'prompt': 'default'}]\n previous_result = pd.DataFrame({'query': ['a', 'b', 'c']})\n strategies = {'metrics': ['recall', 'precision', 'f1'], 'speed_threshold': 1000}\n run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n```\n" + ], + "line": 84, + "token": 850, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 36, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(self) -> Tuple[pd.DataFrame, FreqDetectionResult]:\n try:\n if self._prophet_detector_params[\"suppress_stan\"]:\n pd.set_option(\"mode.chained_assignment\", None)\n\n # Skip measurements based on feedbacks given from SODA Cloud\n preprocessed_df = self.preprocess(time_series_df=self.raw_time_series_df)\n\n # Automatically detect frequency of the time series\n freq_detector = FrequencyDetector(\n logs=self.logs,\n params=self.params,\n time_series_df=preprocessed_df,\n manual_freq=self.training_dataset_params.frequency,\n )\n freq_detection_result = freq_detector.detect_frequency()\n\n # Return if frequency detection failed\n if freq_detection_result.error_code_int >= ERROR_CODE_LEVEL_CUTTOFF:\n return self.exit_with_warning(freq_detection_result)\n\n # Apply training dataset configurations\n training_df = self.apply_training_dataset_configs(\n time_series_df=preprocessed_df, freq_detection_result=freq_detection_result\n )\n\n # Remove big gaps from the time series to not confuse Prophet\n training_df = self.remove_big_gaps_from_time_series(\n time_series_df=training_df, freq_detection_result=freq_detection_result\n )\n\n # Only use the last n points for training based on the window length\n window_length = self.get_window_length(training_df=training_df)\n training_df = training_df.iloc[-window_length:]\n\n training_df_shape = training_df[\"y\"].dropna().shape[0]\n if training_df_shape <= self._min_n_points:\n freq_detection_result = get_not_enough_measurements_freq_result(n_data_points=training_df_shape)\n return self.exit_with_warning(freq_detection_result)\n\n model_hyperparameters = self.get_prophet_hyperparameters(time_series_df=training_df)\n\n predictions_df = self.setup_fit_predict(\n time_series_df=training_df, model_hyperparameters=model_hyperparameters\n )\n anomalies_df = self.detect_anomalies(time_series_df=training_df, predictions_df=predictions_df)\n anomalies_df = self.generate_severity_zones(anomalies_df=anomalies_df)\n anomalies_df = self.compute_alert_level(anomalies_df=anomalies_df)\n return anomalies_df, freq_detection_result\n except Exception as e:\n raise e\n```\n###test function signature:\n```python\n\ndef test_with_exit() -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_exit", + "reference": "\ndef test_with_exit() -> None:\n time_series_df = generate_random_dataframe(size=2, n_rows_to_convert_none=0, frequency=\"D\")\n detector = ProphetDetector(\n logs=LOGS,\n params=PARAMS,\n time_series_df=time_series_df,\n model_cfg=ModelConfigs(),\n training_dataset_params=TrainingDatasetParameters(),\n severity_level_params=SeverityLevelParameters(),\n )\n df_anomalies, frequency_result = detector.run()\n assert df_anomalies.empty\n assert frequency_result.error_code_int == 100\n assert frequency_result.freq_detection_strategy == \"not_enough_measurements_custom\"", + "completions": [ + "```python\ndef test_with_exit() -> None:\n from soda.sodacl.checks.check_frequency import FrequencyDetector\n from soda.sodacl.checks.check_frequency import get_not_enough_measurements_freq_result\n from soda.sodacl.checks.check_frequency import ERROR_CODE_LEVEL_CUTTOFF\n from soda.sodacl.checks.check_frequency import FreqDetectionResult\n from soda.sodacl.checks.check_frequency import CheckFrequency\n from soda.sodacl.checks.check_frequency import CheckFrequencyParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyLogs\n from soda.sodacl.checks.check_frequency import CheckFrequencyTrainingDatasetParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyPro" + ], + "line": 52, + "token": 310, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 37, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef average_transforms(\n R: torch.Tensor,\n t: torch.Tensor,\n w: torch.Tensor,\n mask: torch.Tensor,\n dim: int,\n t_edge: Optional[torch.Tensor] = None,\n dither: Optional[bool] = True,\n dither_eps: float = 1e-4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert dim >= 0, \"dimension must index from the left\"\n w = torch.where(\n mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)\n )\n\n # We use different averaging models based on the number of weights\n num_transform_weights = w.size(-1)\n if num_transform_weights == 1:\n # Share a single scalar weight between t and R.\n probs = w.softmax(dim)\n t_probs = probs\n R_probs = probs[..., None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 2:\n # Use separate scalar weights for each of t and R.\n probs = w.softmax(dim)\n t_probs, R_probs = probs.unbind(-1)\n t_probs = t_probs[..., None]\n R_probs = R_probs[..., None, None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 3:\n # For R use a signed scalar weight.\n R_probs = w[..., 2].softmax(dim)[..., None, None]\n\n # For t use a two-parameter precision matrix P = P_isometric + P_radial.\n # We need to hand compute softmax over the shared dim x 2 elements.\n w_t = w[..., :2]\n w_t_total = w_t.logsumexp([dim, -1], True)\n p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)\n\n # Use Gaussian fusion for translation.\n t_edge = t_edge * mask.to(t_edge.dtype)[..., None]\n t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)\n else:\n raise NotImplementedError\n\n # Average rotation via SVD\n R_avg_unc = (R * R_probs).sum(dim)\n R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)\n U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)\n R_avg = U @ Vh\n\n # Enforce that matrix is rotation matrix\n d = torch.linalg.det(R_avg)\n d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)\n Vh = Vh * d_expand\n R_avg = U @ Vh\n return R_avg, t_avg\n```\n###test function signature:\n```python\n\ndef test_average_transforms(rotations, translations):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_average_transforms", + "reference": "\ndef test_average_transforms(rotations, translations):\n R = torch.stack([rotations[0], torch.eye(3)])\n t = torch.stack([translations[0], torch.zeros(3)])\n w = torch.ones(2, 2)\n mask = torch.ones(2)\n # average of a transform with the identity is \"half\" the transform\n R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)\n R_total_fromavg, _ = compose_transforms(\n R_avg, torch.zeros(3), R_avg, torch.zeros(3)\n )\n _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)\n assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)\n assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)", + "completions": [ + "```python\ndef test_average_transforms(rotations, translations):\n # test case 1\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 2\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 3\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 4\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 5\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 6\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 7\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 8\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 9\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w" + ], + "line": 65, + "token": 537, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n return GPT2Config(\n vocab_size=btlm_config.vocab_size,\n n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n n_embd=btlm_config.hidden_size,\n n_layer=btlm_config.num_hidden_layers,\n n_head=btlm_config.num_attention_heads,\n n_inner=btlm_config.n_inner,\n activation_function=btlm_config.activation_function,\n resid_pdrop=btlm_config.resid_pdrop,\n embd_pdrop=btlm_config.embd_pdrop,\n attn_pdrop=btlm_config.attn_pdrop,\n layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n initializer_range=btlm_config.initializer_range,\n bos_token_id=btlm_config.bos_token_id,\n eos_token_id=btlm_config.eos_token_id,\n # These are new arguments not in the original GPT2Config\n use_alibi=btlm_config.position_embedding_type == \"alibi\",\n use_flash_attn=btlm_config.position_embedding_type == \"alibi\", # Alibi code path requires flash_attn\n mup_width_scale=btlm_config.mup_width_scale,\n mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n mup_output_multiplier=btlm_config.mup_output_alpha,\n mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n mlp_multiple_of=1,\n )\n```\n###test function signature:\n```python\ndef test_btlm_optimized(model_name):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_btlm_optimized", + "reference": "def test_btlm_optimized(model_name):\n dtype = torch.float16\n device = \"cuda\"\n config = btlm_config_to_gpt2_config(\n AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n )\n config.fused_bias_fc = True\n config.fused_dropout_add_ln = True\n config.residual_in_fp32 = True\n\n pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n model = GPTLMHeadModel(config, device=device, dtype=dtype)\n model.load_state_dict(pretrained_state_dict)\n model.eval()\n\n torch.manual_seed(0)\n batch_size = 2\n max_seqlen = 256\n seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n input_ids = torch.randint(\n 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n )\n with torch.no_grad():\n out = model.transformer(input_ids)\n logits = model(input_ids).logits\n del model\n\n # Without device_map, the model is loaded on the CPU, which is very slow\n # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n model_ref = AutoModelForCausalLM.from_pretrained(\n model_name, device_map=\"auto\", trust_remote_code=True\n )\n model_ref.eval()\n with torch.no_grad():\n out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n logits_ref = model_ref(input_ids).logits.to(device=device)\n del model_ref\n\n model_hf = AutoModelForCausalLM.from_pretrained(\n model_name,\n torch_dtype=dtype,\n device_map={\"\": device},\n trust_remote_code=True,\n )\n model_hf.eval()\n with torch.no_grad():\n out_hf = model_hf.transformer(input_ids).last_hidden_state\n logits_hf = model_hf(input_ids).logits\n del model_hf\n\n print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n assert (logits - logits_ref).abs().max().item() < 3 * (\n logits_hf - logits_ref\n ).abs().max().item()", + "completions": [ + "```python\ndef test_btlm_optimized(model_name):\n from transformers import GPT2Config, PretrainedConfig\n from bert_extended.optimized import btlm_config_to_gpt2_config\n\n config = PretrainedConfig.from_pretrained(model_name)\n gpt2_config = btlm_config_to_gpt2_config(config)\n gpt2_config.to_json_file(f\"{model_name}_optimized.json\")\n```\n" + ], + "line": 35, + "token": 169, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 39, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef grad_solid_harmonic(r: jnp.ndarray, l_max: int) -> jnp.ndarray:\n def cartesian_sum(r, ell):\n # computes all cartesian sums of order ell\n # i.e. all terms (a*x + b*y + c*z) for a+b+c = ell\n if ell == 0:\n return np.array([0.])\n return np.sum(np.stack(\n np.meshgrid(*[r]*ell, indexing='ij'), axis=0).reshape(ell, -1), axis=0)\n\n dharmonics = []\n for ell in range(l_max + 1):\n dx = cartesian_sum(np.array([1., 0., 0.]), ell)\n dy = cartesian_sum(np.array([0., 1., 0.]), ell)\n dz = cartesian_sum(np.array([0., 0., 1.]), ell)\n d = np.stack([dx, dy, dz], axis=1)\n cp = jax.vmap(functools.partial(cartesian_product, ell=ell), out_axes=-1)\n dcp = jnp.where(r.T[None, :, :] == 0, 0,\n d[:, :, None]*cp(r)[:, None, :] / r.T[None, :, :])\n sh = jnp.einsum('cs,cxn->snx', full_cart2sph(ell, True), dcp)\n pad = l_max - ell\n sh = jnp.pad(sh, [[pad, pad], [0, 0], [0, 0]])\n dharmonics.append(sh)\n\n return jnp.stack(dharmonics, axis=1) # [L, M, grid, xyz]\n```\n###test function signature:\n```python\n\n def test_grad_solid_harmonic(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_grad_solid_harmonic", + "reference": "\n def test_grad_solid_harmonic(self):\n np.random.seed(0)\n r = np.random.randn(100, 3)\n l_max = 5\n\n jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max))\n expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])),\n [1, 2, 0, 3])\n\n with self.subTest('by hand'):\n observed = gto.grad_solid_harmonic(r, l_max)\n np.testing.assert_allclose(observed, expected, atol=1.e-4)\n with self.subTest('by jax'):\n observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max)\n np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4)", + "completions": [ + "```python\n def test_grad_solid_harmonic(self):\n r = jnp.array([[0.5, 0.5, 0.5]])\n l_max = 2\n dharmonics = grad_solid_harmonic(r, l_max)\n self.assertEqual(dharmonics.shape, (l_max + 1, l_max + 1, 1, 3))\n self.assertTrue(jnp.allclose(dharmonics[0, 0, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 0, 0, :], jnp.array([1., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 1, 0, :], jnp.array([0., 1., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 0, 0, :], jnp.array([0., 0., 1.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 1, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 2, 0, :], jnp.array([0., 0., 0.])))\n```\n" + ], + "line": 32, + "token": 389, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 40, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_basis(\n base_shape, angular_tesselation, remove_symmetries=True, eps=1e-4\n):\n\n if base_shape == 'tetrahedron':\n verts = np.array([\n (np.sqrt(8 / 9), 0, -1 / 3),\n (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),\n (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3),\n (0, 0, 1),\n ])\n faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)])\n elif base_shape == 'icosahedron':\n a = (np.sqrt(5) + 1) / 2\n verts = np.array([\n (-1, 0, a),\n (1, 0, a),\n (-1, 0, -a),\n (1, 0, -a),\n (0, a, 1),\n (0, a, -1),\n (0, -a, 1),\n (0, -a, -1),\n (a, 1, 0),\n (-a, 1, 0),\n (a, -1, 0),\n (-a, -1, 0),\n ]) / np.sqrt(a + 2)\n faces = np.array([\n (0, 4, 1),\n (0, 9, 4),\n (9, 5, 4),\n (4, 5, 8),\n (4, 8, 1),\n (8, 10, 1),\n (8, 3, 10),\n (5, 3, 8),\n (5, 2, 3),\n (2, 7, 3),\n (7, 10, 3),\n (7, 6, 10),\n (7, 11, 6),\n (11, 0, 6),\n (0, 1, 6),\n (6, 1, 10),\n (9, 0, 11),\n (9, 11, 2),\n (9, 2, 5),\n (7, 2, 11),\n ])\n elif base_shape == 'octahedron':\n verts = np.array(\n [(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]\n )\n corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n else:\n raise ValueError(f'base_shape {base_shape} not supported')\n verts = tesselate_geodesic(verts, faces, angular_tesselation)\n\n if remove_symmetries:\n # Remove elements of `verts` that are reflections of each other.\n match = compute_sq_dist(verts.T, -verts.T) < eps\n verts = verts[~np.any(np.triu(match), axis=0), :]\n\n basis = verts[:, ::-1]\n return basis\n```\n###test function signature:\n```python\n def test_generate_basis_golden(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_basis_golden", + "reference": " def test_generate_basis_golden(self):\n\n basis = geopoly.generate_basis('tetrahedron', 1)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('tetrahedron', 2)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.57735027, 0.00000000, -0.81649658],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.57735027, -0.70710678, 0.40824829],\n [-0.57735027, 0.70710678, 0.40824829],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('icosahedron', 2)\n basis_golden = np.array([\n [0.85065081, 0.00000000, 0.52573111],\n [0.80901699, 0.50000000, 0.30901699],\n [0.52573111, 0.85065081, 0.00000000],\n [1.00000000, 0.00000000, 0.00000000],\n [0.80901699, 0.50000000, -0.30901699],\n [0.85065081, 0.00000000, -0.52573111],\n [0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, -0.85065081],\n [0.50000000, 0.30901699, -0.80901699],\n [0.00000000, 1.00000000, 0.00000000],\n [-0.52573111, 0.85065081, 0.00000000],\n [-0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, 0.85065081],\n [-0.30901699, 0.80901699, 0.50000000],\n [0.30901699, 0.80901699, 0.50000000],\n [0.50000000, 0.30901699, 0.80901699],\n [0.50000000, -0.30901699, 0.80901699],\n [0.00000000, 0.00000000, 1.00000000],\n [-0.50000000, 0.30901699, 0.80901699],\n [-0.80901699, 0.50000000, 0.30901699],\n [-0.80901699, 0.50000000, -0.30901699],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('octahedron', 4)\n basis_golden = np.array([\n [0.00000000, 0.00000000, -1.00000000],\n [0.00000000, -0.31622777, -0.94868330],\n [0.00000000, -0.70710678, -0.70710678],\n [0.00000000, -0.94868330, -0.31622777],\n [0.00000000, -1.00000000, 0.00000000],\n [-0.31622777, 0.00000000, -0.94868330],\n [-0.40824829, -0.40824829, -0.81649658],\n [-0.40824829, -0.81649658, -0.40824829],\n [-0.31622777, -0.94868330, 0.00000000],\n [-0.70710678, 0.00000000, -0.70710678],\n [-0.81649658, -0.40824829, -0.40824829],\n [-0.70710678, -0.70710678, 0.00000000],\n [-0.94868330, 0.00000000, -0.31622777],\n [-0.94868330, -0.31622777, 0.00000000],\n [-1.00000000, 0.00000000, 0.00000000],\n [0.00000000, -0.31622777, 0.94868330],\n [0.00000000, -0.70710678, 0.70710678],\n [0.00000000, -0.94868330, 0.31622777],\n [0.40824829, -0.40824829, 0.81649658],\n [0.40824829, -0.81649658, 0.40824829],\n [0.31622777, -0.94868330, 0.00000000],\n [0.81649658, -0.40824829, 0.40824829],\n [0.70710678, -0.70710678, 0.00000000],\n [0.94868330, -0.31622777, 0.00000000],\n [0.31622777, 0.00000000, -0.94868330],\n [0.40824829, 0.40824829, -0.81649658],\n [0.40824829, 0.81649658, -0.40824829],\n [0.70710678, 0.00000000, -0.70710678],\n [0.81649658, 0.40824829, -0.40824829],\n [0.94868330, 0.00000000, -0.31622777],\n [0.40824829, -0.40824829, -0.81649658],\n [0.40824829, -0.81649658, -0.40824829],\n [0.81649658, -0.40824829, -0.40824829],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))", + "completions": [ + "```python\n def test_generate_basis_golden(self):\n basis = generate_basis(\n base_shape='tetrahedron', angular_tesselation=10, remove_symmetries=True, eps=1e-4\n )\n self.assertTrue(np.allclose(basis, np.array([[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0]])))\n```\n" + ], + "line": 75, + "token": 692, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 41, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:\n if not isinstance(t, torch.Tensor):\n t = torch.Tensor([t]).float()\n\n t_max = self.t_max.to(t.device)\n t_min = self.t_min.to(t.device)\n t_tilde = t_min + (t_max - t_min) * t\n\n return t_tilde\n```\n###test function signature:\n```python\n\ndef test_noise_schedule_diffusion_coeff(kind):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_noise_schedule_diffusion_coeff", + "reference": "\ndef test_noise_schedule_diffusion_coeff(kind):\n noise_schedule = GaussianNoiseSchedule(kind=kind, log_snr_range=(-6, 6))\n ts = torch.linspace(1e-2, 1 - 1e-2, 10)\n t_map = noise_schedule.t_map(ts) # map time to the prescribed log_SNR range\n\n if kind == \"log_snr\":\n g = noise_schedule.g(ts)\n\n # compute true beta_t\n l_range = noise_schedule.log_snr_range\n snr = torch.exp((1 - t_map) * l_range[1] + t_map * l_range[0])\n g_true = (-(l_range[0] - l_range[1]) / (snr + 1)).sqrt()\n\n assert torch.allclose(g, g_true, atol=1e-4)\n\n if kind == \"ot_linear\":\n g = noise_schedule.g(ts)\n g_true = (2.0 * t_map / (1.0 - t_map)).sqrt()\n assert torch.allclose(g, g_true, atol=1e-4)", + "completions": [ + "```python\ndef test_noise_schedule_diffusion_coeff(kind):\n from torch.testing import assert_allclose\n from torch.nn import functional as F\n from torch import tensor\n from torch.distributions import Uniform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import Transform\n from torch.distributions.transforms import SigmoidTransform\n from torch.distributions.transforms import TanhTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import ComposeTransform\n from torch.distributions.transforms import PowerTransform\n from torch.distributions.transforms import ExpTransform\n from torch.distributions.transforms import LogTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import" + ], + "line": 17, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 42, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_data_quality", + "reference": "\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index, temperature_data, temperature_mean=False, data_quality=True\n )\n assert df.shape == (27, 4)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_not_null\",\n \"temperature_null\",\n ]\n assert round(df.temperature_not_null.mean(), 2) == 729.23\n assert round(df.temperature_null.mean(), 2) == 0.0", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"1H\", tz=\"America/Chicago\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0" + ], + "line": 168, + "token": 984, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 43, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n\n # when\n result = http_client.unload_model(model_id=\"some/1\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"some/1\",\n }", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n model_id = \"my_model_id\"\n de_aliased_model_id = \"my_de_aliased_model_id\"\n response_payload = {\n \"model_id\": de_aliased_model_id,\n \"model_name\": \"my_model_name\",\n \"model_version\": \"my_model_version\",\n \"model_type\": \"my_model_type\",\n \"model_status\": \"my_model_status\",\n \"model_description\": \"my_model_description\",\n \"model_tags\": [\"my_model_tag1\", \"my_model_tag2\"],\n \"model_created_at\": \"my_model_created_at\",\n \"model_updated_at\": \"my_model_updated_at\",\n }\n requests_mock.post(\n f\"{self.__api_url}/model/remove\",\n json=response_payload,\n headers=DEFAULT_HEADERS,\n )\n # when\n response = self.client.unload_model(model_id=model_id)\n # then\n assert response.model_id == de_aliased_model_id\n assert response.model_name == \"my_model_name\"\n assert response.model_version == \"my_model_version\"\n assert response.model_type == \"my_model_type\"\n assert response.model_status == \"my_model_status\"\n assert response.model_description == \"my_model_description\"\n assert response.model_tags == [\"my_model_tag1\", \"my_model_tag2\"]\n assert response.model_created_at == \"my_model_created_at\"\n assert response.model_updated_at == \"my_model_updated_at\"\n assert self.client.selected_model is None\n```\n" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=20.0,\n max_iteration=100,\n distance_strict=False,\n seed=None,\n):\n rng = np.random.default_rng(seed=seed)\n units_locations = np.zeros((num_units, 3), dtype=\"float32\")\n\n minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um\n minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um\n\n units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)\n units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)\n units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)\n\n if minimum_distance is not None:\n solution_found = False\n renew_inds = None\n for i in range(max_iteration):\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n inds0, inds1 = np.nonzero(distances < minimum_distance)\n mask = inds0 != inds1\n inds0 = inds0[mask]\n inds1 = inds1[mask]\n\n if inds0.size > 0:\n if renew_inds is None:\n renew_inds = np.unique(inds0)\n else:\n # random only bad ones in the previous set\n renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]\n\n units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)\n units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)\n units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)\n else:\n solution_found = True\n break\n\n if not solution_found:\n if distance_strict:\n raise ValueError(\n f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} \"\n \"You can use distance_strict=False or reduce minimum distance\"\n )\n else:\n warnings.warn(f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}\")\n\n return units_locations\n```\n###test function signature:\n```python\n\ndef test_generate_unit_locations():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_unit_locations", + "reference": "\ndef test_generate_unit_locations():\n seed = 0\n\n probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)\n channel_locations = probe.contact_positions\n\n num_units = 100\n minimum_distance = 20.0\n unit_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=minimum_distance,\n max_iteration=500,\n distance_strict=False,\n seed=seed,\n )\n distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)\n dist_flat = np.triu(distances, k=1).flatten()\n dist_flat = dist_flat[dist_flat > 0]\n assert np.all(dist_flat > minimum_distance)\n\n # import matplotlib.pyplot as plt\n # fig, ax = plt.subplots()\n # ax.hist(dist_flat, bins = np.arange(0, 400, 10))\n # fig, ax = plt.subplots()\n # from probeinterface.plotting import plot_probe\n # plot_probe(probe, ax=ax)\n # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)\n # plt.show()", + "completions": [ + "```python\n\ndef test_generate_unit_locations():\n num_units = 10\n channel_locations = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])\n margin_um = 20.0\n minimum_z = 5.0\n maximum_z = 40.0\n minimum_distance = 20.0\n max_iteration = 100\n distance_strict = False\n seed = None\n\n units_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=margin_um,\n minimum_z=minimum_z,\n maximum_z=maximum_z,\n minimum_distance=minimum_distance,\n max_iteration=max_iteration,\n distance_strict=distance_strict,\n seed=seed,\n )\n\n assert units_locations.shape == (num_units, 3)\n assert np.all(units_locations[:, 0] >= margin_um)\n assert np.all(units_locations[:, 0] <= 100 - margin_um)\n assert np.all(units_locations[:, 1] >= margin_um)\n assert np.all(units_locations[:, 1] <= 100 - margin_um)\n assert np.all(units_locations[:, 2] >= minimum_z)\n assert np.all(units_locations[:, 2] <= maximum_z)\n\n # check distance\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n assert np.all(distances >= minimum_distance)\n\n # check no duplicate\n assert np.all(np.unique(units_locations, axis=0) == units_locations)\n```\n" + ], + "line": 58, + "token": 453, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 45, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def to_dataframe(\n self, index: Optional[Union[str, None]] = \"date\", sort_by: Optional[str] = None\n ) -> pd.DataFrame:\n\n def is_list_of_basemodel(items: Union[List[T], T]) -> bool:\n return isinstance(items, list) and all(\n isinstance(item, BaseModel) for item in items\n )\n\n if self.results is None or not self.results:\n raise OpenBBError(\"Results not found.\")\n\n if isinstance(self.results, pd.DataFrame):\n return self.results\n\n try:\n res = self.results\n df = None\n sort_columns = True\n\n # List[Dict]\n if isinstance(res, list) and len(res) == 1 and isinstance(res[0], dict):\n r = res[0]\n dict_of_df = {}\n\n for k, v in r.items():\n # Dict[str, List[BaseModel]]\n if is_list_of_basemodel(v):\n dict_of_df[k] = basemodel_to_df(v, index)\n sort_columns = False\n # Dict[str, Any]\n else:\n dict_of_df[k] = pd.DataFrame(v)\n\n df = pd.concat(dict_of_df, axis=1)\n\n # List[BaseModel]\n elif is_list_of_basemodel(res):\n dt: Union[List[Data], Data] = res # type: ignore\n df = basemodel_to_df(dt, index)\n sort_columns = False\n # List[List | str | int | float] | Dict[str, Dict | List | BaseModel]\n else:\n try:\n df = pd.DataFrame(res)\n # Set index, if any\n if index is not None and index in df.columns:\n df.set_index(index, inplace=True)\n\n except ValueError:\n if isinstance(res, dict):\n df = pd.DataFrame([res])\n\n if df is None:\n raise OpenBBError(\"Unsupported data format.\")\n\n # Drop columns that are all NaN, but don't rearrange columns\n if sort_columns:\n df.sort_index(axis=1, inplace=True)\n df = df.dropna(axis=1, how=\"all\")\n\n # Sort by specified column\n if sort_by:\n df.sort_values(by=sort_by, inplace=True)\n\n except OpenBBError as e:\n raise e\n except ValueError as ve:\n raise OpenBBError(\n f\"ValueError: {ve}. Ensure the data format matches the expected format.\"\n ) from ve\n except TypeError as te:\n raise OpenBBError(\n f\"TypeError: {te}. Check the data types in your results.\"\n ) from te\n except Exception as ex:\n raise OpenBBError(f\"An unexpected error occurred: {ex}\") from ex\n\n return df\n```\n###test function signature:\n```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_to_dataframe_w_args", + "reference": "def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n co = OBBject(results=results)\n\n # Act and Assert\n result = co.to_dataframe(index=index, sort_by=sort_by)\n assert isinstance(result, pd.DataFrame)\n assert result.index.name == index\n\n # check if dataframe is properly sorted\n assert result[sort_by].is_monotonic_increasing", + "completions": [ + "```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n from openbb_terminal.decorators import log_start_end\n from openbb_terminal.stocks.fundamental_analysis import financedatabase_model\n\n @log_start_end(log=logger)\n def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n df = financedatabase_model.to_dataframe(results, index, sort_by)\n # Act\n # Assert\n assert df.shape[0] > 0\n assert df.shape[1] > 0\n assert df.index.name == index\n assert df.columns.tolist() == sort_by\n\n # Act\n test_to_dataframe_w_args(results, index, sort_by)\n # Assert\n```\n" + ], + "line": 76, + "token": 566, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 46, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef unique(ar, return_index=False, return_inverse=False,\n return_counts=False, axis=None, *, equal_nan=True):\n ar = np.asanyarray(ar)\n if axis is None:\n ret = _unique1d(ar, return_index, return_inverse, return_counts, \n equal_nan=equal_nan)\n return _unpack_tuple(ret)\n\n # axis was specified and not None\n try:\n ar = np.moveaxis(ar, axis, 0)\n except np.AxisError:\n # this removes the \"axis1\" or \"axis2\" prefix from the error message\n raise np.AxisError(axis, ar.ndim) from None\n\n # Must reshape to a contiguous 2D array for this to work...\n orig_shape, orig_dtype = ar.shape, ar.dtype\n ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))\n ar = np.ascontiguousarray(ar)\n dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]\n\n # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured\n # data type with `m` fields where each field has the data type of `ar`.\n # In the following, we create the array `consolidated`, which has\n # shape `(n,)` with data type `dtype`.\n try:\n if ar.shape[1] > 0:\n consolidated = ar.view(dtype)\n else:\n # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is\n # a data type with itemsize 0, and the call `ar.view(dtype)` will\n # fail. Instead, we'll use `np.empty` to explicitly create the\n # array with shape `(len(ar),)`. Because `dtype` in this case has\n # itemsize 0, the total size of the result is still 0 bytes.\n consolidated = np.empty(len(ar), dtype=dtype)\n except TypeError as e:\n # There's no good way to do this for object arrays, etc...\n msg = 'The axis argument to unique is not supported for dtype {dt}'\n raise TypeError(msg.format(dt=ar.dtype)) from e\n\n def reshape_uniq(uniq):\n n = len(uniq)\n uniq = uniq.view(orig_dtype)\n uniq = uniq.reshape(n, *orig_shape[1:])\n uniq = np.moveaxis(uniq, 0, axis)\n return uniq\n\n output = _unique1d(consolidated, return_index,\n return_inverse, return_counts, equal_nan=equal_nan)\n output = (reshape_uniq(output[0]),) + output[1:]\n return _unpack_tuple(output)\n```\n###test function signature:\n```python\n\n def test_unique_1d(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_1d", + "reference": "\n def test_unique_1d(self):\n\n def check_all(a, b, i1, i2, c, dt):\n base_msg = 'check {0} failed for type {1}'\n\n msg = base_msg.format('values', dt)\n v = unique(a)\n assert_array_equal(v, b, msg)\n\n msg = base_msg.format('return_index', dt)\n v, j = unique(a, True, False, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i1, msg)\n\n msg = base_msg.format('return_inverse', dt)\n v, j = unique(a, False, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i2, msg)\n\n msg = base_msg.format('return_counts', dt)\n v, j = unique(a, False, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, c, msg)\n\n msg = base_msg.format('return_index and return_inverse', dt)\n v, j1, j2 = unique(a, True, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n\n msg = base_msg.format('return_index and return_counts', dt)\n v, j1, j2 = unique(a, True, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format('return_inverse and return_counts', dt)\n v, j1, j2 = unique(a, False, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i2, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format(('return_index, return_inverse '\n 'and return_counts'), dt)\n v, j1, j2, j3 = unique(a, True, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n assert_array_equal(j3, c, msg)\n\n a = [5, 7, 1, 2, 1, 5, 7]*10\n b = [1, 2, 5, 7]\n i1 = [2, 3, 0, 1]\n i2 = [2, 3, 0, 1, 0, 2, 3]*10\n c = np.multiply([2, 1, 2, 2], 10)\n\n # test for numeric arrays\n types = []\n types.extend(np.typecodes['AllInteger'])\n types.extend(np.typecodes['AllFloat'])\n types.append('datetime64[D]')\n types.append('timedelta64[D]')\n for dt in types:\n aa = np.array(a, dt)\n bb = np.array(b, dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for object arrays\n dt = 'O'\n aa = np.empty(len(a), dt)\n aa[:] = a\n bb = np.empty(len(b), dt)\n bb[:] = b\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for structured arrays\n dt = [('', 'i'), ('', 'i')]\n aa = np.array(list(zip(a, a)), dt)\n bb = np.array(list(zip(b, b)), dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for ticket #2799\n aa = [1. + 0.j, 1 - 1.j, 1]\n assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])\n\n # test for ticket #4785\n a = [(1, 2), (1, 2), (2, 3)]\n unq = [1, 2, 3]\n inv = [0, 1, 0, 1, 1, 2]\n a1 = unique(a)\n assert_array_equal(a1, unq)\n a2, a2_inv = unique(a, return_inverse=True)\n assert_array_equal(a2, unq)\n assert_array_equal(a2_inv, inv)\n\n # test for chararrays with return_inverse (gh-5099)\n a = np.chararray(5)\n a[...] = ''\n a2, a2_inv = np.unique(a, return_inverse=True)\n assert_array_equal(a2_inv, np.zeros(5))\n\n # test for ticket #9137\n a = []\n a1_idx = np.unique(a, return_index=True)[1]\n a2_inv = np.unique(a, return_inverse=True)[1]\n a3_idx, a3_inv = np.unique(a, return_index=True,\n return_inverse=True)[1:]\n assert_equal(a1_idx.dtype, np.intp)\n assert_equal(a2_inv.dtype, np.intp)\n assert_equal(a3_idx.dtype, np.intp)\n assert_equal(a3_inv.dtype, np.intp)\n\n # test for ticket 2111 - float\n a = [2.0, np.nan, 1.0, np.nan]\n ua = [1.0, 2.0, np.nan]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - complex\n a = [2.0-1j, np.nan, 1.0+1j, complex(0.0, np.nan), complex(1.0, np.nan)]\n ua = [1.0+1j, 2.0-1j, complex(0.0, np.nan)]\n ua_idx = [2, 0, 3]\n ua_inv = [1, 2, 0, 2, 2]\n ua_cnt = [1, 1, 3]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - datetime64\n nat = np.datetime64('nat')\n a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]\n ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - timedelta\n nat = np.timedelta64('nat')\n a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]\n ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for gh-19300\n all_nans = [np.nan] * 4\n ua = [np.nan]\n ua_idx = [0]\n ua_inv = [0, 0, 0, 0]\n ua_cnt = [4]\n assert_equal(np.unique(all_nans), ua)\n assert_equal(np.unique(all_nans, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(all_nans, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(all_nans, return_counts=True), (ua, ua_cnt))", + "completions": [ + "```python\n def test_unique_1d(self):\n ar = np.array([1, 2, 3, 4, 5, 5, 5, 5, 5, 5])\n self.assertEqual(unique(ar), np.array([1, 2, 3, 4, 5]))\n self.assertEqual(unique(ar, return_index=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n```\n" + ], + "line": 56, + "token": 571, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef scaled_dot_product_attention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n att_mask: Optional[Union[AttentionMask, \"SparseCS\", torch.Tensor]],\n dropout: Optional[torch.nn.Module] = None,\n block_size: int = 128,\n) -> torch.Tensor:\n autocast_disabled = (\n _has_cpp_library\n and isinstance(att_mask, SparseCS)\n or (att_mask is not None and att_mask.is_sparse)\n )\n seq_len = q.shape[-2]\n\n # switch if:\n # causal is required but mask is not sparse\n # fp16 or under amp context\n # sequence length is divisible by block size\n # same seq len for K and Q\n switch_to_blocksparse = (\n _is_blocksparse_available\n and (att_mask is not None and not att_mask.is_sparse)\n and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)\n and (q.dtype == torch.float16 or torch.is_autocast_enabled())\n and not seq_len % block_size\n and q.shape[-2] == k.shape[-2]\n )\n\n if switch_to_blocksparse:\n logger.info(\"Switching causal attention to Triton blocksparse...\")\n return blocksparse_attention(q, k, v, dropout, block_size)\n\n with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore\n if autocast_disabled:\n q, k, v = q.float(), k.float(), v.float()\n\n att = scaled_query_key_softmax(q, k, att_mask=att_mask)\n\n # Optional dropout, could be part of the masking in the future\n att = _apply_dropout(att, dropout)\n\n # Get to the predicted values, for all heads\n # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)\n y = bmm(att, v)\n return y\n```\n###test function signature:\n```python\n\ndef test_amp_attention_dense(device):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_amp_attention_dense", + "reference": "\ndef test_amp_attention_dense(device):\n b, s, d = 8, 64, 32\n prob = 0.9\n\n a = torch.rand(b, s, d, device=device)\n m = torch.rand(s, s, device=device) > prob\n\n with torch.cuda.amp.autocast():\n r = scaled_dot_product_attention(a, a, a, m)\n\n expected_device = torch.float16 if device == \"cuda\" else torch.float32\n assert r.dtype == expected_device", + "completions": [ + "```python\n\ndef test_amp_attention_dense(device):\n q = torch.randn(1, 128, 128, device=device)\n k = torch.randn(1, 128, 128, device=device)\n v = torch.randn(1, 128, 128, device=device)\n att_mask = torch.randn(1, 128, 128, device=device)\n dropout = torch.nn.Dropout(0.1)\n block_size = 128\n y = scaled_dot_product_attention(q, k, v, att_mask, dropout, block_size)\n assert y.shape == (1, 128, 128)\n```\n" + ], + "line": 50, + "token": 390, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def quantize(\n cls,\n tensor: Tensor,\n nbits: int = 4,\n channel_wise: bool = True,\n group_size: int = 64,\n optimize: bool = False,\n round_zero: bool = False,\n axis: int = 0,\n bitpack: bool = True,\n compute_dtype: torch.dtype | None = None,\n view_as_float: bool = False,\n device: str = \"cuda\",\n ) -> tuple:\n assert nbits in Quantizer.SUPPORTED_BITS, (\n \"nbits=\" + str(nbits) + \" not supported.\"\n )\n assert axis in [0, 1], \"axis should be either 0 or 1\"\n if group_size is not None:\n assert is_divisible(tensor.numel(), group_size), (\n \"group_size should be divisble by the total tensor dimensions. shape: \"\n + str(tensor.shape)\n + \", group_size: \"\n + str(group_size)\n )\n\n W = tensor.float()\n shape = W.shape\n\n # Reshape for grouping\n if (group_size is not None) and channel_wise:\n W = (\n W.reshape([-1, group_size])\n if (axis == 1)\n else W.reshape([group_size, -1])\n )\n\n # Get min/max values\n if not channel_wise:\n _min, _max = W.min(), W.max()\n optimize = False\n else:\n _min = W.min(axis=axis, keepdim=True)[0]\n _max = W.max(axis=axis, keepdim=True)[0]\n\n max_v = 2**nbits - 1\n min_v = 0\n min_max = [min_v, max_v]\n\n # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.\n scale = (max_v / (_max - _min)).clamp(\n max=2e4\n ) # clamp to avoid half-precision problems\n zero = -_min * scale\n\n # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14\n if round_zero:\n zero = torch.round(zero)\n\n # Fine-tune weights\n if optimize:\n W_q, scale, zero = Quantizer.optimize_weights(\n tensor=W,\n scale=scale,\n zero=zero,\n min_max=min_max,\n axis=axis,\n device=device,\n )\n else:\n W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])\n\n # Store meta-data (we invert the scale for dequantization)\n meta = {\n \"nbits\": nbits,\n \"group_size\": group_size,\n \"shape\": shape,\n \"scale\": 1.0 / scale,\n \"zero\": zero,\n \"axis\": axis,\n \"packing\": Quantizer.bit_to_packing[nbits],\n }\n meta[\"unpack_view_dtype\"] = Quantizer.unpack_view_dtype[meta[\"packing\"]]\n\n # Pack bits\n meta[\"view_as_float\"] = view_as_float\n if bitpack:\n W_q = Quantizer.pack[meta[\"packing\"]](W_q)\n if view_as_float:\n W_q = W_q.view(\n torch.float32 if compute_dtype is None else compute_dtype\n ) # store quantized weights as compute_dtype\n else:\n W_q = W_q.to(tensor.dtype)\n meta[\"packing\"] = None\n\n # cleanup\n del W, _min, _max\n torch.cuda.empty_cache()\n\n return W_q, meta\n```\n###test function signature:\n```python\n\n def test_quantizer(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_quantizer", + "reference": "\n def test_quantizer(self): \n for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:\n for nbits in [8,4,3,2,1]:\n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, view_as_float=False)\n if nbits == 3:\n assert W_q.dtype == torch.int32\n else:\n assert W_q.dtype == torch.uint8\n w_dq = Quantizer.dequantize(W_q, meta)\n norm1 = torch.norm(self.w - w_dq, p=0.7)\n \n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, compute_dtype=compute_dtype, view_as_float=True)\n assert W_q.dtype == compute_dtype\n w_dq = Quantizer.dequantize(W_q, meta)\n norm2 = torch.norm(self.w - w_dq, p=0.7)\n \n self.assertTrue(torch.equal(norm1, norm2))", + "completions": [ + "```python\n\n def test_quantizer(self):\n # test case 1\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 2\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=True, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], 0.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 3\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=True, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 4\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor," + ], + "line": 100, + "token": 645, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 49, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_num_ref_area", + "reference": "\ndef test_nanoplot_out_bars_with_num_ref_area():\n\n out_bars_with_num_ref_area = _generate_nanoplot(**CASES[13])\n\n assert _is_nanoplot_output(out_bars_with_num_ref_area)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_area = [1, 3]\n y_ref_line = 2\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_" + ], + "line": 749, + "token": 7038, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 50, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef mean(x, axis=None, keepdims=False):\n if isinstance(x, (list, tuple)):\n x = stack(x)\n x = convert_to_tensor(x)\n if axis == () or axis == []:\n # Torch handles the empty axis case differently from numpy.\n return x\n axis = to_tuple_or_list(axis) # see [NB] below\n\n ori_dtype = standardize_dtype(x.dtype)\n # torch.mean only supports floating point inputs\n compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n if \"int\" in ori_dtype or ori_dtype == \"bool\":\n result_dtype = compute_dtype\n else:\n result_dtype = ori_dtype\n\n # [NB] the python torch op torch.mean() is generated into\n # `torch._C._VariableFunctions.pyi`, and the method\n # signature is overloaded.\n # Dynamo won't actually find the correct signature of\n # `torch.mean()` if arguments are passed via kwargs\n # So we have to pass the arguments via positional args\n # EXCEPT for those that are forced as kwargs via the `*`\n # delimiter in the overloaded method signatures.\n # Additionally, we have to create a singleton-tuple\n # when `axis` is an int to match the existing fn signature\n result = torch.mean(\n x,\n axis,\n keepdims,\n dtype=to_torch_dtype(compute_dtype),\n )\n return cast(result, result_dtype)\n```\n###test function signature:\n```python\n\n def test_mean(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mean", + "reference": "\n def test_mean(self):\n x = np.array([[1, 2, 3], [3, 2, 1]])\n self.assertAllClose(knp.mean(x), np.mean(x))\n self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n self.assertAllClose(\n knp.mean(x, axis=1, keepdims=True),\n np.mean(x, axis=1, keepdims=True),\n )\n\n self.assertAllClose(knp.Mean()(x), np.mean(x))\n self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n self.assertAllClose(\n knp.Mean(axis=1, keepdims=True)(x),\n np.mean(x, axis=1, keepdims=True),\n )\n\n # test overflow\n x = np.array([65504, 65504, 65504], dtype=\"float16\")\n self.assertAllClose(knp.mean(x), np.mean(x))", + "completions": [ + "```python\n\n def test_mean(self):\n x = torch.randn(2, 3, 4)\n self.assertEqual(mean(x), torch.mean(x))\n self.assertEqual(mean(x, axis=0), torch.mean(x, dim=0))\n self.assertEqual(mean(x, axis=1), torch.mean(x, dim=1))\n self.assertEqual(mean(x, axis=2), torch.mean(x, dim=2))\n self.assertEqual(mean(x, axis=(0, 1)), torch.mean(x, dim=(0, 1)))\n self.assertEqual(mean(x, axis=(0, 2)), torch.mean(x, dim=(0, 2)))\n self.assertEqual(mean(x, axis=(1, 2)), torch.mean(x, dim=(1, 2)))\n self.assertEqual(mean(x, axis=(0, 1, 2)), torch.mean(x, dim=(0, 1, 2)))\n self.assertEqual(mean(x, axis=0, keepdims=True), torch.mean(x, dim=0, keepdim=True))\n self.assertEqual(mean(x, axis=1, keepdims=True), torch.mean(x, dim=1, keepdim=True))\n self.assertEqual(mean(x, axis=2, keepdims=True), torch.mean(x, dim=2, keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1), keepdims=True), torch.mean(x, dim=(0, 1), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 2), keepdims=True), torch.mean(x, dim=(0, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(1, 2), keepdims=True), torch.mean(x, dim=(1, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1, 2), keepdims=True), torch.mean(x, dim=(0, 1, 2), keepdim=True))\n```\n" + ], + "line": 42, + "token": 322, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 51, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_intrinsic_value(\n cash_flow: float,\n growth_rate: float,\n perpetual_growth_rate: float,\n weighted_average_cost_of_capital: float,\n cash_and_cash_equivalents: float,\n total_debt: float,\n shares_outstanding: float,\n periods: int = 5,\n) -> pd.DataFrame:\n components = {}\n\n cash_flow_projection = [cash_flow]\n\n # Cash Flow to Use\n for period in range(1, periods + 1):\n if period == 1:\n cash_flow_projection.append(cash_flow_projection[0] * (1 + growth_rate))\n else:\n cash_flow_projection.append(\n cash_flow_projection[period - 1] * (1 + growth_rate)\n )\n\n # Calculate the Terminal Value\n terminal_value = (\n cash_flow_projection[-1]\n * (1 + perpetual_growth_rate)\n / (weighted_average_cost_of_capital - perpetual_growth_rate)\n )\n\n # Add Terminal Value to the end of the cash flow projection\n cash_flow_projection[-1] = cash_flow_projection[-1] + terminal_value\n\n # Calculate the Present Value based on the Discounted Cash Flow Formula\n cash_flow_present_value = []\n for index, cash_flow_value in enumerate(cash_flow_projection):\n cash_flow_present_value.append(\n cash_flow_value / (1 + weighted_average_cost_of_capital) ** (index + 1)\n )\n\n # Calculate the Enterprise Value\n enterprise_value = sum(cash_flow_present_value)\n\n # Calculate the Equity Value\n equity_value = enterprise_value + cash_and_cash_equivalents - total_debt\n\n # Calculate the Intrinsic Value\n intrinsic_value = equity_value / shares_outstanding\n\n # Combine the components into a dictionary\n components = {\n \"Terminal Value\": terminal_value,\n \"Cash Flow Projection\": cash_flow_projection[-1],\n \"Enterprise Value\": enterprise_value,\n \"Equity Value\": equity_value,\n \"Intrinsic Value\": intrinsic_value,\n }\n\n return pd.DataFrame.from_dict(\n components, orient=\"index\", columns=[f\"Periods = {periods}\"]\n )\n```\n###test function signature:\n```python\n\ndef test_get_intrinsic_value(recorder):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_intrinsic_value", + "reference": "\ndef test_get_intrinsic_value(recorder):\n recorder.capture(\n intrinsic_model.get_intrinsic_value(\n cash_flow=500,\n growth_rate=0.05,\n perpetual_growth_rate=0.02,\n weighted_average_cost_of_capital=0.1,\n cash_and_cash_equivalents=100,\n total_debt=100,\n shares_outstanding=100,\n periods=5,\n )\n )", + "completions": [ + "```python\n\ndef test_get_intrinsic_value(recorder):\n cash_flow = 1000000\n growth_rate = 0.05\n perpetual_growth_rate = 0.05\n weighted_average_cost_of_capital = 0.1\n cash_and_cash_equivalents = 1000000\n total_debt = 1000000\n shares_outstanding = 1000000\n periods = 5\n\n result = get_intrinsic_value(\n cash_flow,\n growth_rate,\n perpetual_growth_rate,\n weighted_average_cost_of_capital,\n cash_and_cash_equivalents,\n total_debt,\n shares_outstanding,\n periods,\n )\n\n recorder.capture(result)\n```\n" + ], + "line": 61, + "token": 379, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 52, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef to_numpy(\n evset: EventSet,\n timestamp_to_datetime: bool = True,\n timestamps: bool = True,\n) -> Dict[str, ndarray]:\n timestamp_key = \"timestamp\"\n index_names = evset.schema.index_names()\n feature_names = evset.schema.feature_names()\n\n column_names = index_names + feature_names\n if timestamps:\n column_names += [timestamp_key]\n\n dst = {column_name: [] for column_name in column_names}\n for index, data in evset.data.items():\n assert isinstance(index, tuple)\n\n if timestamps:\n # Timestamps\n if evset.schema.is_unix_timestamp and timestamp_to_datetime:\n dst[timestamp_key].append(\n data.timestamps.astype(\"datetime64[s]\")\n )\n else:\n dst[timestamp_key].append(data.timestamps)\n\n # Features\n for feature_name, feature in zip(feature_names, data.features):\n dst[feature_name].append(feature)\n\n # Indexes\n num_timestamps = len(data.timestamps)\n for index_name, index_item in zip(index_names, index):\n dst[index_name].append(np.repeat(index_item, num_timestamps))\n\n dst = {k: np.concatenate(v) for k, v in dst.items()}\n return dst\n```\n###test function signature:\n```python\n\n def test_correct(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct", + "reference": "\n def test_correct(self):\n evset = event_set(\n timestamps=[\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n features={\n \"feature_1\": [0.5, 0.6],\n \"my_index\": [\"red\", \"blue\"],\n },\n indexes=[\"my_index\"],\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"timestamp\": np.array(\n [\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n dtype=\"datetime64[s]\",\n ),\n \"feature_1\": np.array([0.5, 0.6]),\n \"my_index\": np.array([b\"red\", b\"blue\"]),\n }\n\n for k in expected:\n np.testing.assert_array_equal(\n np.sort(result[k]), np.sort(expected[k])\n )", + "completions": [ + "```python\n def test_correct(self):\n evset = EventSet(\n data={\n \"timestamp\": [1610000000, 1610000000, 1610000000, 1610000000],\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n },\n schema=EventSetSchema(\n index_names=[\"index\"],\n feature_names=[\"feature1\", \"feature2\"],\n is_unix_timestamp=True,\n ),\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n \"timestamp\": [\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n ],\n }\n\n self.assertDictEqual(result, expected)\n```\n" + ], + "line": 41, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def from_file(cls, path: str) -> \"LayerConfig\":\n logger.info(f\"Loading route config from {path}\")\n _, ext = os.path.splitext(path)\n with open(path, \"r\") as f:\n if ext == \".json\":\n layer = json.load(f)\n elif ext in [\".yaml\", \".yml\"]:\n layer = yaml.safe_load(f)\n else:\n raise ValueError(\n \"Unsupported file type. Only .json and .yaml are supported\"\n )\n\n if not is_valid(json.dumps(layer)):\n raise Exception(\"Invalid config JSON or YAML\")\n\n encoder_type = layer[\"encoder_type\"]\n encoder_name = layer[\"encoder_name\"]\n routes = []\n for route_data in layer[\"routes\"]:\n # Handle the 'llm' field specially if it exists\n if \"llm\" in route_data and route_data[\"llm\"] is not None:\n llm_data = route_data.pop(\n \"llm\"\n ) # Remove 'llm' from route_data and handle it separately\n # Use the module path directly from llm_data without modification\n llm_module_path = llm_data[\"module\"]\n # Dynamically import the module and then the class from that module\n llm_module = importlib.import_module(llm_module_path)\n llm_class = getattr(llm_module, llm_data[\"class\"])\n # Instantiate the LLM class with the provided model name\n llm = llm_class(name=llm_data[\"model\"])\n # Reassign the instantiated llm object back to route_data\n route_data[\"llm\"] = llm\n\n # Dynamically create the Route object using the remaining route_data\n route = Route(**route_data)\n routes.append(route)\n\n return cls(\n encoder_type=encoder_type, encoder_name=encoder_name, routes=routes\n )\n```\n###test function signature:\n```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_from_file_yaml", + "reference": "\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n config_path = tmp_path / \"config.yaml\"\n config_path.write_text(\n layer_yaml()\n ) # Assuming layer_yaml() returns a valid YAML string\n\n # Load the LayerConfig from the temporary file\n layer_config = LayerConfig.from_file(str(config_path))\n\n # Assertions to verify the loaded configuration\n assert layer_config.encoder_type == \"cohere\"\n assert layer_config.encoder_name == \"embed-english-v3.0\"\n assert len(layer_config.routes) == 2\n assert layer_config.routes[0].name == \"politics\"", + "completions": [ + "```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n yaml_file = tmp_path / \"test_layer.yaml\"\n yaml_file.write_text(\n \"\"\"\n encoder_type: openai\n encoder_name: gpt2\n routes:\n - name: test_route\n model_name: test_model\n llm:\n module: tests.test_layer\n class: TestLLM\n model: test_model\n encoder_kwargs:\n max_position_embeddings: 512\n num_attention_heads: 8\n intermediate_size: 2048\n hidden_act: gelu\n hidden_dropout_prob: 0.1\n attention_probs_dropout_prob: 0.1\n type_vocab_size: 16\n initializer_range: 0.02\n layer_norm_eps: 1e-12\n use_cache: True\n is_decoder: False\n classifier_dropout: 0.1\n use_multiplicative_input_mask: False\n num_labels: 2\n output_attentions: False\n output_hidden_states: False\n return_dict: True\n task_specific_params:\n summarization:\n max_length: 20\n min_length: 5\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n" + ], + "line": 49, + "token": 405, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 54, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def bind_all(self, circuit: QuantumCircuit) -> np.ndarray:\n arr = np.empty(self.shape, dtype=object)\n for idx in np.ndindex(self.shape):\n arr[idx] = self.bind(circuit, idx)\n return arr\n```\n###test function signature:\n```python\n def test_data(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_data", + "reference": " def test_data(self):\n with self.subTest(\"binding a single value\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 0)\n self.assertEqual(ba.shape, ())\n self.assertEqual(ba.size, 1)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuit = ba.bind(self.circuit, ())\n self.assertEqual(bound_circuit, self.circuit.assign_parameters(vals))\n\n with self.subTest(\"binding an array\"):\n vals = np.linspace(0, 1, 300).reshape((2, 3, 50))\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 2)\n self.assertEqual(ba.shape, (2, 3))\n self.assertEqual(ba.size, 6)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuits = ba.bind_all(self.circuit)\n self.assertEqual(bound_circuits.shape, (2, 3))\n self.assertEqual(bound_circuits[0, 0], self.circuit.assign_parameters(vals[0, 0]))\n self.assertEqual(bound_circuits[0, 1], self.circuit.assign_parameters(vals[0, 1]))\n self.assertEqual(bound_circuits[0, 2], self.circuit.assign_parameters(vals[0, 2]))\n self.assertEqual(bound_circuits[1, 0], self.circuit.assign_parameters(vals[1, 0]))\n self.assertEqual(bound_circuits[1, 1], self.circuit.assign_parameters(vals[1, 1]))\n self.assertEqual(bound_circuits[1, 2], self.circuit.assign_parameters(vals[1, 2]))\n\n with self.subTest(\"binding a single param\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params[0]: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 1)\n self.assertEqual(ba.ndim, 1)\n self.assertEqual(ba.shape, (50,))\n self.assertEqual(ba.size, 50)\n self.assertEqual(list(ba.data.keys()), [(self.params[0].name,)])\n np.testing.assert_allclose(list(ba.data.values()), [vals[..., np.newaxis]])", + "completions": [ + "```python\n def test_data(self):\n from qiskit import QuantumCircuit\n from qiskit.quantum_info import Pauli\n from qiskit.quantum_info import PauliTable\n from qiskit.quantum_info import PauliSumOp\n from qiskit.quantum_info import PauliOp\n from qiskit.quantum_info import PauliExpectation\n from qiskit.quantum_info import PauliMoment\n from qiskit.quantum_info import PauliTensor\n from qiskit.quantum_info import PauliList\n from qiskit.quantum_info import PauliTransform\n from qiskit.quantum_info import PauliTransformOp\n from qiskit.quantum_info import PauliTransformList\n from qiskit.quantum_info import PauliTransformSum\n from qiskit.quantum_info import PauliTransformProduct\n from qiskit.quantum_info import PauliTransformExpectation\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum" + ], + "line": 15, + "token": 122, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef set_seed(\n seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None\n) -> int:\n world_size, rank = get_world_size_and_rank()\n max_val = np.iinfo(np.uint32).max - world_size + 1\n min_val = np.iinfo(np.uint32).min\n if seed is None:\n rand_seed = torch.randint(min_val, max_val, (1,))\n seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks\n if seed < min_val or seed > max_val:\n raise ValueError(\n f\"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]\"\n )\n local_seed = seed + rank\n _log.debug(\n f\"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}\"\n )\n\n torch.manual_seed(local_seed)\n np.random.seed(local_seed)\n random.seed(local_seed)\n\n if debug_mode is not None:\n _log.debug(f\"Setting deterministic debug mode to {debug_mode}\")\n torch.set_deterministic_debug_mode(debug_mode)\n deterministic_debug_mode = torch.get_deterministic_debug_mode()\n if deterministic_debug_mode == 0:\n _log.debug(\"Disabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = False\n torch.backends.cudnn.benchmark = True\n else:\n _log.debug(\"Enabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.benchmark = False\n # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n return seed\n```\n###test function signature:\n```python\n\n def test_deterministic_true(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_deterministic_true", + "reference": "\n def test_deterministic_true(self) -> None:\n for det_debug_mode, det_debug_mode_str in [(1, \"warn\"), (2, \"error\")]:\n warn_only = det_debug_mode == 1\n for debug_mode in (det_debug_mode, det_debug_mode_str):\n set_seed(42, debug_mode=debug_mode)\n assert torch.backends.cudnn.deterministic\n assert not torch.backends.cudnn.benchmark\n assert det_debug_mode == torch.get_deterministic_debug_mode()\n assert torch.are_deterministic_algorithms_enabled()\n assert (\n warn_only == torch.is_deterministic_algorithms_warn_only_enabled()\n )\n assert os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] == \":4096:8\"", + "completions": [ + "```python\n def test_deterministic_true(self) -> None:\n seed = set_seed(debug_mode=True)\n self.assertEqual(seed, 123456789)\n```\n" + ], + "line": 45, + "token": 345, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 56, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n```\n###test function signature:\n```python\n\ndef test_concat_periodarray_2d():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_concat_periodarray_2d", + "reference": "\ndef test_concat_periodarray_2d():\n pi = pd.period_range(\"2016-01-01\", periods=36, freq=\"D\")\n arr = pi._data.reshape(6, 6)\n\n result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)\n tm.assert_period_array_equal(result, arr)\n\n result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)\n tm.assert_period_array_equal(result, arr)\n\n msg = (\n \"all the input array dimensions.* for the concatenation axis must match exactly\"\n )\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)\n\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:2], arr[2:]], axis=1)", + "completions": [ + "```python\n\ndef test_concat_periodarray_2d():\n import numpy as np\n import pandas as pd\n from pandas.core.dtypes.common import is_period_dtype\n\n def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n\n def _concat_datetime(to_concat, axis: AxisInt = 0):\n from pandas.core.arrays import PeriodArray\n\n if axis == 0:\n " + ], + "line": 70, + "token": 610, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 57, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef ingest_ocp_payload(request):\n request_id = uuid4().hex\n response_data = {\"request-id\": request_id, \"payload-name\": []}\n if request.method == \"POST\":\n if not request.FILES:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for _, file in request.FILES.items():\n payload_name = file.name\n response_data[\"payload-name\"].append(payload_name)\n s3_signature = get_s3_signature(settings.S3_ENDPOINT, payload_name, method=\"put_object\")\n res = upload_file_to_s3(s3_signature, data=file.file)\n if res.status_code == HTTPStatus.OK:\n response_data[\"upload\"] = \"success\"\n else:\n response_data[\"upload\"] = \"failed\"\n response_data[\"failed-reason\"] = res.reason\n return Response(response_data, status=res.status_code)\n send_payload(request_id, payload_name)\n else:\n params = request.query_params\n payload_names = params.get(\"payload_name\")\n if not payload_names:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for payload_name in payload_names.split(\",\"):\n send_payload(request_id, payload_name)\n response_data[\"payload-name\"].append(payload_name)\n\n response_data[\"ingest-started\"] = True\n\n return Response(response_data, status=HTTPStatus.ACCEPTED)\n```\n###test function signature:\n```python\n\n def test_ingest_ocp_payload(self):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ingest_ocp_payload", + "reference": "\n def test_ingest_ocp_payload(self):\n # Arrange\n file1 = SimpleUploadedFile(\"file1.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n file2 = SimpleUploadedFile(\"file2.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n test_table = [\n # Happy path tests\n (\n \"POST\",\n {\"file1\": file1, \"file2\": file2},\n \"\",\n HTTPStatus.ACCEPTED,\n {\"upload\": \"success\", \"ingest-started\": True},\n \"happy_path_post\",\n ),\n (\"GET\", {}, \"?payload_name=file1,file2\", HTTPStatus.ACCEPTED, {\"ingest-started\": True}, \"happy_path_get\"),\n # Edge cases\n (\"POST\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_files_post\"),\n (\"GET\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_payload_name_get\"),\n # Error cases\n (\"POST\", {\"file1\": file1}, \"\", HTTPStatus.INTERNAL_SERVER_ERROR, {\"upload\": \"failed\"}, \"error_case_post\"),\n (\"GET\", {}, \"?payload_name=non_existent_file\", HTTPStatus.ACCEPTED, {}, \"error_case_get\"),\n ]\n for test in test_table:\n method, files, query_params, expected_status, expected_response, test_id = test\n with self.subTest(test=test_id):\n\n mock_response = MagicMock()\n mock_response.status_code = (\n HTTPStatus.OK if expected_response.get(\"upload\") != \"failed\" else HTTPStatus.INTERNAL_SERVER_ERROR\n )\n mock_response.reason = \"Some reason\"\n\n with patch(\"masu.api.ingest_ocp_payload.get_s3_signature\") as mock_get_s3_signature, patch(\n \"masu.api.ingest_ocp_payload.upload_file_to_s3\", return_value=mock_response\n ) as mock_upload_file_to_s3, patch(\"masu.api.ingest_ocp_payload.send_payload\") as mock_send_payload:\n\n request = (\n APIRequestFactory().get(f\"/{query_params}\")\n if method == \"GET\"\n else APIRequestFactory().post(\"/\", data=files)\n )\n\n # Act\n response = ingest_ocp_payload(request)\n\n # Assert\n self.assertEqual(response.status_code, expected_status)\n if method == \"POST\" and files:\n mock_get_s3_signature.assert_called()\n mock_upload_file_to_s3.assert_called()\n if method == \"GET\" and \"payload_name\" in query_params:\n mock_send_payload.assert_called()\n for key, value in expected_response.items():\n self.assertEqual(response.data[key], value)", + "completions": [ + "```python\n def test_ingest_ocp_payload(self):\n # Arrange\n request = self.factory.post(\n \"/api/v1/ingest/ocp/payload\",\n data={\"payload_name\": \"payload_name\"},\n content_type=\"multipart/form-data\",\n )\n request.FILES = {\"payload_name\": \"payload_name\"}\n # Act\n response = ingest_ocp_payload(request)\n # Assert\n self.assertEqual(response.status_code, HTTPStatus.ACCEPTED)\n self.assertEqual(response.data[\"request-id\"], \"request_id\")\n self.assertEqual(response.data[\"payload-name\"], [\"payload_name\"])\n self.assertEqual(response.data[\"ingest-started\"], True)\n```\n" + ], + "line": 41, + "token": 308, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def compute_nb_items_with_metadata_field(\n project_id: str,\n metadata_field: str,\n collection_name: str,\n quantile_value: float = 0.1,\n) -> Tuple[int, float, int]:\n\n mongo_db = await get_mongo_db()\n\n try:\n pipeline = [\n # Match on the project id and the existence of the metadata field\n {\n \"$match\": {\n \"project_id\": project_id,\n f\"metadata.{metadata_field}\": {\"$exists\": True, \"$ne\": None},\n }\n },\n {\n \"$group\": {\n \"_id\": \"$user_id\", # Group by 'user_id'\n \"count\": {\"$sum\": 1}, # Count the documents in each group\n }\n },\n {\n \"$sort\": {\"count\": -1} # Sort by 'count' in descending order\n },\n ]\n\n # Execute the aggregation pipeline\n results = [] # List to hold the results\n async for group in mongo_db[collection_name].aggregate(pipeline):\n results.append(group[\"count\"])\n\n logger.debug(\n f\"Results for {metadata_field} in collection {collection_name} for project {project_id}: {results}\"\n )\n\n # If no results, return 0\n if len(results) == 0:\n return 0, 0.0, 0\n\n # Get the average and the quantiles\n average = sum(results) / len(results)\n bottom_quantile = results[int(len(results) * quantile_value)]\n top_quantile = results[int(len(results) * (1 - quantile_value))]\n\n return bottom_quantile, average, top_quantile\n\n except Exception as e:\n logger.warning(\n f\"Failed to fetch the number of {metadata_field} in collection {collection_name} for project {project_id}: {e}\",\n )\n return 0, 0.0, 0\n```\n###test function signature:\n```python\n\nasync def test_main_pipeline(db, populated_project):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_main_pipeline", + "reference": "\nasync def test_main_pipeline(db, populated_project):\n async for mongo_db in db:\n test_project_id = populated_project.id\n\n # Run the function\n topics = await fetch_topics(test_project_id)\n\n logger.debug(f\"Topics: {topics}\")\n\n count = await nb_items_with_a_metadata_field(\n test_project_id, \"tasks\", \"user_id\"\n )\n\n assert count > 0\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_successrate_metadata_quantiles(test_project_id, \"user_id\")\n\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert bottom_quantile >= 0\n # assert bottom_quantile <= top_quantile\n\n test_project_id = populated_project.id\n test_metadata_field = \"user_id\"\n test_collection_name = \"tasks\"\n test_quantile_value = 0.1\n\n # Run the function\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_nb_items_with_metadata_field(\n test_project_id,\n test_metadata_field,\n test_collection_name,\n test_quantile_value,\n )\n\n logger.debug(\"runned compute_nb_items_with_metadata_field\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, int)\n assert isinstance(average, float)\n assert isinstance(top_quantile, int)\n # assert bottom_quantile <= average <= top_quantile\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_session_length_per_metadata(\n test_project_id,\n test_metadata_field,\n test_quantile_value,\n )\n logger.debug(\"runned compute_session_length_per_metadata\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, float)\n assert isinstance(average, float)\n assert isinstance(top_quantile, float)\n # assert bottom_quantile <= average <= top_quantile", + "completions": [ + "```python\nasync def test_main_pipeline(db, populated_project):\n # Arrange\n project_id = populated_project.id\n collection_name = \"collection_name\"\n metadata_field = \"metadata_field\"\n quantile_value = 0.1\n\n # Act\n result = await compute_nb_items_with_metadata_field(\n project_id, metadata_field, collection_name, quantile_value\n )\n\n # Assert\n assert result == (0, 0.0, 0)\n```\n" + ], + "line": 56, + "token": 430, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n```\n###test function signature:\n```python\n\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_engine_auto_error_message", + "reference": "\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n from pandas.compat._optional import VERSIONS\n\n # Do we have engines installed, but a bad version of them?\n pa_min_ver = VERSIONS.get(\"pyarrow\")\n fp_min_ver = VERSIONS.get(\"fastparquet\")\n have_pa_bad_version = (\n False\n if not _HAVE_PYARROW\n else Version(pyarrow.__version__) < Version(pa_min_ver)\n )\n have_fp_bad_version = (\n False\n if not _HAVE_FASTPARQUET\n else Version(fastparquet.__version__) < Version(fp_min_ver)\n )\n # Do we have usable engines installed?\n have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version\n have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version\n\n if not have_usable_pa and not have_usable_fp:\n # No usable engines found.\n if have_pa_bad_version:\n match = f\"Pandas requires version .{pa_min_ver}. or newer of .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n\n if have_fp_bad_version:\n match = f\"Pandas requires version .{fp_min_ver}. or newer of .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")", + "completions": [ + "```python\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n import pyarrow\n import fastparquet\n from pyarrow.parquet import ParquetEngine\n from fastparquet.parquet import ParquetEngine as FastParquetEngine\n\n def get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n\n # pyarrow is installed\n assert get_engine(\"auto\").__class__ == PyArrowImpl\n assert get_engine(\"pyarrow\").__class__ == PyArrowImpl\n assert get_engine(\"fastparquet\").__class__ == FastParquetImpl\n\n # pyarrow is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'pyarrow'\" in str(excinfo.value)\n\n # fastparquet is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'fastparquet'\" in str(excinfo.value)\n\n # pyarrow is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n\n # fastparquet is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n" + ], + "line": 38, + "token": 277, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 60, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def attribute(\n self,\n inp: InterpretableInput,\n target: Union[str, torch.Tensor, None] = None,\n num_trials: int = 1,\n gen_args: Optional[Dict] = None,\n # internal callback hook can be used for logging\n _inspect_forward: Optional[Callable] = None,\n **kwargs,\n ) -> LLMAttributionResult:\n\n assert isinstance(\n inp, self.SUPPORTED_INPUTS\n ), f\"LLMAttribution does not support input type {type(inp)}\"\n\n if target is None:\n # generate when None\n assert hasattr(self.model, \"generate\") and callable(self.model.generate), (\n \"The model does not have recognizable generate function.\"\n \"Target must be given for attribution\"\n )\n\n if not gen_args:\n gen_args = DEFAULT_GEN_ARGS\n\n model_inp = self._format_model_input(inp.to_model_input())\n output_tokens = self.model.generate(model_inp, **gen_args)\n target_tokens = output_tokens[0][model_inp.size(1) :]\n else:\n assert gen_args is None, \"gen_args must be None when target is given\"\n\n if type(target) is str:\n # exclude sos\n target_tokens = self.tokenizer.encode(target)[1:]\n target_tokens = torch.tensor(target_tokens)\n elif type(target) is torch.Tensor:\n target_tokens = target\n\n attr = torch.zeros(\n [\n 1 + len(target_tokens) if self.include_per_token_attr else 1,\n inp.n_itp_features,\n ],\n dtype=torch.float,\n device=self.device,\n )\n\n for _ in range(num_trials):\n attr_input = inp.to_tensor().to(self.device)\n\n cur_attr = self.attr_method.attribute(\n attr_input,\n additional_forward_args=(inp, target_tokens, _inspect_forward),\n **kwargs,\n )\n\n # temp necessary due to FA & Shapley's different return shape of multi-task\n # FA will flatten output shape internally (n_output_token, n_itp_features)\n # Shapley will keep output shape (batch, n_output_token, n_input_features)\n cur_attr = cur_attr.reshape(attr.shape)\n\n attr += cur_attr\n\n attr = attr / num_trials\n\n attr = inp.format_attr(attr)\n\n return LLMAttributionResult(\n attr[0],\n attr[1:]\n if self.include_per_token_attr\n else None, # shape(n_output_token, n_input_features)\n inp.values,\n self.tokenizer.convert_ids_to_tokens(target_tokens),\n )\n```\n###test function signature:\n```python\n\n def test_llm_attr(self, AttrClass) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_llm_attr", + "reference": "\n def test_llm_attr(self, AttrClass) -> None:\n llm = DummyLLM()\n llm.to(self.device)\n tokenizer = DummyTokenizer()\n llm_attr = LLMAttribution(AttrClass(llm), tokenizer)\n\n inp = TextTemplateInput(\"{} b {} {} e {}\", [\"a\", \"c\", \"d\", \"f\"])\n res = llm_attr.attribute(inp, \"m n o p q\")\n\n self.assertEqual(res.seq_attr.shape, (4,))\n self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))\n self.assertEqual(res.input_tokens, [\"a\", \"c\", \"d\", \"f\"])\n self.assertEqual(res.output_tokens, [\"m\", \"n\", \"o\", \"p\", \"q\"])\n self.assertEqual(res.seq_attr.device.type, self.device)\n self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)", + "completions": [ + "```python\n def test_llm_attr(self, AttrClass) -> None:\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for L" + ], + "line": 72, + "token": 464, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 61, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def plan(self, task: str):\n agent_pool_description = self.agent_pool.get_agent_pool_description()\n feedback = self.get_last_feedback()\n last_plan = self.get_last_plan()\n\n # run the planning\n plan = self.planning.execute(task, agent_pool_description, last_plan, feedback)\n\n logger.info(f\"Plan: {plan}\")\n\n self.shared_memory.add(\n event_type=EventType.planning, agent=self.name, content=str(plan)\n )\n\n return plan\n```\n###test function signature:\n```python\n\ndef test_planner_succeeds(get_llm): # noqa: F811\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_planner_succeeds", + "reference": "\ndef test_planner_succeeds(get_llm): # noqa: F811\n llm = get_llm(__file__, test_planner_succeeds.__name__)\n\n physicist_description = (\n \"The physicist agent answers questions or research about physics-related topics\"\n )\n physicist = Physicist(\n name=\"Physicist\",\n description=physicist_description,\n llm=llm,\n )\n\n agent_pool = AgentPool()\n agent_pool.add_agents([physicist])\n\n shared_memory = SharedMemory(\n objective=\"Share the information across different agents.\",\n agent_pool=agent_pool,\n )\n\n planner = Planner(\n name=\"planner\",\n agent_pool=agent_pool,\n shared_memory=shared_memory,\n llm=llm,\n )\n\n task = \"\"\"We need to render a highly complex 3D image on the solar system. We can use any publicly avaliable\n resources to achieve this task.\"\"\" # noqa: E501\n\n p = planner.plan(task=task)\n assert len(p.steps) > 0", + "completions": [ + "```python\ndef test_planner_succeeds(get_llm): # noqa: F811\n planner = get_llm\n planner.agent_pool.set_agent_pool_description(\n {\n \"agent_pool\": {\n \"agents\": [\n {\n \"name\": \"agent1\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n {\n \"name\": \"agent2\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n ]\n }\n }\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task1\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task2\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task3\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task4\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task5\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task6\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task7\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\"," + ], + "line": 21, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_edge(self, from_step: str, to_step: str) -> None:\n if from_step not in self.G:\n raise ValueError(f\"Step with name '{from_step}' does not exist\")\n\n if to_step not in self.G:\n raise ValueError(f\"Step with name '{to_step}' does not exist\")\n\n if to_step in self.G[from_step]:\n raise ValueError(\n f\"There is already a edge from '{to_step}' to '{from_step}'\"\n )\n\n if to_step in nx.ancestors(self.G, from_step):\n raise ValueError(\n f\"Cannot add edge from '{from_step}' to '{to_step}' as it would create a cycle.\"\n )\n\n self.G.add_edge(from_step, to_step)\n```\n###test function signature:\n```python\n\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validate_missing_step_input", + "reference": "\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n class DummyStep3(Step):\n @property\n def inputs(self) -> List[str]:\n return [\"instruction\"]\n\n @property\n def outputs(self) -> List[str]:\n return [\"response\"]\n\n def process(self) -> \"StepOutput\": # type: ignore\n yield [{\"response\": \"response1\"}]\n\n dag = DAG()\n dag.add_step(dummy_generator_step)\n dag.add_step(DummyStep3(name=\"dummy_step_3\", pipeline=pipeline))\n dag.add_edge(\"dummy_generator_step\", \"dummy_step_3\")\n\n with pytest.raises(\n ValueError,\n match=\"Step 'dummy_step_3' should have a parameter with type hint `StepInput`\",\n ):\n dag.validate()", + "completions": [ + "```python\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_2\")\n```\n" + ], + "line": 26, + "token": 238, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 63, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_hourly_temp_mean", + "reference": "\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n meter_data = il_electricity_cdd_hdd_hourly[\"meter_data\"][\"2016-03-01\":\"2016-07-01\"]\n temperature_data = il_electricity_cdd_hdd_hourly[\"temperature_data\"][\n \"2016-03-01\":\"2016-07-01\"\n ]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert list(sorted(df.columns)) == [\n \"n_hours_dropped\",\n \"n_hours_kept\",\n \"temperature_mean\",\n ]\n assert df.shape == (2952, 3)\n\n assert round(df.temperature_mean.mean()) == 62.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n df = il_electricity_cdd_hdd_hourly.copy()\n df = df.iloc[100:150]\n df = df.assign(\n **{\n \"cdd_10\": np.maximum(df.temperature_mean - 10, 0),\n \"cdd_20\": np.maximum(df.temperature_mean - 20, 0),\n \"hdd_10\": np.maximum(10 - df.temperature_mean, 0),\n \"hdd_20\": np.maximum(20 - df.temperature_mean, 0),\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n df = df.drop([\"temperature_mean\"], axis=1)\n df = df.reindex(il_electricity_cdd_hdd_hourly.index)\n df = overwrite_partial_rows_with_nan(df)\n df = df.iloc[:-1].reindex(df.index)\n assert_frame_equal(df, il_electricity_cdd_hdd_hourly)\n```\n" + ], + "line": 167, + "token": 992, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 64, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef pandasSQL_builder(\n con,\n schema: str | None = None,\n need_transaction: bool = False,\n) -> PandasSQL:\n import sqlite3\n\n if isinstance(con, sqlite3.Connection) or con is None:\n return SQLiteDatabase(con)\n\n sqlalchemy = import_optional_dependency(\"sqlalchemy\", errors=\"ignore\")\n\n if isinstance(con, str) and sqlalchemy is None:\n raise ImportError(\"Using URI string without sqlalchemy installed.\")\n\n if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):\n return SQLDatabase(con, schema, need_transaction)\n\n warnings.warn(\n \"pandas only supports SQLAlchemy connectable (engine/connection) or \"\n \"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 \"\n \"objects are not tested. Please consider using SQLAlchemy.\",\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return SQLiteDatabase(con)\n```\n###test function signature:\n```python\n\n def test_execute_closed_connection(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_execute_closed_connection", + "reference": "\n def test_execute_closed_connection(self):\n create_sql = \"\"\"\n CREATE TABLE test\n (\n a TEXT,\n b TEXT,\n c REAL,\n PRIMARY KEY (a, b)\n );\n \"\"\"\n with contextlib.closing(sqlite3.connect(\":memory:\")) as conn:\n cur = conn.cursor()\n cur.execute(create_sql)\n\n with sql.pandasSQL_builder(conn) as pandas_sql:\n pandas_sql.execute('INSERT INTO test VALUES(\"foo\", \"bar\", 1.234)')\n\n msg = \"Cannot operate on a closed database.\"\n with pytest.raises(sqlite3.ProgrammingError, match=msg):\n tquery(\"select * from test\", con=conn)", + "completions": [ + "```python\n def test_execute_closed_connection(self):\n from pandas.io.sql import PandasSQL\n from pandas.io.sql.sqldatabase import SQLiteDatabase\n from pandas.io.sql.pandas_sql import PandasSQLQueryCompiler\n from pandas.io.sql.pandas_sql import PandasSQLEngine\n from pandas.io.sql.pandas_sql import PandasSQLTable\n from pandas.io.sql.pandas_sql import PandasSQLConverter\n from pandas.io.sql.pandas_sql import PandasSQLAlchemy\n from pandas.io.sql.pandas_sql import PandasSQLQuery\n from pandas.io.sql.pandas_sql import PandasSQLExecutionContext\n from pandas.io.sql.pandas_sql import PandasSQLCompiler\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas." + ], + "line": 31, + "token": 237, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 65, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_hourly_degree_days", + "reference": "\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index,\n temperature_data,\n heating_balance_points=[60, 61],\n cooling_balance_points=[65, 66],\n temperature_mean=False,\n degree_day_method=\"hourly\",\n )\n assert df.shape == (27, 6)\n assert list(sorted(df.columns)) == [\n \"cdd_65\",\n \"cdd_66\",\n \"hdd_60\",\n \"hdd_61\",\n \"n_hours_dropped\",\n \"n_hours_kept\",\n ]\n snapshot.assert_match(\n [\n round(df.hdd_60.mean(), 2),\n round(df.hdd_61.mean(), 2),\n round(df.cdd_65.mean(), 2),\n round(df.cdd_66.mean(), 2),\n round(df.n_hours_kept.mean(), 2),\n round(df.n_hours_dropped.mean(), 2),\n ],\n \"values\",\n )", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"H\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 1" + ], + "line": 168, + "token": 985, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef initialize_model_parallel(\n tensor_model_parallel_size: int = 1,\n pipeline_model_parallel_size: int = 1,\n virtual_pipeline_model_parallel_size: Optional[int] = None,\n pipeline_model_parallel_split_rank: Optional[int] = None,\n use_fp8: bool = False,\n) -> None:\n # Get world size and rank. Ensure some consistencies.\n assert torch.distributed.is_initialized()\n world_size: int = torch.distributed.get_world_size()\n\n if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:\n raise RuntimeError(\n f\"world_size ({world_size}) is not divisible by tensor_model_parallel_size \"\n f\"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n )\n\n data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n\n num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n num_data_parallel_groups: int = world_size // data_parallel_size\n\n if virtual_pipeline_model_parallel_size is not None:\n if not pipeline_model_parallel_size > 2:\n raise RuntimeError(\"pipeline-model-parallel size should be greater than 2 with \" \"interleaved schedule\")\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size\n\n if pipeline_model_parallel_split_rank is not None:\n global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK\n _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank\n\n rank = torch.distributed.get_rank()\n\n # Build the data-parallel groups.\n global _DATA_PARALLEL_GROUP\n global _DATA_PARALLEL_GROUP_GLOO\n global _DATA_PARALLEL_GLOBAL_RANKS\n assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'\n all_data_parallel_group_ranks = []\n for i in range(pipeline_model_parallel_size):\n start_rank = i * num_pipeline_model_parallel_groups\n end_rank = (i + 1) * num_pipeline_model_parallel_groups\n for j in range(tensor_model_parallel_size):\n ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)\n all_data_parallel_group_ranks.append(list(ranks))\n group = torch.distributed.new_group(ranks)\n group_gloo = torch.distributed.new_group(ranks, backend=\"gloo\")\n if rank in ranks:\n _DATA_PARALLEL_GROUP = group\n _DATA_PARALLEL_GROUP_GLOO = group_gloo\n _DATA_PARALLEL_GLOBAL_RANKS = ranks\n\n # Build the model-parallel groups.\n global _MODEL_PARALLEL_GROUP\n assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'\n for i in range(data_parallel_size):\n ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _MODEL_PARALLEL_GROUP = group\n\n # Build the tensor model-parallel groups.\n global _TENSOR_MODEL_PARALLEL_GROUP\n assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized'\n for i in range(num_tensor_model_parallel_groups):\n ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _TENSOR_MODEL_PARALLEL_GROUP = group\n\n # Build the pipeline model-parallel groups and embedding groups\n # (first and last rank in each pipeline model-parallel group).\n global _PIPELINE_MODEL_PARALLEL_GROUP\n global _PIPELINE_GLOBAL_RANKS\n assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized'\n global _EMBEDDING_GROUP\n global _EMBEDDING_GLOBAL_RANKS\n assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'\n global _POSITION_EMBEDDING_GROUP\n global _POSITION_EMBEDDING_GLOBAL_RANKS\n assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'\n for i in range(num_pipeline_model_parallel_groups):\n ranks = range(i, world_size, num_pipeline_model_parallel_groups)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _PIPELINE_MODEL_PARALLEL_GROUP = group\n _PIPELINE_GLOBAL_RANKS = ranks\n # Setup embedding group (to exchange gradients between\n # first and last stages).\n if len(ranks) > 1:\n embedding_ranks = [ranks[0], ranks[-1]]\n position_embedding_ranks = [ranks[0]]\n if pipeline_model_parallel_split_rank is not None:\n if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:\n embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]]\n if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:\n position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]\n else:\n embedding_ranks = ranks\n position_embedding_ranks = ranks\n\n group = torch.distributed.new_group(embedding_ranks)\n if rank in embedding_ranks:\n _EMBEDDING_GROUP = group\n if rank in ranks:\n _EMBEDDING_GLOBAL_RANKS = embedding_ranks\n\n group = torch.distributed.new_group(position_embedding_ranks)\n if rank in position_embedding_ranks:\n _POSITION_EMBEDDING_GROUP = group\n if rank in ranks:\n _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks\n\n # Build the FP8 groups.\n global _AMAX_REDUCTION_GROUP\n assert _AMAX_REDUCTION_GROUP is None, \\\n 'FP8 amax reduction group is already initialized'\n if use_fp8:\n amax_group_size: int = tensor_model_parallel_size * data_parallel_size\n num_amax_groups: int = world_size // amax_group_size\n for i in range(num_amax_groups):\n start_rank = i * amax_group_size\n end_rank = (i + 1) * amax_group_size\n ranks = range(start_rank, end_rank)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _AMAX_REDUCTION_GROUP = group\n\n # Initialize global memory buffer\n # This isn't really \"parallel state\" but there isn't another good place to\n # put this. If we end up with a more generic initialization of megatron-core\n # we could stick it there\n _set_global_memory_buffer()\n```\n###test function signature:\n```python\n\ndef test_initialize__and_destroy_model_parallel():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_initialize__and_destroy_model_parallel", + "reference": "\ndef test_initialize__and_destroy_model_parallel():\n with pytest.raises(AssertionError):\n assert(ps.initialize_model_parallel())\n Utils.initialize_distributed()\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))\n Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)\n\n assert(ps.model_parallel_is_initialized())\n assert(ps.get_model_parallel_group() is not None)\n assert(ps.get_tensor_model_parallel_group() is not None)\n assert(ps.get_pipeline_model_parallel_group() is not None)\n assert(ps.get_data_parallel_group() is not None) \n Utils.destroy_model_parallel()\n assert(ps._MODEL_PARALLEL_GROUP is None)", + "completions": [ + "```python\ndef test_initialize_and_destroy_model_parallel():\n initialize_model_parallel(\n tensor_model_parallel_size=1,\n pipeline_model_parallel_size=1,\n virtual_pipeline_model_parallel_size=None,\n pipeline_model_parallel_split_rank=None,\n use_fp8=False,\n )\n destroy_model_parallel()\n```\n" + ], + "line": 133, + "token": 886, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_shapes(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_shapes", + "reference": "\n def test_spectrogram_shapes(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 729))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 183))", + "completions": [ + "```python\n def test_spectrogram_shapes(self):\n waveform = np.random.rand(1000)\n window = np.hanning(100)\n frame_length = 100\n hop_length = 10\n fft_length = 100\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (10, 100))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_base64_image_given_and_no_resize_needed", + "reference": "\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n _, image = example_local_image\n _, buffer = cv2.imencode(\".jpg\", image)\n base64_image = base64.b64encode(buffer).decode(\"ascii\")\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=base64_image,\n max_width=None,\n max_height=None,\n )\n\n # then\n assert (\n scaling_factor is None\n ), \"No resize parameters given, so scaling factor should not be established\"\n assert (\n serialised_image == base64_image\n ), \"Serialised image should be identical with input\"", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n image_bytes = base64.b64decode(example_local_image[0])\n image = bytes_to_opencv_image(payload=image_bytes)\n # when\n result = load_image_from_string(\n reference=example_local_image[0],\n max_height=None,\n max_width=None,\n )\n # then\n assert result[0] == example_local_image[0]\n assert result[1] is None\n assert np.array_equal(image, result[1])\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 69, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def submit_payout_address(self, robot_index=1):\n path = reverse(\"order\")\n params = f\"?order_id={self.order_id}\"\n headers = self.get_robot_auth(robot_index)\n\n payout_address = create_address(\"robot\")\n signed_payout_address = sign_message(\n payout_address,\n passphrase_path=f\"tests/robots/{robot_index}/token\",\n private_key_path=f\"tests/robots/{robot_index}/enc_priv_key\",\n )\n body = {\n \"action\": \"update_address\",\n \"address\": signed_payout_address,\n \"mining_fee_rate\": 50,\n }\n self.response = self.client.post(path + params, body, **headers)\n```\n###test function signature:\n```python\n def test_successful_onchain(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_successful_onchain", + "reference": " def test_successful_onchain(self):\n trade = Trade(self.client)\n trade.publish_order()\n trade.take_order()\n trade.lock_taker_bond()\n trade.lock_escrow(trade.taker_index)\n trade.submit_payout_address(trade.maker_index)\n trade.confirm_fiat(trade.maker_index)\n trade.confirm_fiat(trade.taker_index)\n\n trade.process_payouts(mine_a_block=True)\n trade.get_order(trade.maker_index)\n\n data = trade.response.json()\n\n self.assertEqual(trade.response.status_code, 200)\n self.assertResponse(trade.response)\n\n self.assertEqual(data[\"status_message\"], Order.Status(Order.Status.SUC).label)\n self.assertTrue(data[\"is_fiat_sent\"])\n self.assertFalse(data[\"is_disputed\"])\n self.assertIsInstance(data[\"maker_summary\"][\"address\"], str)\n self.assertIsHash(data[\"maker_summary\"][\"txid\"])\n\n self.assert_order_logs(data[\"id\"])", + "completions": [ + "```python\n def test_successful_onchain(self):\n self.submit_payout_address(robot_index=1)\n self.assertEqual(self.response.status_code, 200)\n self.assertEqual(self.response.json()[\"status\"], \"success\")\n```\n" + ], + "line": 26, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 70, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_blank_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_blank_line", + "reference": "\ndef test_header_to_metadata_and_cell_blank_line():\n text = \"\"\"---\ntitle: Sample header\n---\n\nHeader is followed by a blank line\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {}\n assert lines[pos].startswith(\"Header is\")", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_blank_line():\n lines = [\"#' ---\", \"\"]\n header_prefix = \"#'\"\n header_suffix = \"\"\n ext = None\n root_level_metadata_as_raw_cell = True\n metadata, jupyter, cell, i = header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext, root_level_metadata_as_raw_cell\n )\n assert metadata == {}\n assert jupyter == []\n assert cell == new_raw_cell(source=\"\\n\".join([\"---\"]), metadata={\"lines_to_next_cell\": 1})\n assert i == 1\n```\n" + ], + "line": 104, + "token": 642, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 71, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def _create_output(\n self,\n accepted: torch.Tensor, # [batch_size, k]\n recovered_token_ids: torch.Tensor, # [batch_size, k]\n draft_token_ids: torch.Tensor, # [batch_size, k]\n bonus_token_ids: torch.Tensor, # [batch_size]\n ) -> torch.Tensor:\n bonus_token_ids = bonus_token_ids.squeeze()\n batch_size, k = recovered_token_ids.shape\n\n # Determine the index of the first False value for each row.\n limits = (accepted == 0).max(1).indices\n limits[~(accepted == 0).any(1)] = k\n\n # Create masks using the indices.\n indices = torch.arange(k, device=accepted.device).unsqueeze(0)\n accepted_mask = indices < limits.unsqueeze(1)\n after_false_mask = indices == limits.unsqueeze(1)\n\n # Create an extended output tensor\n output_with_bonus_tokens = -torch.ones(\n (batch_size, k + self._num_bonus_tokens),\n dtype=self.token_id_dtype,\n device=accepted.device)\n output = output_with_bonus_tokens[:, :k]\n\n # Fill in the first k columns of the output tensor using masks and data\n # tensors.\n output[:, :k] = torch.where(accepted_mask, draft_token_ids,\n -torch.ones_like(draft_token_ids))\n\n # Fill the last column.\n # We check output directly as accepted may have True values inconsistent\n # with causal acceptance.\n output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,\n bonus_token_ids, -1)\n\n # We disable bonus tokens because it causes corrupt KV cache for\n # proposal methods that require KV cache. We can fix it by \"prefilling\"\n # the bonus token in the proposer. The following issue tracks the fix.\n # https://github.com/vllm-project/vllm/issues/4212\n output_with_bonus_tokens[:, -1] = -1\n\n # Fill the recovered token ids.\n output.mul_(~after_false_mask).add_(\n recovered_token_ids.mul(after_false_mask))\n\n self.num_accepted_tokens += accepted.sum()\n self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()\n self.num_draft_tokens += batch_size * k\n\n return output_with_bonus_tokens\n```\n###test function signature:\n```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct_output_format", + "reference": "def test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n set_random_seed(seed)\n torch.set_default_device(device)\n\n batch_size = 10\n k = 5\n vocab_size = 3000\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"no_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"some_tokens_accepted\":\n last_accepted_indices = torch.randint(low=-1,\n high=k,\n size=(batch_size, ))\n accepted = mock_causal_accepted_tensor(k, last_accepted_indices)\n else:\n raise AssertionError()\n\n recovered_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n draft_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n bonus_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, 1),\n dtype=torch.int64)\n\n rejection_sampler = RejectionSampler()\n rejection_sampler.init_gpu_tensors(rank=0)\n output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access\n accepted,\n recovered_token_ids,\n draft_token_ids,\n bonus_token_ids,\n )\n\n # Bonus tokens are currently disabled. Verify they're set to -1.\n # See https://github.com/vllm-project/vllm/issues/4212\n expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n # Expect all tokens to be equal to draft tokens.\n assert torch.equal(output_token_ids[:, :-1], draft_token_ids)\n\n # Expect all bonus tokens to be included.\n assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)\n elif which_tokens_accepted == \"no_tokens_accepted\":\n # Expect first token to be equal to recovered tokens.\n assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])\n\n # Expect everything else to be -1.\n assert torch.equal(output_token_ids[:, 1:],\n torch.ones_like(output_token_ids[:, 1:]) * -1)\n elif which_tokens_accepted == \"some_tokens_accepted\":\n recovered_plus_bonus = torch.cat(\n (recovered_token_ids, expected_bonus_token_ids), dim=-1)\n # Assert first rejected token is a recovered token or bonus token.\n assert torch.equal(\n recovered_plus_bonus[torch.arange(0, batch_size),\n last_accepted_indices + 1],\n output_token_ids[torch.arange(0, batch_size),\n last_accepted_indices + 1])\n\n # Assert every subsequent token is -1.\n subsequent_mask = torch.arange(0, k + 1).expand(\n batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)\n assert torch.all(output_token_ids[subsequent_mask] == -1)", + "completions": [ + "```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n torch.manual_seed(seed)\n batch_size = 10\n k = 10\n num_bonus_tokens = 1\n accepted = torch.randint(0, 2, (batch_size, k), dtype=torch.int32)\n recovered_token_ids = torch.randint(0, 10, (batch_size, k),\n dtype=torch.int32)\n draft_token_ids = torch.randint(0, 10, (batch_size, k), dtype=torch.int32)\n bonus_token_ids = torch.randint(0, 10, (batch_size,), dtype=torch.int32)\n output = _create_output(accepted, recovered_token_ids, draft_token_ids,\n bonus_token_ids)\n assert output.shape == (batch_size, k + num_bonus_tokens)\n assert output.dtype == torch.int32\n assert output.device == device\n assert (output[:, :k] == torch.where(accepted, draft_token_ids,\n -torch.ones_like(draft_token_ids))).all()\n assert (output[:, -1] == torch.where(output[:, -1] != -1, bonus_token_ids,\n -1)).all()\n assert (output[:, -1] == -1).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted" + ], + "line": 54, + "token": 433, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 72, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_array(fp, allow_pickle=False, pickle_kwargs=None, *,\n max_header_size=_MAX_HEADER_SIZE):\n if allow_pickle:\n # Effectively ignore max_header_size, since `allow_pickle` indicates\n # that the input is fully trusted.\n max_header_size = 2**64\n\n version = read_magic(fp)\n _check_version(version)\n shape, fortran_order, dtype = _read_array_header(\n fp, version, max_header_size=max_header_size)\n if len(shape) == 0:\n count = 1\n else:\n count = numpy.multiply.reduce(shape, dtype=numpy.int64)\n\n # Now read the actual data.\n if dtype.hasobject:\n # The array contained Python objects. We need to unpickle the data.\n if not allow_pickle:\n raise ValueError(\"Object arrays cannot be loaded when \"\n \"allow_pickle=False\")\n if pickle_kwargs is None:\n pickle_kwargs = {}\n try:\n array = pickle.load(fp, **pickle_kwargs)\n except UnicodeError as err:\n # Friendlier error message\n raise UnicodeError(\"Unpickling a python object failed: %r\\n\"\n \"You may need to pass the encoding= option \"\n \"to numpy.load\" % (err,)) from err\n else:\n if isfileobj(fp):\n # We can use the fast fromfile() function.\n array = numpy.fromfile(fp, dtype=dtype, count=count)\n else:\n # This is not a real file. We have to read it the\n # memory-intensive way.\n # crc32 module fails on reads greater than 2 ** 32 bytes,\n # breaking large reads from gzip streams. Chunk reads to\n # BUFFER_SIZE bytes to avoid issue and reduce memory overhead\n # of the read. In non-chunked case count < max_read_count, so\n # only one read is performed.\n\n # Use np.ndarray instead of np.empty since the latter does\n # not correctly instantiate zero-width string dtypes; see\n # https://github.com/numpy/numpy/pull/6430\n array = numpy.ndarray(count, dtype=dtype)\n\n if dtype.itemsize > 0:\n # If dtype.itemsize == 0 then there's nothing more to read\n max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)\n\n for i in range(0, count, max_read_count):\n read_count = min(max_read_count, count - i)\n read_size = int(read_count * dtype.itemsize)\n data = _read_bytes(fp, read_size, \"array data\")\n array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n count=read_count)\n\n if fortran_order:\n array.shape = shape[::-1]\n array = array.transpose()\n else:\n array.shape = shape\n\n return array\n```\n###test function signature:\n```python\n\ndef test_version_2_0():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_version_2_0", + "reference": "\ndef test_version_2_0():\n f = BytesIO()\n # requires more than 2 byte for header\n dt = [((\"%d\" % i) * 100, float) for i in range(500)]\n d = np.ones(1000, dtype=dt)\n\n format.write_array(f, d, version=(2, 0))\n with warnings.catch_warnings(record=True) as w:\n warnings.filterwarnings('always', '', UserWarning)\n format.write_array(f, d)\n assert_(w[0].category is UserWarning)\n\n # check alignment of data portion\n f.seek(0)\n header = f.readline()\n assert_(len(header) % format.ARRAY_ALIGN == 0)\n\n f.seek(0)\n n = format.read_array(f, max_header_size=200000)\n assert_array_equal(d, n)\n\n # 1.0 requested but data cannot be saved this way\n assert_raises(ValueError, format.write_array, f, d, (1, 0))", + "completions": [ + "```python\ndef test_version_2_0():\n fp = open(\"test_file.npy\", \"rb\")\n read_array(fp)\n```\n" + ], + "line": 70, + "token": 516, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 73, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):\n if not torch.is_tensor(audio):\n if isinstance(audio, str):\n audio = load_audio(audio)\n audio = torch.from_numpy(audio)\n\n window = torch.hann_window(N_FFT).to(audio.device)\n stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)\n magnitudes = stft[..., :-1].abs() ** 2\n\n filters = mel_filters(audio.device, n_mels)\n mel_spec = filters @ magnitudes\n\n log_spec = torch.clamp(mel_spec, min=1e-10).log10()\n log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)\n log_spec = (log_spec + 4.0) / 4.0\n return log_spec\n```\n###test function signature:\n```python\n\ndef test_audio():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_audio", + "reference": "\ndef test_audio():\n audio_path = os.path.join(os.path.dirname(__file__), \"jfk.flac\")\n audio = load_audio(audio_path)\n assert audio.ndim == 1\n assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12\n assert 0 < audio.std() < 1\n\n mel_from_audio = log_mel_spectrogram(audio)\n mel_from_file = log_mel_spectrogram(audio_path)\n\n assert np.allclose(mel_from_audio, mel_from_file)\n assert mel_from_audio.max() - mel_from_audio.min() <= 2.0", + "completions": [ + "```python\ndef test_audio():\n audio = torch.randn(1, 16000)\n log_mel_spectrogram(audio)\n```\n" + ], + "line": 24, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 74, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef transform_grid_cell(\n grid: Grid,\n cell: npt.NDArray[np.float_] = np.eye(3),\n output_shape: Optional[Tuple[int, int, int]] = None,\n mode: str = \"constant\",\n order: int = 1,\n **kwargs,\n) -> Grid:\n # Take the current shape of the grid if no output shape was provided\n if output_shape is None:\n output_shape = grid.shape\n\n # Make sure the cell has type float\n cell = np.asarray(cell, dtype=float)\n\n # Get the current cell in coordinates of the destination axes\n inv_cell = cell_invert(cell).T\n projected_cell = grid.cell.dot(inv_cell)\n\n # From that, infere how long will the bounding box of the cell be\n lengths = abs(projected_cell).sum(axis=0)\n\n # Create the transformation matrix. Since we want to control the shape\n # of the output, we can not use grid.dcell directly, we need to modify it.\n scales = output_shape / lengths\n forward_t = (grid.dcell.dot(inv_cell) * scales).T\n\n # Scipy's affine transform asks for the inverse transformation matrix, to\n # map from output pixels to input pixels. By taking the inverse of our\n # transformation matrix, we get exactly that.\n tr = cell_invert(forward_t).T\n\n # Calculate the offset of the image so that all points of the grid \"fall\" inside\n # the output array.\n # For this we just calculate the centers of the input and output images\n center_input = 0.5 * (_a.asarrayd(grid.shape) - 1)\n center_output = 0.5 * (_a.asarrayd(output_shape) - 1)\n\n # And then make sure that the input center that is interpolated from the output\n # falls in the actual input's center\n offset = center_input - tr.dot(center_output)\n\n # We pass all the parameters to scipy's affine_transform\n transformed_image = affine_transform(\n grid.grid,\n tr,\n order=1,\n offset=offset,\n output_shape=output_shape,\n mode=mode,\n **kwargs,\n )\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1))\n new_grid.grid = transformed_image\n new_grid.geometry = grid.geometry\n new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset))\n\n # Find the offset between the origin before and after the transformation\n return new_grid\n```\n###test function signature:\n```python\n\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_grid_cell", + "reference": "\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10))\n\n assert new_grid.shape == (10, 10, 10)\n assert new_grid.lattice.is_cartesian()\n\n if not skewed:\n assert np.allclose(new_grid.lattice.cell, grid.lattice.cell)\n\n assert np.allclose(new_grid.grid, grid.grid) == (not skewed)\n assert not np.allclose(new_grid.grid, 0)\n\n assert np.allclose(new_grid.origin, grid.origin) == (not skewed)\n\n # Convert to a skewed cell\n directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]])\n new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5))\n\n assert new_grid.shape == (5, 5, 5)\n for i in range(3):\n n = new_grid.lattice.cell[i] / directions[i]\n assert np.allclose(n, n[0])", + "completions": [ + "```python\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n cell = grid.cell.dot(cell_invert(grid.cell))\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = transform_grid_cell(grid, cell, grid.shape, mode=\"constant\", order=1)\n\n # Check that the new grid has the same shape and cell as the original grid\n assert new_grid.shape == grid.shape\n assert np.allclose(new_grid.cell, grid.cell)\n\n # Check that the new grid has the same geometry as the original grid\n assert new_grid.geometry == grid.geometry\n\n # Check that the new grid has the same origin as the original grid\n assert np.allclose(new_grid.origin, grid.origin)\n\n # Check that the new grid has the same dcell as the original grid\n assert np.allclose(new_grid.dcell, grid.dcell)\n\n # Check that the new grid has the same grid as the original grid\n assert np.allclose(new_grid.grid, grid.grid)\n\n # Check that the new grid has the same skewed as the original grid\n assert np.allclose(new_grid.skewed, grid.skewed)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid" + ], + "line": 63, + "token": 533, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:\n\n index_norm = norm(index, axis=1)\n xq_norm = norm(xq.T)\n sim = np.dot(index, xq.T) / (index_norm * xq_norm)\n return sim\n```\n###test function signature:\n```python\ndef test_top_scores__is_sorted(test_index):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_top_scores__is_sorted", + "reference": "def test_top_scores__is_sorted(test_index):\n\n xq = test_index[0] # should have max similarity\n\n sim = similarity_matrix(xq, test_index)\n _, idx = top_scores(sim, 3)\n\n # Scores and indexes should be sorted ascending\n assert np.array_equal(idx, np.array([2, 1, 0]))", + "completions": [ + "```python\ndef test_top_scores__is_sorted(test_index):\n sim = similarity_matrix(xq, test_index)\n assert np.all(np.diff(sim, axis=1) >= 0)\n```\n" + ], + "line": 15, + "token": 125, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_mixed_ref_area_2", + "reference": "\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n\n out_bars_with_mixed_ref_area_2 = _generate_nanoplot(**CASES[16])\n\n assert _is_nanoplot_output(out_bars_with_mixed_ref_area_2)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n y_ref_area = [1, 10]\n y_ref_line = 5\n x_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n expand_x = \"2020-01-01 00:00:00 UTC\"\n expand_y = \"2020-01-01 00:00:00 UTC\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n all_single_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_" + ], + "line": 749, + "token": 7038, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 77, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def mean(self, lower=0.0, upper=np.inf):\n loc = []\n for idx in self.index:\n qpd = self.qpd.loc[idx, :].values[0]\n l, _ = quad(exp_func, args=(qpd), a=lower, b=upper)\n loc.append(l)\n loc_arr = np.array(loc)\n return pd.DataFrame(loc_arr, index=self.index, columns=self.columns)\n```\n###test function signature:\n```python\ndef test_qpd_b_simple_use():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_qpd_b_simple_use", + "reference": "def test_qpd_b_simple_use():\n from skpro.distributions.qpd import QPD_B\n\n qpd = QPD_B(\n alpha=0.2,\n qv_low=[1, 2],\n qv_median=[3, 4],\n qv_high=[5, 6],\n lower=0,\n upper=10,\n )\n\n qpd.mean()", + "completions": [ + "```python\ndef test_qpd_b_simple_use():\n import numpy as np\n import pandas as pd\n from scipy.integrate import quad\n from scipy.special import expit\n\n def exp_func(x, qpd):\n return expit(qpd * x)\n\n qpd = pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z'])\n qpd_b = qpd.mean(lower=0.0, upper=np.inf)\n assert qpd_b.equals(pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z']))\n```\n" + ], + "line": 18, + "token": 146, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef create(\n supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False\n) -> adapters.HTTPAdapter:\n settings = supply_settings()\n allowed_methods = set(util.Retry.DEFAULT_ALLOWED_METHODS)\n if include_post:\n allowed_methods.add(\"POST\")\n\n retries = util.Retry(\n total=settings.total,\n backoff_factor=settings.backoff_factor,\n status_forcelist=settings.statuses,\n allowed_methods=allowed_methods,\n )\n return adapters.HTTPAdapter(max_retries=retries)\n```\n###test function signature:\n```python\n\ndef test_should_create_adapter(include_post):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_should_create_adapter", + "reference": "\ndef test_should_create_adapter(include_post):\n total_retries = 5\n backoff_factor = 0.4\n statuses = (random.randint(400, 600),)\n settings = config.ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses)\n adapter = http_adapter.create(supply_settings=lambda: settings, include_post=include_post)\n assert adapter.max_retries.total == total_retries\n assert adapter.max_retries.backoff_factor == backoff_factor\n assert adapter.max_retries.status_forcelist == statuses\n allowed_methods = adapter.max_retries.allowed_methods or []\n if include_post:\n assert \"POST\" in allowed_methods\n else:\n assert \"POST\" not in allowed_methods", + "completions": [ + "```python\n\ndef test_should_create_adapter(include_post):\n assert create(include_post=include_post) == adapters.HTTPAdapter(max_retries=util.Retry(\n total=config.ResilienceSettings().total,\n backoff_factor=config.ResilienceSettings().backoff_factor,\n status_forcelist=config.ResilienceSettings().statuses,\n allowed_methods=set(util.Retry.DEFAULT_ALLOWED_METHODS) if not include_post else set(util.Retry.DEFAULT_ALLOWED_METHODS).union({\"POST\"})\n ))\n```\n" + ], + "line": 24, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 79, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:\n summary_path = os.path.join(trial_path, 'summary.csv')\n if not os.path.exists(summary_path):\n raise ValueError(f\"summary.csv does not exist in {trial_path}.\")\n trial_summary_df = load_summary_file(summary_path, dict_columns=['best_module_params'])\n config_yaml_path = os.path.join(trial_path, 'config.yaml')\n with open(config_yaml_path, 'r') as f:\n config_dict = yaml.safe_load(f)\n yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)\n if output_path is not None:\n with open(output_path, 'w') as f:\n yaml.dump(yaml_dict, f)\n return yaml_dict\n```\n###test function signature:\n```python\n\ndef test_runner(evaluator):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_runner", + "reference": "\ndef test_runner(evaluator):\n os.environ['BM25'] = 'bm25'\n evaluator.start_trial(os.path.join(resource_dir, 'simple.yaml'))\n project_dir = evaluator.project_dir\n\n def runner_test(runner: Runner):\n answer = runner.run('What is the best movie in Korea? Have Korea movie ever won Oscar?',\n 'retrieved_contents')\n assert len(answer) == 1\n assert isinstance(answer, list)\n assert isinstance(answer[0], str)\n\n runner = Runner.from_trial_folder(os.path.join(project_dir, '0'))\n runner_test(runner)\n runner_test(runner)\n\n with tempfile.NamedTemporaryFile(suffix='yaml', mode='w+t') as yaml_path:\n extract_best_config(os.path.join(project_dir, '0'), yaml_path.name)\n runner = Runner.from_yaml(yaml_path.name, project_dir=project_dir)\n runner_test(runner)", + "completions": [ + "```python\ndef test_runner(evaluator):\n def test_extract_best_config():\n trial_path = 'path/to/trial'\n output_path = 'path/to/output'\n yaml_dict = extract_best_config(trial_path, output_path)\n assert yaml_dict is not None\n assert isinstance(yaml_dict, dict)\n assert 'best_module_params' in yaml_dict\n assert isinstance(yaml_dict['best_module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['" + ], + "line": 23, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 80, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def init_weights(self):\n logger = MMLogger.get_current_instance()\n if self.init_cfg is None:\n logger.warn(f'No pre-trained weights for '\n f'{self.__class__.__name__}, '\n f'training start from scratch')\n if self.use_abs_pos_embed:\n trunc_normal_(self.absolute_pos_embed, std=0.02)\n for m in self.modules():\n if isinstance(m, nn.Linear):\n trunc_normal_init(m, std=.02, bias=0.)\n elif isinstance(m, nn.LayerNorm):\n constant_init(m, 1.0)\n else:\n assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n f'specify `Pretrained` in ' \\\n f'`init_cfg` in ' \\\n f'{self.__class__.__name__} '\n ckpt = CheckpointLoader.load_checkpoint(\n self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n if 'state_dict' in ckpt:\n _state_dict = ckpt['state_dict']\n elif 'model' in ckpt:\n _state_dict = ckpt['model']\n else:\n _state_dict = ckpt\n if self.convert_weights:\n # supported loading weight from original repo,\n _state_dict = swin_converter(_state_dict)\n\n state_dict = OrderedDict()\n for k, v in _state_dict.items():\n if k.startswith('backbone.'):\n state_dict[k[9:]] = v\n\n # strip prefix of state_dict\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n # reshape absolute position embedding\n if state_dict.get('absolute_pos_embed') is not None:\n absolute_pos_embed = state_dict['absolute_pos_embed']\n N1, L, C1 = absolute_pos_embed.size()\n N2, C2, H, W = self.absolute_pos_embed.size()\n if N1 != N2 or C1 != C2 or L != H * W:\n logger.warning('Error in loading absolute_pos_embed, pass')\n else:\n state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n # interpolate position bias table if needed\n relative_position_bias_table_keys = [\n k for k in state_dict.keys()\n if 'relative_position_bias_table' in k\n ]\n for table_key in relative_position_bias_table_keys:\n table_pretrained = state_dict[table_key]\n table_current = self.state_dict()[table_key]\n L1, nH1 = table_pretrained.size()\n L2, nH2 = table_current.size()\n if nH1 != nH2:\n logger.warning(f'Error in loading {table_key}, pass')\n elif L1 != L2:\n S1 = int(L1**0.5)\n S2 = int(L2**0.5)\n table_pretrained_resized = F.interpolate(\n table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n size=(S2, S2),\n mode='bicubic')\n state_dict[table_key] = table_pretrained_resized.view(\n nH2, L2).permute(1, 0).contiguous()\n\n # load state_dict\n self.load_state_dict(state_dict, False)\n```\n###test function signature:\n```python\ndef test_swin_transformer():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_swin_transformer", + "reference": "def test_swin_transformer():\n\n with pytest.raises(TypeError):\n # Pretrained arg must be str or None.\n SwinTransformer(pretrained=123)\n\n with pytest.raises(AssertionError):\n # Because swin uses non-overlapping patch embed, so the stride of patch\n # embed must be equal to patch size.\n SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)\n\n # test pretrained image size\n with pytest.raises(AssertionError):\n SwinTransformer(pretrain_img_size=(224, 224, 224))\n\n # Test absolute position embedding\n temp = torch.randn((1, 3, 224, 224))\n model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)\n model.init_weights()\n model(temp)\n\n # Test patch norm\n model = SwinTransformer(patch_norm=False)\n model(temp)\n\n # Test normal inference\n temp = torch.randn((1, 3, 32, 32))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 31, 31))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 112, 137))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 28, 35)\n assert outs[1].shape == (1, 192, 14, 18)\n assert outs[2].shape == (1, 384, 7, 9)\n assert outs[3].shape == (1, 768, 4, 5)\n\n model = SwinTransformer(frozen_stages=4)\n model.train()\n for p in model.parameters():\n assert not p.requires_grad", + "completions": [ + "```python\ndef test_swin_transformer():\n from mmcv.cnn import (build_pos_embed, build_pos_embed_sine,\n constant_init, trunc_normal_init)\n from mmcv.runner import CheckpointLoader\n from mmdet.models.builder import build_backbone\n from mmdet.models.utils import swin_converter\n from mmdet.utils import get_root_logger\n from mmdet.utils.logger import MMLogger\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import PatchEmbed\n from mmdet.models.backbones.swin import PatchMerging\n from mmdet.models.backbones.swin import BasicLayer\n from mmdet.models.backbones.swin import BasicBlock\n from mmdet.models.backbones.swin import DepthwiseSeparableConv\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransform" + ], + "line": 79, + "token": 586, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 81, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\n\ndef test_date_columns_equal():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_date_columns_equal", + "reference": "\ndef test_date_columns_equal():\n data = \"\"\"a|b|expected\n2017-01-01|2017-01-01|True\n2017-01-02|2017-01-02|True\n2017-10-01|2017-10-10|False\n2017-01-01||False\n|2017-01-01|False\n||True\"\"\"\n df = pl.read_csv(\n io.StringIO(data),\n separator=\"|\",\n null_values=[\"NULL\"],\n missing_utf8_is_empty_string=True,\n )\n # First compare just the strings\n actual_out = columns_equal(df[\"a\"], df[\"b\"], rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)\n\n # Then compare converted to datetime objects\n col_a = df[\"a\"].str.to_datetime(strict=False)\n col_b = df[\"b\"].str.to_datetime(strict=False)\n actual_out = columns_equal(col_a, col_b, rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)", + "completions": [ + "```python\ndef test_date_columns_equal():\n col_1 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n col_2 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 82, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def upsert_child_firing(\n db: PrefectDBInterface,\n session: AsyncSession,\n firing: Firing,\n):\n automation_id = firing.trigger.automation.id\n parent_trigger_id = firing.trigger.parent.id\n child_trigger_id = firing.trigger.id\n\n upsert = (\n postgresql.insert(db.CompositeTriggerChildFiring)\n .values(\n automation_id=automation_id,\n parent_trigger_id=parent_trigger_id,\n child_trigger_id=child_trigger_id,\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n )\n .on_conflict_do_update(\n index_elements=[\n db.CompositeTriggerChildFiring.automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id,\n ],\n set_=dict(\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n updated=pendulum.now(\"UTC\"),\n ),\n )\n )\n\n await session.execute(upsert)\n\n result = await session.execute(\n sa.select(db.CompositeTriggerChildFiring).filter(\n db.CompositeTriggerChildFiring.automation_id == automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id == parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id == child_trigger_id,\n )\n )\n\n return result.scalars().one()\n```\n###test function signature:\n```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_insert_compound_trigger_child_firing", + "reference": "\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n firing = Firing(\n trigger=saw_a_red_dragon,\n trigger_states={TriggerState.Triggered},\n triggered=frozen_time.add(seconds=1),\n triggering_labels={},\n triggering_event=baby_red_dragon_passed_by,\n )\n res = await upsert_child_firing(\n session=automations_session,\n firing=firing,\n )\n\n assert res and res.id\n assert res.automation_id == raise_the_alarm.id\n assert res.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert res.child_trigger_id == saw_a_red_dragon.id\n assert res.child_firing_id == firing.id\n assert res.child_fired_at == firing.triggered\n\n child_firings = await get_child_firings(\n session=automations_session, trigger=saw_a_red_and_green_dragon\n )\n assert len(child_firings) == 1\n assert child_firings[0].id == res.id\n assert child_firings[0].child_firing_id == firing.id", + "completions": [ + "```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n \"\"\"\n Test that a compound trigger child firing is inserted correctly\n \"\"\"\n # Arrange\n firing = Firing(\n id=uuid4(),\n triggered=frozen_time,\n trigger=saw_a_red_and_green_dragon,\n event=baby_red_dragon_passed_by,\n )\n\n # Act\n result = await upsert_child_firing(\n db=automations_session,\n session=automations_session,\n firing=firing,\n )\n\n # Assert\n assert result.automation_id == raise_the_alarm.id\n assert result.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert result.child_trigger_id == saw_a_red_dragon.id\n assert result.child_firing_id == firing.id\n assert result.child_fired_at == firing.triggered\n assert result.child_firing == firing.dict()\n assert result.updated == frozen_time\n\n```\n" + ], + "line": 58, + "token": 225, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 83, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef write_output(output_file, output_type, results):\n\n # If json or txt, write contents directly to file.\n if output_type == \"json\" or output_type == \"txt\":\n with open(output_file, \"w\") as f:\n json.dump(results, f, indent=4)\n return\n # If csv or xlsx, convert results to pandas dataframes.\n urls_df = get_urls_df(results)\n codes_df = get_codes_df(results)\n\n # If csv, write dataframes to separate csv files for urls, codes.\n if output_type == \"csv\":\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n urls_df.to_csv(urls_output_file, index=False)\n codes_df.to_csv(codes_output_file, index=False)\n\n # If xlsx, write dataframes to separate sheets for urls, codes.\n if output_type == \"xlsx\":\n writer = pd.ExcelWriter(output_file, engine=\"xlsxwriter\")\n urls_df.to_excel(writer, sheet_name=\"URLs\", index=False)\n codes_df.to_excel(writer, sheet_name=\"Codes\", index=False)\n writer.close()\n```\n###test function signature:\n```python\n def test_write_output_csv(self, mock_urls, mock_codes):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_output_csv", + "reference": " def test_write_output_csv(self, mock_urls, mock_codes):\n\n test_file = \"./test_output/test_file.csv\"\n test_file_urls = \"./test_output/test_file_urls.csv\"\n test_file_codes = \"./test_output/test_file_codes.csv\"\n test_results = {\"test\": \"test\"}\n mock_urls.return_value = pd.DataFrame([test_results])\n mock_codes.return_value = pd.DataFrame([test_results])\n\n with open(test_file_urls, \"w\") as f:\n pass\n\n with open(test_file_codes, \"w\") as f:\n pass\n\n write_output(test_file, \"csv\", test_results)\n\n test_data_urls = pd.read_csv(test_file_urls).to_dict(orient=\"records\")[0]\n test_data_codes = pd.read_csv(test_file_codes).to_dict(orient=\"records\")[0]\n\n os.remove(test_file_urls)\n os.remove(test_file_codes)\n\n self.assertEqual(test_data_urls, test_results)\n self.assertEqual(test_data_codes, test_results)", + "completions": [ + "```python\ndef test_write_output_csv(self, mock_urls, mock_codes):\n \"\"\"\n Test write_output function with csv output type.\n \"\"\"\n output_file = \"output.csv\"\n output_type = \"csv\"\n results = {\"urls\": mock_urls, \"codes\": mock_codes}\n write_output(output_file, output_type, results)\n\n # Check that the urls and codes dataframes are written to separate csv files.\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n self.assertTrue(os.path.isfile(urls_output_file))\n self.assertTrue(os.path.isfile(codes_output_file))\n\n # Check that the dataframes are correct.\n urls_df = pd.read_csv(urls_output_file)\n codes_df = pd.read_csv(codes_output_file)\n self.assertEqual(urls_df.shape, (2, 2))\n self.assertEqual(codes_df.shape, (2, 2))\n self.assertEqual(urls_df.iloc[0, 0], \"url1\")\n self.assertEqual(codes_df.iloc[0, 0], \"code1\")\n self.assertEqual(urls_df.iloc[1, 0], \"url2\")\n self.assertEqual(codes_df.iloc[1, 0], \"code2\")\n```\n" + ], + "line": 31, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 84, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def terminate_idle_instance() -> None:\n async with get_session_ctx() as session:\n async with PROCESSING_POOL_LOCK:\n res = await session.execute(\n select(InstanceModel)\n .where(\n InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE,\n InstanceModel.deleted == False,\n InstanceModel.job == None, # noqa: E711\n InstanceModel.status == InstanceStatus.IDLE,\n )\n .options(joinedload(InstanceModel.project))\n )\n instances = res.scalars().all()\n\n for instance in instances:\n last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)\n if instance.last_job_processed_at is not None:\n last_time = instance.last_job_processed_at.replace(\n tzinfo=datetime.timezone.utc\n )\n\n idle_seconds = instance.termination_idle_time\n delta = datetime.timedelta(seconds=idle_seconds)\n\n current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc)\n\n if last_time + delta < current_time:\n jpd = JobProvisioningData.__response__.parse_raw(\n instance.job_provisioning_data\n )\n await terminate_job_provisioning_data_instance(\n project=instance.project, job_provisioning_data=jpd\n )\n instance.deleted = True\n instance.deleted_at = get_current_datetime()\n instance.finished_at = get_current_datetime()\n instance.status = InstanceStatus.TERMINATED\n instance.termination_reason = \"Idle timeout\"\n\n idle_time = current_time - last_time\n logger.info(\n \"Instance %s terminated by termination policy: idle time %ss\",\n instance.name,\n str(idle_time.seconds),\n extra={\n \"instance_name\": instance.name,\n \"instance_status\": InstanceStatus.TERMINATED.value,\n },\n )\n\n await session.commit()\n```\n###test function signature:\n```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_terminate_by_idle_timeout", + "reference": "\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n project = await create_project(session=session)\n pool = await create_pool(session, project)\n\n instance = await create_instance(session, project, pool, status=InstanceStatus.IDLE)\n instance.termination_idle_time = 300\n instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE\n instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)\n await session.commit()\n\n with patch(\n \"dstack._internal.server.background.tasks.process_instances.terminate_job_provisioning_data_instance\"\n ):\n await terminate_idle_instance()\n\n await session.refresh(instance)\n\n assert instance is not None\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"", + "completions": [ + "```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n # given\n project = ProjectModel(name=\"test_project\")\n await session.add(project)\n await session.commit()\n\n instance = InstanceModel(\n name=\"test_instance\",\n project=project,\n termination_policy=TerminationPolicy.DESTROY_AFTER_IDLE,\n termination_idle_time=10,\n status=InstanceStatus.IDLE,\n created_at=datetime.datetime.now(datetime.timezone.utc),\n )\n await session.add(instance)\n await session.commit()\n\n # when\n await terminate_idle_instance()\n\n # then\n instance = await session.get(InstanceModel, instance.id)\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"\n```\n" + ], + "line": 56, + "token": 289, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 85, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef nargsort(\n items,\n kind: str = \"quicksort\",\n ascending: bool = True,\n na_position: str = \"last\",\n key: Callable | None = None,\n mask: npt.NDArray[np.bool_] | None = None,\n) -> npt.NDArray[np.intp]:\n\n if key is not None:\n items = ensure_key_mapped(items, key)\n return nargsort(\n items,\n kind=kind,\n ascending=ascending,\n na_position=na_position,\n key=None,\n mask=mask,\n )\n\n if isinstance(items, ABCRangeIndex):\n return items.argsort(ascending=ascending) # TODO: test coverage with key?\n elif not isinstance(items, ABCMultiIndex):\n items = extract_array(items)\n if mask is None:\n mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?\n\n if is_extension_array_dtype(items):\n return items.argsort(ascending=ascending, kind=kind, na_position=na_position)\n else:\n items = np.asanyarray(items)\n\n idx = np.arange(len(items))\n non_nans = items[~mask]\n non_nan_idx = idx[~mask]\n\n nan_idx = np.nonzero(mask)[0]\n if not ascending:\n non_nans = non_nans[::-1]\n non_nan_idx = non_nan_idx[::-1]\n indexer = non_nan_idx[non_nans.argsort(kind=kind)]\n if not ascending:\n indexer = indexer[::-1]\n # Finally, place the NaNs at the end or the beginning according to\n # na_position\n if na_position == \"last\":\n indexer = np.concatenate([indexer, nan_idx])\n elif na_position == \"first\":\n indexer = np.concatenate([nan_idx, indexer])\n else:\n raise ValueError(f\"invalid na_position: {na_position}\")\n return ensure_platform_int(indexer)\n```\n###test function signature:\n```python\n\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nargsort", + "reference": "\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)\n\n # mergesort is the most difficult to get right because we want it to be\n # stable.\n\n # According to numpy/core/tests/test_multiarray, \"\"\"The number of\n # sorted items must be greater than ~50 to check the actual algorithm\n # because quick and merge sort fall over to insertion sort for small\n # arrays.\"\"\"\n\n result = nargsort(\n items, kind=\"mergesort\", ascending=ascending, na_position=na_position\n )\n tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)", + "completions": [ + "```python\n def test_nargsort(self, ascending, na_position, exp, box):\n items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216," + ], + "line": 58, + "token": 412, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 86, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_rotated_rb_circuits(\n n_qubits: int,\n num_cliffords: int,\n theta: Optional[float] = None,\n trials: int = 1,\n return_type: Optional[str] = None,\n seed: Optional[int] = None,\n) -> List[QPROGRAM]:\n\n circuits = cast(\n List[cirq.Circuit],\n generate_rb_circuits(n_qubits, num_cliffords, trials, seed=seed),\n )\n\n if theta is None:\n theta = 2 * np.pi * np.random.rand()\n\n for circ in circuits:\n qubits = list(circ.all_qubits())\n circ.insert(len(circ) // 2, cirq.Rz(rads=theta).on(qubits[0]))\n\n return_type = \"cirq\" if not return_type else return_type\n return [convert_from_mitiq(circuit, return_type) for circuit in circuits]\n```\n###test function signature:\n```python\n\ndef test_rotated_rb_circuits(n_qubits, theta):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rotated_rb_circuits", + "reference": "\ndef test_rotated_rb_circuits(n_qubits, theta):\n depth = 10\n for trials in [2, 3]:\n circuits = generate_rotated_rb_circuits(\n n_qubits=n_qubits, num_cliffords=depth, theta=theta, trials=trials\n )\n for circ in circuits:\n zero_prob = (\n cirq.DensityMatrixSimulator()\n .simulate(circ)\n .final_density_matrix[0, 0]\n .real\n )\n assert -1.0001 <= zero_prob <= 1.0001", + "completions": [ + "```python\ndef test_rotated_rb_circuits(n_qubits, theta):\n from cirq import Circuit, LineQubit\n from cirq.testing import assert_has_diagram\n from mitiq.generate import generate_rotated_rb_circuits\n\n circuits = generate_rotated_rb_circuits(n_qubits, 10, theta)\n for circuit in circuits:\n assert_has_diagram(\n circuit,\n f\"\"\"\n0: ───Rz({theta})───\n1: ───Rz({theta})───\n2: ───Rz({theta})───\n3: ───Rz({theta})───\n4: ───Rz({theta})───\n5: ───Rz({theta})───\n6: ───Rz({theta})───\n7: ───Rz({theta})───\n8: ───Rz({theta})───\n9: ───Rz({theta})───\n10: ───Rz({theta})───\n11: ───Rz({theta})───\n12: ───Rz({theta})───\n13: ───Rz({theta})───\n14: ───Rz({theta})───\n15: ───Rz({theta})───\n16: ───Rz({theta})───\n17: ───Rz({theta})───\n18: ───Rz({theta})───\n19: ───Rz({theta})───\n20: ───Rz({theta})───\n21: ───Rz({theta})───\n22: ───Rz({theta})───\n23: ───Rz({theta})───\n24: ───Rz({theta})───\n25: ───Rz({theta})───\n26: ───Rz({theta})───\n27: ───Rz({theta})───\n28: ───Rz({theta})───\n29: ───Rz({theta})───\n30: ───Rz({theta})───\n31: ───Rz({theta})───\n32: ───Rz({theta})───\n33: ───Rz({theta})───\n34: ───Rz({theta})───\n35: ───Rz({theta})───\n36: ───Rz({theta})───\n37: ───Rz({theta})───\n38: ───Rz({theta})───\n39: ───Rz({theta})───\n40: ───Rz({theta})───\n41: ───Rz({theta})───\n42: ───Rz({theta})───\n43: ───Rz({theta})───\n44: ───Rz({theta})───\n45: ───Rz({theta})───\n46: ───Rz({theta})───\n47: ───Rz({theta})───\n48: ───Rz({theta})───\n49: ───Rz({theta})───\n50: ───Rz({theta})───\n51: ───Rz({theta})───\n52: ───Rz({theta})───\n53: ───Rz({theta})───\n54: ───Rz({theta})───\n55: ───Rz({theta})───\n56: ───Rz({theta})───\n57: ───Rz({theta})───\n58: ───Rz({theta})───\n59: ───Rz({theta})───\n60: ───Rz({theta})───\n61: ───Rz({theta})───\n62: ───Rz({theta})───\n63: ───" + ], + "line": 29, + "token": 226, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 87, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef eye(N, M=None, k=None, dtype=None):\n dtype = to_torch_dtype(dtype or config.floatx())\n M = N if M is None else M\n k = 0 if k is None else k\n if k == 0:\n # TODO: torch.eye doesn't support bfloat16 with cpu\n if get_device() == \"cpu\" and dtype == torch.bfloat16:\n return cast(\n torch.eye(\n N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n ),\n dtype,\n )\n return torch.eye(N, M, dtype=dtype, device=get_device())\n diag_length = builtins.max(N, M)\n diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n return torch.diag(diag, diagonal=k)[:N, :M]\n```\n###test function signature:\n```python\n\n def test_eye(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_eye", + "reference": "\n def test_eye(self):\n self.assertAllClose(knp.eye(3), np.eye(3))\n self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n self.assertAllClose(knp.Eye()(3), np.eye(3))\n self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1))\n\n # Test k >= N\n self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3))\n\n # Test k > 0 and N >= M\n self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1))\n\n # Test k > 0 and N < M and N + k > M\n self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2))\n\n # Test k < 0 and M >= N\n self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1))\n\n # Test k < 0 and M < N and M - k > N\n self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))", + "completions": [ + "```python\n\n def test_eye(self):\n self.assertEqual(eye(3), torch.eye(3))\n self.assertEqual(eye(3, 4), torch.eye(3, 4))\n self.assertEqual(eye(3, 4, 1), torch.eye(3, 4, 1))\n self.assertEqual(eye(3, 4, -1), torch.eye(3, 4, -1))\n self.assertEqual(eye(3, 4, 2), torch.eye(3, 4, 2))\n self.assertEqual(eye(3, 4, -2), torch.eye(3, 4, -2))\n self.assertEqual(eye(3, 4, 3), torch.eye(3, 4, 3))\n self.assertEqual(eye(3, 4, -3), torch.eye(3, 4, -3))\n self.assertEqual(eye(3, 4, 4), torch.eye(3, 4, 4))\n self.assertEqual(eye(3, 4, -4), torch.eye(3, 4, -4))\n self.assertEqual(eye(3, 4, 5), torch.eye(3, 4, 5))\n self.assertEqual(eye(3, 4, -5), torch.eye(3, 4, -5))\n self.assertEqual(eye(3, 4, 6), torch.eye(3, 4, 6))\n self.assertEqual(eye(3, 4, -6), torch.eye(3, 4, -6))\n self.assertEqual(eye(3, 4, 7), torch.eye(3, 4, 7))\n self.assertEqual(eye(3, 4, -7), torch.eye(3, 4, -7))\n self.assertEqual(eye(3, 4, 8), torch.eye(3, 4, 8))\n self.assertEqual(eye(3, 4, -8), torch.eye(3, 4, -8))\n self.assertEqual(eye(3, 4, 9), torch.eye(3, 4, 9))\n self.assertEqual(eye(3, 4, -9), torch.eye(3, 4, -9))\n self.assertEqual(eye(3, 4, 10), torch.eye(3, 4, 10))\n self.assertEqual(eye(3, 4, -10), torch.eye(3, 4, -10))\n self.assertEqual(eye(3, 4, 11), torch.eye(3, 4, 11))\n self.assertEqual(eye(3, 4, -11), torch.eye(3, 4, -11))\n self.assertEqual(eye(3, 4, 12), torch.eye(3, 4, 12))\n self.assertEqual(eye(3, 4, -12), torch.eye(3, 4, -12))\n self.assertEqual(eye(3, 4, 13), torch.eye(3, 4, 13))\n self.assertEqual(eye(3, 4, -13), torch.eye(3, 4, -13))\n self.assertEqual(eye(3, 4, 14), torch.eye(3, 4, 14))\n self.assertEqual(eye(3, 4, -14), torch.eye(3, 4, -14))\n self.assertEqual(eye(3, 4, 15), torch.eye(" + ], + "line": 27, + "token": 219, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef ragged_qk_dotprod(\n query: RaggedActivations, key: RaggedActivations, lut: RaggedQkPidLookupTable\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n # check constraints\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n # allocates output\n # max_n_ctx_q_across_seqs = query.max_n_ctx_per_seq\n\n assert query.n_seqs == key.n_seqs\n # TODO: flag use zeros for garbage\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n # Integers\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n```\n###test function signature:\n```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ragged_qk_dotprod_single_seq", + "reference": "\ndef test_ragged_qk_dotprod_single_seq(dtype):\n d_head = 2\n\n key = RaggedActivations.from_list(\n [\n make_seq(n_ctx=3, value=42, d_model=d_head, dtype=dtype),\n ]\n )\n query = RaggedActivations.from_list(\n [\n make_seq(n_ctx=4, value=55, d_model=d_head, dtype=dtype),\n ]\n )\n torch_scores = scores_via_qk_dotprod(query, key)\n print(f\"{torch_scores=}\")\n\n lut = RaggedQkPidLookupTable.from_query_and_key_tokens_per_seq(\n n_ctx_q_per_seq=query.n_ctx_per_seq, n_ctx_k_per_seq=key.n_ctx_per_seq\n )\n\n scores = ragged_qk_dotprod(query, key, lut)\n assert_eq(torch_scores, scores)", + "completions": [ + "```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n # query\n query_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n query = RaggedActivations(\n raw_tensor=query_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # key\n key_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n key = RaggedActivations(\n raw_tensor=key_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # lut\n pid_to_in_q_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_in_k_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_q_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_k_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_seq_idx = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n lut = RaggedQkPidLookupTable(\n pid_to_in_q_token_offset=pid_to_in_q_token_offset,\n pid_to_in_k_token_offset=pid_to_in_k_token_offset,\n pid_to_out_q_block=pid_to_out_q_block,\n pid_to_out_k_block=pid_to_out_k_block,\n pid_to_out_seq_idx=pid_to_out_seq_idx,\n n_pids_total=1,\n )\n\n # scores_out\n scores_out = torch.ones(\n (1, 1, 1), dtype=dtype, device=torch.device(\"cuda\")\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_" + ], + "line": 57, + "token": 310, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 89, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):\n # Ravel both arrays, behavior for the first array could be different\n ar1 = np.asarray(ar1).ravel()\n ar2 = np.asarray(ar2).ravel()\n\n # Ensure that iteration through object arrays yields size-1 arrays\n if ar2.dtype == object:\n ar2 = ar2.reshape(-1, 1)\n\n if kind not in {None, 'sort', 'table'}:\n raise ValueError(\n f\"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.\")\n\n # Can use the table method if all arrays are integers or boolean:\n is_int_arrays = all(ar.dtype.kind in (\"u\", \"i\", \"b\") for ar in (ar1, ar2))\n use_table_method = is_int_arrays and kind in {None, 'table'}\n\n if use_table_method:\n if ar2.size == 0:\n if invert:\n return np.ones_like(ar1, dtype=bool)\n else:\n return np.zeros_like(ar1, dtype=bool)\n\n # Convert booleans to uint8 so we can use the fast integer algorithm\n if ar1.dtype == bool:\n ar1 = ar1.astype(np.uint8)\n if ar2.dtype == bool:\n ar2 = ar2.astype(np.uint8)\n\n ar2_min = np.min(ar2)\n ar2_max = np.max(ar2)\n\n ar2_range = int(ar2_max) - int(ar2_min)\n\n # Constraints on whether we can actually use the table method:\n # 1. Assert memory usage is not too large\n below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)\n # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype\n range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max\n # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype\n if ar1.size > 0:\n ar1_min = np.min(ar1)\n ar1_max = np.max(ar1)\n\n # After masking, the range of ar1 is guaranteed to be\n # within the range of ar2:\n ar1_upper = min(int(ar1_max), int(ar2_max))\n ar1_lower = max(int(ar1_min), int(ar2_min))\n\n range_safe_from_overflow &= all((\n ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,\n ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min\n ))\n\n # Optimal performance is for approximately\n # log10(size) > (log10(range) - 2.27) / 0.927.\n # However, here we set the requirement that by default\n # the intermediate array can only be 6x\n # the combined memory allocation of the original\n # arrays. See discussion on \n # https://github.com/numpy/numpy/pull/12065.\n\n if (\n range_safe_from_overflow and \n (below_memory_constraint or kind == 'table')\n ):\n\n if invert:\n outgoing_array = np.ones_like(ar1, dtype=bool)\n else:\n outgoing_array = np.zeros_like(ar1, dtype=bool)\n\n # Make elements 1 where the integer exists in ar2\n if invert:\n isin_helper_ar = np.ones(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 0\n else:\n isin_helper_ar = np.zeros(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 1\n\n # Mask out elements we know won't work\n basic_mask = (ar1 <= ar2_max) & (ar1 >= ar2_min)\n outgoing_array[basic_mask] = isin_helper_ar[ar1[basic_mask] -\n ar2_min]\n\n return outgoing_array\n elif kind == 'table': # not range_safe_from_overflow\n raise RuntimeError(\n \"You have specified kind='table', \"\n \"but the range of values in `ar2` or `ar1` exceed the \"\n \"maximum integer of the datatype. \"\n \"Please set `kind` to None or 'sort'.\"\n )\n elif kind == 'table':\n raise ValueError(\n \"The 'table' method is only \"\n \"supported for boolean or integer arrays. \"\n \"Please select 'sort' or None for kind.\"\n )\n\n\n # Check if one of the arrays may contain arbitrary objects\n contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject\n\n # This code is run when\n # a) the first condition is true, making the code significantly faster\n # b) the second condition is true (i.e. `ar1` or `ar2` may contain\n # arbitrary objects), since then sorting is not guaranteed to work\n if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:\n if invert:\n mask = np.ones(len(ar1), dtype=bool)\n for a in ar2:\n mask &= (ar1 != a)\n else:\n mask = np.zeros(len(ar1), dtype=bool)\n for a in ar2:\n mask |= (ar1 == a)\n return mask\n\n # Otherwise use sorting\n if not assume_unique:\n ar1, rev_idx = np.unique(ar1, return_inverse=True)\n ar2 = np.unique(ar2)\n\n ar = np.concatenate((ar1, ar2))\n # We need this to be a stable sort, so always use 'mergesort'\n # here. The values from the first array should always come before\n # the values from the second array.\n order = ar.argsort(kind='mergesort')\n sar = ar[order]\n if invert:\n bool_ar = (sar[1:] != sar[:-1])\n else:\n bool_ar = (sar[1:] == sar[:-1])\n flag = np.concatenate((bool_ar, [invert]))\n ret = np.empty(ar.shape, dtype=bool)\n ret[order] = flag\n\n if assume_unique:\n return ret[:len(ar1)]\n else:\n return ret[rev_idx]\n```\n###test function signature:\n```python\n\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_in1d_ravel", + "reference": "\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n a = np.arange(6).reshape(2, 3)\n b = np.arange(3, 9).reshape(3, 2)\n long_b = np.arange(3, 63).reshape(30, 2)\n ec = np.array([False, False, False, True, True, True])\n\n assert_array_equal(in1d(a, b, assume_unique=True, kind=kind),\n ec)\n assert_array_equal(in1d(a, b, assume_unique=False,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=True,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=False,\n kind=kind),\n ec)", + "completions": [ + "```python\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n ar1 = np.array([1, 2, 3, 4, 5])\n ar2 = np.array([1, 2, 3, 4, 5])\n ar1_ravel = ar1.ravel()\n ar2_ravel = ar2.ravel()\n self.assertTrue(np.array_equal(ar1_ravel, ar1))\n self.assertTrue(np.array_equal(ar2_ravel, ar2))\n self.assertTrue(np.array_equal(in1d(ar1, ar2, kind=kind),\n in1d(ar1_ravel, ar2_ravel, kind=kind)))\n```\n" + ], + "line": 133, + "token": 1112, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 90, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_metadata", + "reference": "\ndef test_header_to_metadata_and_cell_metadata():\n text = \"\"\"---\ntitle: Sample header\njupyter:\n mainlanguage: python\n---\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {\"mainlanguage\": \"python\"}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {\"lines_to_next_cell\": 0}\n assert pos == len(lines)", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_metadata():\n lines = [\n \"#' ---\",\n \"#' jupyter:\",\n \"#' jupytext:\",\n \"#' text_representation:\",\n \"#' extension: .py\",\n \"#' format_name: light\",\n \"#' format_version: '1.5'\",\n \"#' jupytext_version: 1.11.4\",\n \"#' kernelspec:\",\n \"#' display_name: Python 3\",\n \"#' language: python\",\n \"#' name: python3\",\n \"#' ---\",\n \"#'\",\n \"#' This is a comment\",\n \"#'\",\n \"#' And another one\",\n \"#'\",\n \"#' ```python\",\n \"#' import numpy as np\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' import matplotlib.pyplot as plt\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %matplotlib inline\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2" + ], + "line": 104, + "token": 642, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef import_optional_dependency(\n name: str,\n extra: str = \"\",\n errors: str = \"raise\",\n min_version: str | None = None,\n):\n\n assert errors in {\"warn\", \"raise\", \"ignore\"}\n\n package_name = INSTALL_MAPPING.get(name)\n install_name = package_name if package_name is not None else name\n\n msg = (\n f\"Missing optional dependency '{install_name}'. {extra} \"\n f\"Use pip or conda to install {install_name}.\"\n )\n try:\n module = importlib.import_module(name)\n except ImportError:\n if errors == \"raise\":\n raise ImportError(msg)\n return None\n\n # Handle submodules: if we have submodule, grab parent module from sys.modules\n parent = name.split(\".\")[0]\n if parent != name:\n install_name = parent\n module_to_get = sys.modules[install_name]\n else:\n module_to_get = module\n minimum_version = min_version if min_version is not None else VERSIONS.get(parent)\n if minimum_version:\n version = get_version(module_to_get)\n if version and Version(version) < Version(minimum_version):\n msg = (\n f\"Pandas requires version '{minimum_version}' or newer of '{parent}' \"\n f\"(version '{version}' currently installed).\"\n )\n if errors == \"warn\":\n warnings.warn(\n msg,\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return None\n elif errors == \"raise\":\n raise ImportError(msg)\n\n return module\n```\n###test function signature:\n```python\n\ndef test_bad_version(monkeypatch):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_bad_version", + "reference": "\ndef test_bad_version(monkeypatch):\n name = \"fakemodule\"\n module = types.ModuleType(name)\n module.__version__ = \"0.9.0\"\n sys.modules[name] = module\n monkeypatch.setitem(VERSIONS, name, \"1.0.0\")\n\n match = \"Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'\"\n with pytest.raises(ImportError, match=match):\n import_optional_dependency(\"fakemodule\")\n\n # Test min_version parameter\n result = import_optional_dependency(\"fakemodule\", min_version=\"0.8\")\n assert result is module\n\n with tm.assert_produces_warning(UserWarning):\n result = import_optional_dependency(\"fakemodule\", errors=\"warn\")\n assert result is None\n\n module.__version__ = \"1.0.0\" # exact match is OK\n result = import_optional_dependency(\"fakemodule\")\n assert result is module", + "completions": [ + "```python\ndef test_bad_version(monkeypatch):\n monkeypatch.setattr(import_optional_dependency, \"get_version\", lambda x: \"0.1.0\")\n monkeypatch.setattr(import_optional_dependency, \"VERSIONS\", {\"pandas\": \"0.2.0\"})\n with pytest.raises(ImportError):\n import_optional_dependency(\"pandas\")\n```\n" + ], + "line": 54, + "token": 352, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_mel_spectrogram(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mel_spectrogram", + "reference": "\n def test_mel_spectrogram(self):\n waveform = self._load_datasamples(1)[0]\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=513,\n num_mel_filters=13,\n min_frequency=100,\n max_frequency=4000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"htk\",\n )\n self.assertEqual(mel_filters.shape, (513, 13))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n )\n self.assertEqual(spec.shape, (513, 732))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n mel_filters=mel_filters,\n )\n self.assertEqual(spec.shape, (13, 732))\n\n # fmt: off\n expected = np.array([\n 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,\n 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,\n 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,\n 9.44153646e-04\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:, 300], expected))", + "completions": [ + "```python\n def test_mel_spectrogram(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n mel_floor = 1e-10\n log_mel = \"log\"\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (1, 10))\n self.assertEqual(spectrogram[0][0], 1)\n self.assertEqual(spectrogram[0][1], 2)\n self.assertEqual(spectrogram[0][2], 3)\n self.assertEqual(spectrogram[0][3], 4)\n self.assertEqual(spectrogram[0][4], 5)\n self.assertEqual(spectrogram[0][5], 6)\n self.assertEqual(spectrogram[0][6], 7)\n self.assertEqual(spectrogram[0][7], 8)\n self.assertEqual(spectrogram[0][8], 9)\n self.assertEqual(spectrogram[0][9], 10)\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 93, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef create(\n meta_schema,\n validators=(),\n version=None,\n type_checker=_types.draft202012_type_checker,\n format_checker=_format.draft202012_format_checker,\n id_of=_id_of,\n applicable_validators=methodcaller(\"items\"),\n):\n # preemptively don't shadow the `Validator.format_checker` local\n format_checker_arg = format_checker\n\n @attr.s\n class Validator:\n\n VALIDATORS = dict(validators)\n META_SCHEMA = dict(meta_schema)\n TYPE_CHECKER = type_checker\n FORMAT_CHECKER = format_checker_arg\n ID_OF = staticmethod(id_of)\n\n schema = attr.ib(repr=reprlib.repr)\n resolver = attr.ib(default=None, repr=False)\n format_checker = attr.ib(default=None)\n\n def __init_subclass__(cls):\n warnings.warn(\n (\n \"Subclassing validator classes is not intended to \"\n \"be part of their public API. A future version \"\n \"will make doing so an error, as the behavior of \"\n \"subclasses isn't guaranteed to stay the same \"\n \"between releases of jsonschema. Instead, prefer \"\n \"composition of validators, wrapping them in an object \"\n \"owned entirely by the downstream library.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n\n def __attrs_post_init__(self):\n if self.resolver is None:\n self.resolver = RefResolver.from_schema(\n self.schema,\n id_of=id_of,\n )\n\n @classmethod\n def check_schema(cls, schema, format_checker=_UNSET):\n Validator = validator_for(cls.META_SCHEMA, default=cls)\n if format_checker is _UNSET:\n format_checker = Validator.FORMAT_CHECKER\n validator = Validator(\n schema=cls.META_SCHEMA,\n format_checker=format_checker,\n )\n for error in validator.iter_errors(schema):\n raise exceptions.SchemaError.create_from(error)\n\n def evolve(self, **changes):\n # Essentially reproduces attr.evolve, but may involve instantiating\n # a different class than this one.\n cls = self.__class__\n\n schema = changes.setdefault(\"schema\", self.schema)\n NewValidator = validator_for(schema, default=cls)\n\n for field in attr.fields(cls):\n if not field.init:\n continue\n attr_name = field.name # To deal with private attributes.\n init_name = attr_name if attr_name[0] != \"_\" else attr_name[1:]\n if init_name not in changes:\n changes[init_name] = getattr(self, attr_name)\n\n return NewValidator(**changes)\n\n def iter_errors(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.iter_errors \"\n \"is deprecated and will be removed in a future \"\n \"release. Call validator.evolve(schema=new_schema).\"\n \"iter_errors(...) instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n else:\n _schema = self.schema\n\n if _schema is True:\n return\n elif _schema is False:\n yield exceptions.ValidationError(\n f\"False schema does not allow {instance!r}\",\n validator=None,\n validator_value=None,\n instance=instance,\n schema=_schema,\n )\n return\n\n scope = id_of(_schema)\n if scope:\n self.resolver.push_scope(scope)\n try:\n for k, v in applicable_validators(_schema):\n validator = self.VALIDATORS.get(k)\n if validator is None:\n continue\n\n errors = validator(self, v, instance, _schema) or ()\n for error in errors:\n # set details if not already set by the called fn\n error._set(\n validator=k,\n validator_value=v,\n instance=instance,\n schema=_schema,\n type_checker=self.TYPE_CHECKER,\n )\n if k not in {\"if\", \"$ref\"}:\n error.schema_path.appendleft(k)\n yield error\n finally:\n if scope:\n self.resolver.pop_scope()\n\n def descend(self, instance, schema, path=None, schema_path=None):\n for error in self.evolve(schema=schema).iter_errors(instance):\n if path is not None:\n error.path.appendleft(path)\n if schema_path is not None:\n error.schema_path.appendleft(schema_path)\n yield error\n\n def validate(self, *args, **kwargs):\n for error in self.iter_errors(*args, **kwargs):\n raise error\n\n def is_type(self, instance, type):\n try:\n return self.TYPE_CHECKER.is_type(instance, type)\n except exceptions.UndefinedTypeCheck:\n raise exceptions.UnknownType(type, instance, self.schema)\n\n def is_valid(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.is_valid is deprecated \"\n \"and will be removed in a future release. Call \"\n \"validator.evolve(schema=new_schema).is_valid(...) \"\n \"instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n self = self.evolve(schema=_schema)\n\n error = next(self.iter_errors(instance), None)\n return error is None\n\n if version is not None:\n safe = version.title().replace(\" \", \"\").replace(\"-\", \"\")\n Validator.__name__ = Validator.__qualname__ = f\"{safe}Validator\"\n Validator = validates(version)(Validator)\n\n return Validator\n```\n###test function signature:\n```python\n\n def test_validates_registers_meta_schema_id(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validates_registers_meta_schema_id", + "reference": "\n def test_validates_registers_meta_schema_id(self):\n meta_schema_key = \"meta schema id\"\n my_meta_schema = {\"id\": meta_schema_key}\n\n validators.create(\n meta_schema=my_meta_schema,\n version=\"my version\",\n id_of=lambda s: s.get(\"id\", \"\"),\n )\n self.addCleanup(validators._META_SCHEMAS.pop, meta_schema_key)\n\n self.assertIn(meta_schema_key, validators._META_SCHEMAS)", + "completions": [ + "```python\n def test_validates_registers_meta_schema_id(self):\n validator = create(\n meta_schema={\n \"$id\": \"https://example.com/schema.json\",\n \"type\": \"object\",\n \"properties\": {\"foo\": {\"type\": \"string\"}},\n }\n )\n self.assertEqual(\n validator.resolver.store[\"https://example.com/schema.json\"],\n validator.schema,\n )\n```\n" + ], + "line": 160, + "token": 896, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 94, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_metrics(directory: str) -> pd.Series:\n\n dfs = []\n # Loop through all csv files in the directory\n csv_files = [\n filename for filename in os.listdir(directory) if filename.endswith(\".csv\")\n ]\n\n if len(csv_files) == 0:\n return pd.Series()\n\n for filename in csv_files:\n filepath = os.path.join(directory, filename)\n # Read the csv file into a pandas DataFrame\n df = pd.read_csv(\n filepath,\n header=None,\n names=[\n \"initial_time_iso\",\n \"lead_time_hours\",\n \"channel\",\n \"metric\",\n \"value\",\n ],\n )\n out = pd.DataFrame()\n out[\"initial_time\"] = pd.to_datetime(df[\"initial_time_iso\"])\n out[\"lead_time\"] = pd.to_timedelta(df[\"lead_time_hours\"], unit=\"h\")\n out[\"channel\"] = df[\"channel\"]\n out[\"metric\"] = df[\"metric\"]\n out[\"value\"] = df[\"value\"]\n dfs.append(out)\n\n df = pd.concat(dfs, axis=0)\n df.set_index([\"initial_time\", \"lead_time\", \"channel\", \"metric\"], inplace=True)\n return df[\"value\"]\n```\n###test function signature:\n```python\n\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_and_read_metric", + "reference": "\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n initial_time = datetime.datetime(2022, 1, 1, 0, 0, 0)\n lead_time = datetime.timedelta(hours=24)\n channel = \"t2m\"\n metric = \"rmse\"\n value = 25.6\n\n # Read the metric from the directory\n directory = tmpdir.mkdir(\"metrics\")\n filename = \"test_metric.csv\"\n filepath = os.path.join(directory, filename)\n with open(filepath, \"w\") as f:\n forecast_metrics_io.write_metric(\n f, initial_time, lead_time, channel, metric, value\n )\n metrics = forecast_metrics_io.read_metrics(directory)\n\n # Check that the metric was written correctly\n expected_index = pd.MultiIndex.from_tuples(\n [(initial_time, lead_time, channel, metric)],\n names=[\"initial_time\", \"lead_time\", \"channel\", \"metric\"],\n )\n expected_value = pd.Series([value], index=expected_index, name=\"value\")\n pd.testing.assert_series_equal(metrics, expected_value)", + "completions": [ + "```python\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n metric = pd.Series(\n data=[1, 2, 3],\n index=[\n pd.Timestamp(\"2020-01-01 00:00:00\"),\n pd.Timestamp(\"2020-01-01 01:00:00\"),\n pd.Timestamp(\"2020-01-01 02:00:00\"),\n ],\n name=\"metric\",\n )\n metric.to_csv(tmpdir.join(\"metric.csv\"), header=False)\n\n # Read the metric from the file object\n read_metric = read_metrics(tmpdir.strpath)\n\n # Check that the metric is the same\n assert_series_equal(metric, read_metric)\n```\n" + ], + "line": 43, + "token": 320, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 95, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n\n def test_kde_bandwidth_method(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_kde_bandwidth_method", + "reference": "\n def test_kde_bandwidth_method(self):\n\n np.random.seed(8765678)\n n_basesample = 50\n xn = np.random.randn(n_basesample)\n\n # Default\n gkde = mlab.GaussianKDE(xn)\n # Supply a callable\n gkde2 = mlab.GaussianKDE(xn, 'scott')\n # Supply a scalar\n gkde3 = mlab.GaussianKDE(xn, bw_method=gkde.factor)\n\n xs = np.linspace(-7, 7, 51)\n kdepdf = gkde.evaluate(xs)\n kdepdf2 = gkde2.evaluate(xs)\n assert kdepdf.all() == kdepdf2.all()\n kdepdf3 = gkde3.evaluate(xs)\n assert kdepdf.all() == kdepdf3.all()", + "completions": [ + "```python\n def test_kde_bandwidth_method(self):\n from sklearn.neighbors import KernelDensity\n from sklearn.metrics import euclidean_distances\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n from sklearn.model_selection import train_test_split\n from sklearn.metrics import accuracy_score\n from sklearn.metrics import f1_score\n from sklearn.metrics import precision_score\n from sklearn.metrics import recall_score\n from sklearn.metrics import roc_auc_score\n from sklearn.metrics import confusion_matrix\n from sklearn.metrics import classification_report\n from sklearn.metrics import cohen_kappa_score\n from sklearn.metrics import mean_squared_error\n from sklearn.metrics import mean_absolute_error\n from sklearn.metrics import mean_squared_log_error\n from sklearn.metrics import median_absolute_error\n from sklearn.metrics import r2_score\n from sklearn.metrics import explained_variance_score\n from sklearn.metrics import silhouette_score\n from sklearn.metrics import calinski_harabasz_score\n from sklearn.metrics import davies_bouldin_score\n from sklearn.metrics import mutual_info_score\n from sklearn.metrics import normalized_mutual_info_score\n from sklearn.metrics import adjusted_mutual_info_score\n from sklearn.metrics import adjusted_rand_score\n from sklearn.metrics import completeness_score\n from sklearn.metrics import contingency_matrix\n from sklearn.metrics import homogeneity_completeness_v_measure\n from sklearn.metrics import homogeneity_score\n from sklearn.metrics import v_measure_score\n from sklearn.metrics import normalized_v_measure_score\n from sklearn.metrics import poisson\n from sklearn.metrics import kld\n from sklearn.metrics import kld_normalized\n from sklearn.metrics import kld_unnormalized\n from sklearn.metrics import kld_unnormalized_normalized\n from sklearn.metrics import kld_unnormalized_normalized_sym\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_top_k_embeddings(\n query_embedding: List[float],\n embeddings: List[List[float]],\n similarity_fn: Optional[Callable[..., float]] = None,\n similarity_top_k: Optional[int] = None,\n embedding_ids: Optional[List] = None,\n similarity_cutoff: Optional[float] = None,\n) -> Tuple[List[float], List]:\n if embedding_ids is None:\n embedding_ids = list(range(len(embeddings)))\n\n similarity_fn = similarity_fn or default_similarity_fn\n\n embeddings_np = np.array(embeddings)\n query_embedding_np = np.array(query_embedding)\n\n similarity_heap: List[Tuple[float, Any]] = []\n for i, emb in enumerate(embeddings_np):\n similarity = similarity_fn(query_embedding_np, emb)\n if similarity_cutoff is None or similarity > similarity_cutoff:\n heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))\n if similarity_top_k and len(similarity_heap) > similarity_top_k:\n heapq.heappop(similarity_heap)\n result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)\n\n result_similarities = [s for s, _ in result_tups]\n result_ids = [n for _, n in result_tups]\n\n return result_similarities, result_ids\n```\n###test function signature:\n```python\ndef test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_top_k_mmr_embeddings", + "reference": "def test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n query_embedding = [5.0, 0.0, 0.0]\n embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]]\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.8\n )\n\n assert np.isclose(0.8 * 4 / 5, result_similarities[0], atol=0.00001)\n assert np.isclose(\n 0.8 * 3 / 5 - (1 - 0.8) * (3 * 4 / 25 + 3 * 4 / 25),\n result_similarities[1],\n atol=0.00001,\n )\n assert np.isclose(\n 0.8 * -4 / 5 - (1 - 0.8) * (3 * -4 / 25 + 4 * 3 / 25),\n result_similarities[2],\n atol=0.00001,\n )\n assert result_ids == [0, 1, 2]\n\n # Tests that if the first embedding vector is close to the second,\n # it will return the third\n query_embedding = [1.0, 0.0, 1.0]\n embeddings = [[1.0, 0.0, 0.9], [1.0, 0.0, 0.8], [0.7, 0.0, 1.0]]\n\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.5\n )\n assert result_ids == [0, 2, 1]\n\n # Tests that embedding ids map properly to results\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, embedding_ids=[\"A\", \"B\", \"C\"], mmr_threshold=0.5\n )\n assert result_ids == [\"A\", \"C\", \"B\"]\n # Test that it will go back to the original order under a high threshold\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n assert result_ids == [0, 1, 2]\n\n # Test similarity_top_k works\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1, similarity_top_k=2\n )\n assert result_ids == [0, 1]\n\n # Test the results for get_top_k_embeddings and get_top_k_mmr_embeddings are the\n # same for threshold = 1\n query_embedding = [10, 23, 90, 78]\n embeddings = [[1, 23, 89, 68], [1, 74, 144, 23], [0.23, 0.0, 1.0, 9]]\n result_similarities_no_mmr, result_ids_no_mmr = get_top_k_embeddings(\n query_embedding, embeddings\n )\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n\n for result_no_mmr, result_with_mmr in zip(\n result_similarities_no_mmr, result_similarities\n ):\n assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)", + "completions": [ + "```python\ndef test_get_top_k_mmr_embeddings() -> None:\n query_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]\n embeddings = [\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n ]\n result_similarities, result_ids = get_top_k_embeddings(\n query_embedding, embeddings, similarity_top_k=3\n )\n assert result_similarities == [0.5, 0.4, 0.3]\n assert result_ids == [0, 1, 2]\n```\n" + ], + "line": 35, + "token": 309, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 97, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef frame_accuracy(\n logits: np.ndarray,\n label: np.ndarray,\n background_logit_threshold: float = 0.0,\n) -> float:\n top1_idx = np.argmax(logits, axis=-1)\n background_label = np.sum(label, axis=-1) == 0\n pred_background = (\n np.max(logits, axis=-1) <= background_logit_threshold\n ).astype(np.int32)\n\n # Extracts the label at the highest logit index for each input.\n top1_correct = np.take_along_axis(label, top1_idx[..., None], axis=-1)\n top1_correct = np.squeeze(top1_correct)\n foreground_correct = ~pred_background.astype(bool) * top1_correct\n\n # Count correctly classified background samples.\n background_correct = pred_background * background_label\n correct = foreground_correct + background_correct\n return np.sum(correct) / len(logits)\n```\n###test function signature:\n```python\n\n def test_frame_accuracy_all_background(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_frame_accuracy_all_background", + "reference": "\n def test_frame_accuracy_all_background(self):\n logits = np.array(\n [\n [1.5, 0.0, -1.0],\n [0.1, 0.7, 0.2],\n [0.1, 0.7, 0.2],\n [0.1, -0.7, 0.2],\n [-0.1, -0.7, -0.2],\n ],\n np.float32,\n )\n label = np.array(\n [\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n ],\n np.int32,\n )\n actual = metrics.frame_accuracy(logits, label)\n self.assertAlmostEqual(actual, 0.2)", + "completions": [ + "```python\ndef test_frame_accuracy_all_background(self):\n logits = np.array([[0.1, 0.2, 0.3, 0.4]])\n label = np.array([[0, 0, 0, 0]])\n self.assertAlmostEqual(frame_accuracy(logits, label), 1.0)\n```\n" + ], + "line": 28, + "token": 206, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_local_image_to_be_loaded", + "reference": "\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n file_path, _ = example_local_image\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=file_path,\n max_width=64,\n max_height=128,\n )\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert abs(scaling_factor - 0.5) < 1e-5\n assert decoding_result.shape == (64, 64, 3)\n assert (decoding_result == 0).all()", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n reference, image = example_local_image\n # when\n result = load_image_from_string(\n reference=reference,\n max_height=image.shape[0],\n max_width=image.shape[1],\n )\n # then\n assert result[0] == reference\n assert result[1] is None\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 99, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = \"htk\") -> Union[float, np.ndarray]:\n\n if mel_scale not in [\"slaney\", \"htk\", \"kaldi\"]:\n raise ValueError('mel_scale should be one of \"htk\", \"slaney\" or \"kaldi\".')\n\n if mel_scale == \"htk\":\n return 2595.0 * np.log10(1.0 + (freq / 700.0))\n elif mel_scale == \"kaldi\":\n return 1127.0 * np.log(1.0 + (freq / 700.0))\n\n min_log_hertz = 1000.0\n min_log_mel = 15.0\n logstep = 27.0 / np.log(6.4)\n mels = 3.0 * freq / 200.0\n\n if isinstance(freq, np.ndarray):\n log_region = freq >= min_log_hertz\n mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep\n elif freq >= min_log_hertz:\n mels = min_log_mel + np.log(freq / min_log_hertz) * logstep\n\n return mels\n```\n###test function signature:\n```python\n\n def test_hertz_to_mel(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_hertz_to_mel", + "reference": "\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(0.0), 0.0)\n self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)\n\n inputs = np.array([100, 200])\n expected = np.array([150.48910241, 283.22989816])\n self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))\n\n self.assertEqual(hertz_to_mel(0.0, \"slaney\"), 0.0)\n self.assertEqual(hertz_to_mel(100, \"slaney\"), 1.5)\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"slaney\"), expected))\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"kaldi\"), expected))\n\n with pytest.raises(ValueError):\n hertz_to_mel(100, mel_scale=None)", + "completions": [ + "```python\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(" + ], + "line": 27, + "token": 262, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/3/EI/line_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Python/3/EI/line_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..95d35063ee5628f014e069cd168828498ca9c05d --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/3/EI/line_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(15~187), subset_1,subset_2(749~749) +StarCoder2-15b,24.55,0.00,20.70 +CodeLlama-7b,26.44,0.00,27.18 +CodeLlama-13b,26.46,0.00,31.82 +CodeLlama-34b,27.20,0.00,24.58 +DeepSeek-Coder-1.3b,25.85,0.00,21.28 +DeepSeek-Coder-6.7b,24.76,0.00,20.74 +DeepSeek-Coder-33b,29.22,0.00,25.46 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/3/EI/token_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Python/3/EI/token_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..31534d1994cc988d1297fdd390e4299be9bd7f7d --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/3/EI/token_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(122~1981),subset_1,subset_2(7038~7038) +StarCoder2-15b,24.55,0.00,20.70 +CodeLlama-7b,26.44,0.00,27.18 +CodeLlama-13b,26.46,0.00,31.82 +CodeLlama-34b,27.20,0.00,24.58 +DeepSeek-Coder-1.3b,25.85,0.00,21.28 +DeepSeek-Coder-6.7b,24.76,0.00,20.74 +DeepSeek-Coder-33b,29.22,0.00,25.46 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/3/EI/tongji.py b/dataset/Test Generation/ComplexCodeEval-Python/3/EI/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/3/EI/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/3/QS/QS.json b/dataset/Test Generation/ComplexCodeEval-Python/3/QS/QS.json new file mode 100644 index 0000000000000000000000000000000000000000..19e997a27ff703504b14dcc97fb493ec631ae534 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/3/QS/QS.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_pair_vector_and_distance(g: dgl.DGLGraph):\n dst_pos = g.ndata[\"pos\"][g.edges()[1]] + g.edata[\"pbc_offshift\"]\n src_pos = g.ndata[\"pos\"][g.edges()[0]]\n bond_vec = dst_pos - src_pos\n bond_dist = torch.norm(bond_vec, dim=1)\n\n return bond_vec, bond_dist\n```\n###test function signature:\n```python\n\n def test_compute_angle(self, graph_Mo, graph_CH4):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_angle", + "reference": "\n def test_compute_angle(self, graph_Mo, graph_CH4):\n s1, g1, state1 = graph_Mo\n lattice = torch.tensor(s1.lattice.matrix, dtype=matgl.float_th).unsqueeze(dim=0)\n g1.edata[\"pbc_offshift\"] = torch.matmul(g1.edata[\"pbc_offset\"], lattice[0])\n g1.ndata[\"pos\"] = g1.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g1)\n g1.edata[\"bond_vec\"] = bv\n g1.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g1, 4.0)\n\n line_graph = create_line_graph(g1, 4.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))\n np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata[\"theta\"])), decimal=4)\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n s2, g2, state2 = graph_CH4\n lattice = torch.tensor(np.identity(3), dtype=matgl.float_th).unsqueeze(dim=0)\n g2.edata[\"pbc_offshift\"] = torch.matmul(g2.edata[\"pbc_offset\"], lattice[0])\n g2.ndata[\"pos\"] = g2.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g2)\n g2.edata[\"bond_vec\"] = bv\n g2.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g2, 2.0)\n\n line_graph = create_line_graph(g2, 2.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata[\"theta\"]))\n )\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )", + "completions": [ + "```python\n def test_compute_angle(self, graph_Mo, graph_CH4):\n \"\"\"\n Test compute_angle function\n \"\"\"\n # compute angle between Mo-C and Mo-H\n angle_Mo_C = compute_angle(graph_Mo, graph_CH4, 0, 1)\n angle_Mo_H = compute_angle(graph_Mo, graph_CH4, 0, 2)\n\n # check angle between Mo-C and Mo-H\n self.assertAlmostEqual(angle_Mo_C, 104.5, places=1)\n self.assertAlmostEqual(angle_Mo_H, 109.5, places=1)\n```\n" + ], + "line": 16, + "token": 148, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 1, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def _make_schedule(self, targets: List[Task]) -> List[Task]:\n self.schedule = []\n self.dependencies = self._build_dependencies(targets)\n\n edge_tups = []\n for node in self.dependencies:\n for dependency in self.dependencies[node]:\n edge_tups.append((dependency, node))\n\n for task in targets:\n # add edges from a dummy node to each target to guarantee\n # they will be included in the final schedule\n edge_tups.append((Executor.DUMMY_TASK_VALUE, task))\n\n def _compare_key(task: Union[Task, str]):\n if task == Executor.DUMMY_TASK_VALUE:\n return (\"\", 0)\n return (\n task.group_label() or \"\",\n -task.priority(),\n )\n\n graph = networkx.DiGraph(edge_tups)\n res = [\n t\n for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)\n if t != Executor.DUMMY_TASK_VALUE\n ]\n return res\n```\n###test function signature:\n```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_group_label_with_dependencies", + "reference": "\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = create_mock_task(\"task1\", result=1, group_label=\"group1\")\n task2 = create_mock_task(\n \"task2\", result=2, dependencies={\"task1\": task1}, group_label=\"group2\"\n )\n task3 = create_mock_task(\n \"task3\", result=3, dependencies={\"task2\": task2}, group_label=\"group1\"\n )\n\n executor = Executor([task3])\n schedule = executor._make_schedule([task3])\n scheduled_labels = [\n task.group_label() for task in schedule if task.group_label()\n ]\n\n # Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1\n group1_indices = [\n i for i, label in enumerate(scheduled_labels) if label == \"group1\"\n ]\n group2_index = scheduled_labels.index(\"group2\")\n\n assert (\n group1_indices[-1] > group2_index\n ), \"Task with the same group label but later dependency was not scheduled after different group label\"", + "completions": [ + "```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = Task(name=\"task1\", group_label=\"group1\")\n task2 = Task(name=\"task2\", group_label=\"group1\")\n task3 = Task(name=\"task3\", group_label=\"group2\")\n task4 = Task(name=\"task4\", group_label=\"group2\")\n task5 = Task(name=\"task5\", group_label=\"group3\")\n task6 = Task(name=\"task6\", group_label=\"group3\")\n task7 = Task(name=\"task7\", group_label=\"group4\")\n task8 = Task(name=\"task8\", group_label=\"group4\")\n task9 = Task(name=\"task9\", group_label=\"group5\")\n task10 = Task(name=\"task10\", group_label=\"group5\")\n\n # Creating dependencies between tasks\n task1.depends_on(task2)\n task2.depends_on(task3)\n task3.depends_on(task4)\n task4.depends_on(task5)\n task5.depends_on(task6)\n task6.depends_on(task7)\n task7.depends_on(task8)\n task8.depends_on(task9)\n task9.depends_on(task10)\n\n # Creating a list of tasks\n tasks = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies\n tasks_with_dependencies = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with group labels\n tasks_with_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_" + ], + "line": 36, + "token": 242, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 2, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n):\n if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:\n return\n\n head_dim = dim // heads\n max_seq_len = find_multiple(max_seq_len, 8)\n\n for b in layers:\n b.attention.kv_cache = KVCache(\n max_batch_size, max_seq_len, heads, head_dim\n )\n\n freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)\n causal_mask = torch.tril(\n torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)\n )\n\n return causal_mask, freq_cis\n```\n###test function signature:\n```python\n\ndef test_setup_cache_max_batch_size_greater_than_max():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_setup_cache_max_batch_size_greater_than_max", + "reference": "\ndef test_setup_cache_max_batch_size_greater_than_max():\n layers = [\n Mock(attention=Mock(kw_cache=None)),\n Mock(attention=Mock(kw_cache=None)),\n ]\n max_batch_size = 64\n max_seq_len = 32\n heads = 4\n head_dim = 32\n block_size = 32\n rope_base = 1000\n setup_cache(\n max_batch_size + 10,\n max_seq_len,\n head_dim * heads,\n heads,\n layers,\n block_size,\n rope_base,\n )\n for layer in layers:\n assert isinstance(layer.attention.kw_cache, KVCache)\n assert layer.attention.kw_cache.k_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )\n assert layer.attention.kw_cache.v_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )", + "completions": [ + "```python\ndef test_setup_cache_max_batch_size_greater_than_max():\n max_batch_size = 10\n max_seq_len = 10\n dim = 10\n heads = 10\n layers = [1, 2, 3]\n block_size = 10\n rope_base = 10\n\n causal_mask, freq_cis = setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n )\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size" + ], + "line": 26, + "token": 168, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_file_to_be_downloaded", + "reference": "\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n response = Response()\n response.status_code = 200\n image = np.zeros((128, 128, 3), dtype=np.uint8)\n _, encoded_image = cv2.imencode(\".jpg\", image)\n response._content = encoded_image\n requests_get_mock.return_value = response\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=\"https://some/file.jpg\"\n )\n\n # then\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert scaling_factor is None\n assert decoding_result.shape == image.shape\n assert np.allclose(decoding_result, image)\n requests_get_mock.assert_called_once_with(\"https://some/file.jpg\")", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n reference = \"https://www.example.com/image.jpg\"\n max_height = 100\n max_width = 100\n requests_get_mock.return_value.content = b\"image_bytes\"\n # when\n result = load_image_from_string(\n reference=reference, max_height=max_height, max_width=max_width\n )\n # then\n assert result == (\n \"data:image/jpeg;base64,image_bytes\",\n None,\n )\n```\n" + ], + "line": 47, + "token": 276, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 4, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_random_circuits", + "reference": "\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n placer = placer(connectivity=connectivity)\n layout_circ = Circuit(5)\n initial_layout = placer(layout_circ)\n router = Sabre(connectivity=connectivity, lookahead=look, decay_lookahead=decay)\n circuit = generate_random_circuit(nqubits=5, ngates=n_gates)\n measurement = gates.M(*range(5))\n circuit.add(measurement)\n transpiled_circuit, final_qubit_map = router(circuit, initial_layout)\n assert router.added_swaps >= 0\n assert_connectivity(connectivity, transpiled_circuit)\n assert_placement(transpiled_circuit, final_qubit_map)\n assert n_gates + router.added_swaps + 1 == transpiled_circuit.ngates\n assert_circuit_equivalence(\n original_circuit=circuit,\n transpiled_circuit=transpiled_circuit,\n final_map=final_qubit_map,\n initial_map=initial_layout,\n )\n assert transpiled_circuit.queue[-1].register_name == measurement.register_name", + "completions": [ + "```python\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n import random\n import networkx as nx\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import" + ], + "line": 27, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 5, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef xx_circuit_step(source, strength, target, embodiment):\n\n permute_source_for_overlap, permute_target_for_overlap = None, None\n\n # apply all possible reflections, shifts to the source\n for source_reflection_name in reflection_options:\n reflected_source_coord, source_reflection, reflection_phase_shift = apply_reflection(\n source_reflection_name, source\n )\n for source_shift_name in shift_options:\n shifted_source_coord, source_shift, shift_phase_shift = apply_shift(\n source_shift_name, reflected_source_coord\n )\n\n # check for overlap, back out permutation\n source_shared, target_shared = None, None\n for i, j in [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]:\n\n if (\n abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi)) < EPSILON\n or abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi) - np.pi)\n < EPSILON\n ):\n source_shared, target_shared = i, j\n break\n if source_shared is None:\n continue\n\n # pick out the other coordinates\n source_first, source_second = (x for x in [0, 1, 2] if x != source_shared)\n target_first, target_second = (x for x in [0, 1, 2] if x != target_shared)\n\n # check for arccos validity\n r, s, u, v, x, y = decompose_xxyy_into_xxyy_xx(\n float(target[target_first]),\n float(target[target_second]),\n float(shifted_source_coord[source_first]),\n float(shifted_source_coord[source_second]),\n float(strength),\n )\n if any(math.isnan(val) for val in (r, s, u, v, x, y)):\n continue\n\n # OK: this combination of things works.\n # save the permutation which rotates the shared coordinate into ZZ.\n permute_source_for_overlap = canonical_rotation_circuit(source_first, source_second)\n permute_target_for_overlap = canonical_rotation_circuit(target_first, target_second)\n break\n\n if permute_source_for_overlap is not None:\n break\n\n if permute_source_for_overlap is None:\n raise QiskitError(\n \"Error during RZX decomposition: Could not find a suitable Weyl \"\n f\"reflection to match {source} to {target} along {strength}.\"\n )\n\n prefix_circuit, affix_circuit = QuantumCircuit(2), QuantumCircuit(2)\n\n # the basic formula we're trying to work with is:\n # target^p_t_f_o =\n # rs * (source^s_reflection * s_shift)^p_s_f_o * uv * operation * xy\n # but we're rearranging it into the form\n # target = affix source prefix\n # and computing just the prefix / affix circuits.\n\n # the outermost prefix layer comes from the (inverse) target permutation.\n prefix_circuit.compose(permute_target_for_overlap.inverse(), inplace=True)\n # the middle prefix layer comes from the local Z rolls.\n prefix_circuit.rz(2 * x, [0])\n prefix_circuit.rz(2 * y, [1])\n prefix_circuit.compose(embodiment, inplace=True)\n prefix_circuit.rz(2 * u, [0])\n prefix_circuit.rz(2 * v, [1])\n # the innermost prefix layer is source_reflection, shifted by source_shift,\n # finally conjugated by p_s_f_o.\n prefix_circuit.compose(permute_source_for_overlap, inplace=True)\n prefix_circuit.compose(source_reflection, inplace=True)\n prefix_circuit.global_phase += -np.log(reflection_phase_shift).imag\n prefix_circuit.global_phase += -np.log(shift_phase_shift).imag\n\n # the affix circuit is constructed in reverse.\n # first (i.e., innermost), we install the other half of the source transformations and p_s_f_o.\n affix_circuit.compose(source_reflection.inverse(), inplace=True)\n affix_circuit.compose(source_shift, inplace=True)\n affix_circuit.compose(permute_source_for_overlap.inverse(), inplace=True)\n # then, the other local rolls in the middle.\n affix_circuit.rz(2 * r, [0])\n affix_circuit.rz(2 * s, [1])\n # finally, the other half of the p_t_f_o conjugation.\n affix_circuit.compose(permute_target_for_overlap, inplace=True)\n\n return {\"prefix_circuit\": prefix_circuit, \"affix_circuit\": affix_circuit}\n```\n###test function signature:\n```python\n def test_xx_circuit_step(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_xx_circuit_step", + "reference": " def test_xx_circuit_step(self):\n\n for _ in range(100):\n source_coordinate, interaction, target_coordinate = self._generate_xxyy_test_case()\n\n source_embodiment = qiskit.QuantumCircuit(2)\n source_embodiment.append(UnitaryGate(canonical_matrix(*source_coordinate)), [0, 1])\n interaction_embodiment = qiskit.QuantumCircuit(2)\n interaction_embodiment.append(UnitaryGate(canonical_matrix(*interaction)), [0, 1])\n\n prefix_circuit, affix_circuit = itemgetter(\"prefix_circuit\", \"affix_circuit\")(\n xx_circuit_step(\n source_coordinate, interaction[0], target_coordinate, interaction_embodiment\n )\n )\n\n target_embodiment = QuantumCircuit(2)\n target_embodiment.compose(prefix_circuit, inplace=True)\n target_embodiment.compose(source_embodiment, inplace=True)\n target_embodiment.compose(affix_circuit, inplace=True)\n self.assertTrue(\n np.all(\n np.abs(\n qiskit.quantum_info.operators.Operator(target_embodiment).data\n - canonical_matrix(*target_coordinate)\n )\n < EPSILON\n )\n )", + "completions": [ + "```python\n def test_xx_circuit_step(self):\n source = np.array([0.1, 0.2, 0.3])\n strength = 0.4\n target = np.array([0.5, 0.6, 0.7])\n embodiment = QuantumCircuit(2)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.cx(0, 1)\n embodiment.barrier()\n embodiment.rz(0.1, 0)\n embodiment.rz(0.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.3, 0)\n embodiment.rz(0.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.5, 0)\n embodiment.rz(0.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.7, 0)\n embodiment.rz(0.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.9, 0)\n embodiment.rz(1.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.1, 0)\n embodiment.rz(1.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.3, 0)\n embodiment.rz(1.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.5, 0)\n embodiment.rz(1.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.7, 0)\n embodiment.rz(1.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.9, 0)\n embodiment.rz(2.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(2.1, 0)\n embodiment.rz(2.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment" + ], + "line": 90, + "token": 807, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 6, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_einsum(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_einsum", + "reference": "\n def test_einsum(self):\n x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n self.assertAllClose(\n knp.einsum(\"ijk,lkj->il\", x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"ijk,ikj->i\", x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"i...,j...k->...ijk\", x, y),\n np.einsum(\"i..., j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n self.assertAllClose(\n knp.Einsum(\"ijk,lkj->il\")(x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"ijk,ikj->i\")(x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n np.einsum(\"i...,j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))", + "completions": [ + "```python\n\n def test_arange(self):\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32),\n torch.arange(0, 10, 1, dtype=torch.float32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64),\n torch.arange(0, 10, 1, dtype=torch.float64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32),\n torch.arange(0, 10, 1, dtype=torch.int32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64),\n torch.arange(0, 10, 1, dtype=torch.int64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool),\n torch.arange(0, 10, 1, dtype=torch.bool),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64),\n torch.arange(0, 10, 1, dtype=torch.complex64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128),\n torch.arange(0, 10, 1, dtype=torch.complex128),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n )\n```\n" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 7, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_center_padding(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_center_padding", + "reference": "\n def test_spectrogram_center_padding(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"reflect\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,\n 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,\n 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,\n 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,\n 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,\n 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,\n 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,\n 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,\n 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,\n 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,\n 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,\n 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,\n 0.00217659, 0.00276204, 0.00260835, 0.00299299,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"constant\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,\n 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,\n 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,\n 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,\n 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,\n 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,\n 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,\n 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,\n 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,\n 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,\n 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,\n 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,\n 0.00788239, 0.00664407, 0.00824227, 0.00628301,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=False,\n )\n self.assertEqual(spec.shape, (257, 728))\n\n # fmt: off\n expected = np.array([\n 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,\n 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,\n 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,\n 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,\n 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,\n 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,\n 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,\n 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,\n 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,\n 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,\n 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,\n 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,\n 0.00811857, 0.00538216, 0.00685749, 0.00535275,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))", + "completions": [ + "```python\n def test_spectrogram_center_padding(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 4\n hop_length = 2\n fft_length = 16\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n np.testing.assert_array_equal(\n spectrogram,\n np.array(\n [\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0," + ], + "line": 95, + "token": 804, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 8, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_kword_ref_line", + "reference": "\ndef test_nanoplot_out_bars_with_kword_ref_line():\n\n out_bars_with_kword_ref_line = _generate_nanoplot(**CASES[12])\n\n assert _is_nanoplot_output(out_bars_with_kword_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = \"max\"\n y_ref_area = [1, 2]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = \"min\"\n expand_y = \"min\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_stroke_color=data_bar_negative_stroke_" + ], + "line": 749, + "token": 7038, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 9, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_take_along_axis(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_take_along_axis", + "reference": "\n def test_take_along_axis(self):\n x = np.arange(24).reshape([1, 2, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=1),\n np.take_along_axis(x, indices, axis=1),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=1)(x, indices),\n np.take_along_axis(x, indices, axis=1),\n )\n\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )\n\n # Test with axis=None\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.array([1, 2, 3], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=None),\n np.take_along_axis(x, indices, axis=None),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=None)(x, indices),\n np.take_along_axis(x, indices, axis=None),\n )\n\n # Test with negative indices\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )", + "completions": [ + "```python\n\n def test_take_along_axis(self):\n # test arange\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=1, dtype=torch.float32),\n torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=2, dtype=torch.float32),\n torch.tensor([0, 2, 4, 6, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=3, dtype=torch.float32),\n torch.tensor([0, 3, 6, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=4, dtype=torch.float32),\n torch.tensor([0, 4, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=5, dtype=torch.float32),\n torch.tensor([0, 5], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=6, dtype=torch.float32),\n torch.tensor([0, 6], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=7, dtype=torch.float32),\n torch.tensor([0, 7], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=8, dtype=torch.float32),\n torch.tensor([0, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=9, dtype=torch.float32),\n torch.tensor([0, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=10, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=11, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=12, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=13, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=14, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 10, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_correlation(\n response_level_scores: dict[str, list[dict[str, dict[str, Any]]]],\n name: str = '',\n) -> dict[str, dict[str, dict[str, float]]]:\n factscore_metrics = set(list_metrics(response_level_scores[FACTSCORE]))\n pred_metrics = set(list_metrics(response_level_scores[EVAL_METHOD]))\n result = {}\n\n for metric in factscore_metrics.intersection(pred_metrics):\n factscore_scores = [\n find_metric(response_level_scores[FACTSCORE][i], metric)\n for i in range(len(response_level_scores[FACTSCORE]))\n ]\n pred_scores = [\n find_metric(response_level_scores[EVAL_METHOD][i], metric)\n for i in range(len(response_level_scores[EVAL_METHOD]))\n ]\n\n # Remove failed runs\n for i in range(len(factscore_scores) - 1, -1, -1):\n if factscore_scores[i] == -1 or pred_scores[i] == -1:\n factscore_scores.pop(i)\n pred_scores.pop(i)\n\n # Cannot calculate correlation with only one data point\n if len(factscore_scores) <= 1 or len(pred_scores) <= 1:\n pearson_result, spearman_result = None, None\n else:\n pearson_result = stats.pearsonr(factscore_scores, pred_scores)\n spearman_result = stats.spearmanr(factscore_scores, pred_scores)\n\n scatter_plot(\n factscore_scores,\n pred_scores,\n f'{name}-{metric}',\n FACTSCORE,\n EVAL_METHOD,\n )\n result[metric] = {\n _PEARSON: CorrelationResult(pearson_result).result,\n _SPEARMAN: CorrelationResult(spearman_result).result,\n }\n\n return result\n```\n###test function signature:\n```python\n\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_correlation_one_data_point", + "reference": "\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n response_level_scores = {\n correlation_vs_factscore.FACTSCORE: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'N': 1}},\n ],\n correlation_vs_factscore.EVAL_METHOD: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'M': 1}},\n ],\n }\n actual_output = correlation_vs_factscore.compute_correlation(\n response_level_scores=response_level_scores\n )\n mock_pearsonr.assert_not_called()\n mock_spearmanr.assert_not_called()\n mock_scatter_plot.assert_called()\n self.assertIsInstance(actual_output, dict)\n self.assertLen(actual_output, 1)\n self.assertIn('S', actual_output)\n self.assertIsInstance(actual_output['S'], dict)\n self.assertLen(actual_output['S'], 2)\n self.assertIn('Pearson', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Pearson'], dict)\n self.assertIn('correlation', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['correlation']))\n self.assertIn('p_value', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['p_value']))\n self.assertIn('Spearman', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Spearman'], dict)\n self.assertIn('correlation', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['correlation']))\n self.assertIn('p_value', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['p_value']))", + "completions": [ + "```python\ndef test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n) -> None:\n response_level_scores = {\n FACTSCORE: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n EVAL_METHOD: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n }\n result = compute_correlation(response_level_scores)\n self.assertEqual(result, {})\n mock_spearmanr.assert_not_called()\n mock_pearsonr.assert_not_called()\n mock_scatter_plot.assert_not_called()\n```\n" + ], + "line": 54, + "token": 384, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 11, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_data_lines():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_data_lines", + "reference": "\ndef test_nanoplot_out_data_lines():\n\n out_data_lines = _generate_nanoplot(**CASES[0])\n\n assert _is_nanoplot_output(out_data_lines)\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_data_lines():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n expected = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_" + ], + "line": 749, + "token": 7038, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 12, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_daily_temp_mean", + "reference": "\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data = il_electricity_cdd_hdd_daily[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_daily[\"temperature_data\"]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert df.shape == (810, 3)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_mean\",\n ]\n\n assert round(df.temperature_mean.mean()) == 55.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"D\", tz=\"UTC\"\n )\n temperature_data = pd.DataFrame(\n {\"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]},\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n df = compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=[],\n cooling_balance_points=[],\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n )\n expected = pd.DataFrame(\n {\n \"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],\n \"n_hours_kept\": [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],\n \"n_hours_dropped\": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n },\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n pd.testing.assert_frame_equal(df, expected)\n```\n" + ], + "line": 166, + "token": 983, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 13, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef power_color(\n frequency,\n power,\n power_err=None,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n):\n freq_edges = np.asarray(freq_edges)\n if len(freq_edges) != 5:\n raise ValueError(\"freq_edges must have 5 elements\")\n\n frequency = np.asarray(frequency)\n power = np.asarray(power)\n\n if df is None:\n df = np.median(np.diff(frequency))\n input_frequency_low_edges = frequency - df / 2\n input_frequency_high_edges = frequency + df / 2\n\n if freq_edges.min() < input_frequency_low_edges[0]:\n raise ValueError(\"The minimum frequency is larger than the first frequency edge\")\n if freq_edges.max() > input_frequency_high_edges[-1]:\n raise ValueError(\"The maximum frequency is lower than the last frequency edge\")\n\n if power_err is None:\n power_err = power / np.sqrt(m)\n else:\n power_err = np.asarray(power_err)\n\n if freqs_to_exclude is not None:\n if len(np.shape(freqs_to_exclude)) == 1:\n freqs_to_exclude = [freqs_to_exclude]\n\n if (\n not isinstance(freqs_to_exclude, Iterable)\n or len(np.shape(freqs_to_exclude)) != 2\n or np.shape(freqs_to_exclude)[1] != 2\n ):\n raise ValueError(\"freqs_to_exclude must be of format [[f0, f1], [f2, f3], ...]\")\n for f0, f1 in freqs_to_exclude:\n frequency_mask = (input_frequency_low_edges > f0) & (input_frequency_high_edges < f1)\n idx0, idx1 = np.searchsorted(frequency, [f0, f1])\n power[frequency_mask] = np.mean([power[idx0], power[idx1]])\n\n var00, var00_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[:2],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var01, var01_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[2:4],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var10, var10_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[1:3],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var11, var11_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[3:5],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n pc0 = var00 / var01\n pc1 = var10 / var11\n pc0_err = pc0 * (var00_err / var00 + var01_err / var01)\n pc1_err = pc1 * (var10_err / var10 + var11_err / var11)\n if return_log:\n pc0_err = 1 / pc0 * pc0_err\n pc1_err = 1 / pc1 * pc1_err\n pc0 = np.log10(pc0)\n pc1 = np.log10(pc1)\n return pc0, pc0_err, pc1, pc1_err\n```\n###test function signature:\n```python\n\n def test_with_power_err(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_power_err", + "reference": "\n def test_with_power_err(self):\n pc0, pc0_err, pc1, pc1_err = power_color(\n self.freq,\n self.power,\n power_err=self.power / 2,\n )\n pc0e, pc0e_err, pc1e, pc1e_err = power_color(\n self.freq,\n self.power,\n power_err=self.power,\n )\n assert np.isclose(pc0, 1, atol=0.001)\n assert np.isclose(pc1, 1, atol=0.001)\n assert np.isclose(pc0e, 1, atol=0.001)\n assert np.isclose(pc1e, 1, atol=0.001)\n assert np.isclose(pc0e_err / pc0_err, 2, atol=0.001)\n assert np.isclose(pc1e_err / pc1_err, 2, atol=0.001)", + "completions": [ + "```python\n def test_with_power_err(self):\n frequency = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power_err = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n pc0, pc0_err, pc1, pc1_err = power_color(\n frequency,\n power,\n power_err=power_err,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n )\n self.assertEqual(pc0, 0.0)\n self.assertEqual(pc0_err, 0.0)\n self.assertEqual(pc1, 0.0)\n self.assertEqual(pc1_err, 0.0)\n```\n" + ], + "line": 97, + "token": 560, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 14, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def apply_top_k(\n self, logits: numpy.ndarray, filter_value=_MIN_FLOAT\n ) -> numpy.ndarray:\n logits_shape = logits.shape\n logits = logits.reshape(logits.shape[-1])\n top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]\n logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value\n\n return logits.reshape(logits_shape)\n```\n###test function signature:\n```python\n\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_apply_topk", + "reference": "\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11))\n\n token_generator = TokenGenerator(\n logits_shape=logits[-1].shape[-1],\n top_k=3,\n )\n\n filter_value = -float(\"Inf\")\n new_logits = token_generator.apply_top_k(\n logits.copy(), filter_value=filter_value\n )\n\n for _ in range(token_generator.top_k):\n curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits)\n assert curr_max > filter_value\n new_logits = numpy.delete(new_logits, idx)\n\n assert numpy.all(new_logits == filter_value)", + "completions": [ + "```python\n def test_apply_topk(\n self,\n ):\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9," + ], + "line": 21, + "token": 152, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(\n self,\n sources: List[Union[str, Path, ByteStream]],\n meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,\n ):\n parser = MarkdownIt(renderer_cls=RendererPlain)\n if self.table_to_single_line:\n parser.enable(\"table\")\n\n documents = []\n meta_list = normalize_metadata(meta=meta, sources_count=len(sources))\n\n for source, metadata in tqdm(\n zip(sources, meta_list),\n total=len(sources),\n desc=\"Converting markdown files to Documents\",\n disable=not self.progress_bar,\n ):\n try:\n bytestream = get_bytestream_from_source(source)\n except Exception as e:\n logger.warning(\"Could not read {source}. Skipping it. Error: {error}\", source=source, error=e)\n continue\n try:\n file_content = bytestream.data.decode(\"utf-8\")\n text = parser.render(file_content)\n except Exception as conversion_e:\n logger.warning(\n \"Failed to extract text from {source}. Skipping it. Error: {error}\",\n source=source,\n error=conversion_e,\n )\n continue\n\n merged_metadata = {**bytestream.meta, **metadata}\n document = Document(content=text, meta=merged_metadata)\n documents.append(document)\n\n return {\"documents\": documents}\n```\n###test function signature:\n```python\n\n def test_run_calls_normalize_metadata(self, test_files_path):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_calls_normalize_metadata", + "reference": "\n def test_run_calls_normalize_metadata(self, test_files_path):\n bytestream = ByteStream(data=b\"test\", meta={\"author\": \"test_author\", \"language\": \"en\"})\n\n converter = MarkdownToDocument()\n\n with patch(\"haystack.components.converters.markdown.normalize_metadata\") as normalize_metadata, patch(\n \"haystack.components.converters.markdown.MarkdownIt\"\n ):\n converter.run(sources=[bytestream, test_files_path / \"markdown\" / \"sample.md\"], meta={\"language\": \"it\"})\n\n # check that the metadata normalizer is called properly\n normalize_metadata.assert_called_with(meta={\"language\": \"it\"}, sources_count=2)", + "completions": [ + "```python\n def test_run_calls_normalize_metadata(self, test_files_path):\n # Arrange\n sources = [test_files_path / \"test_file.md\"]\n meta = {\"key\": \"value\"}\n\n # Act\n run(sources, meta)\n\n # Assert\n normalize_metadata.assert_called_once_with(meta=meta, sources_count=1)\n```\n" + ], + "line": 45, + "token": 311, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 16, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_config(config: object) -> None:\n if hasattr(config, \"kind\"):\n Config.running_config.classes[config.kind] = config # type: ignore\n Config.running_config.types[config.kind] = {}\n for field in fields(config): # type: ignore\n if hasattr(field, \"type\"):\n Config.running_config.types[config.kind][field.name] = optional_origin(field.type)\n else:\n raise RuntimeError(\"Config must have a 'kind' attribute\")\n```\n###test function signature:\n```python\n\ndef test_config():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_config", + "reference": "\ndef test_config():\n arg_parser = get_arg_parser()\n core_add_args(arg_parser)\n arg_parser.parse_known_args()\n cfg = Config(\"test\")\n cfg.add_config(ConfigTest)\n cfg.init_default_config()\n assert Config.dict() == {\n \"configtest\": {\n \"testvar1\": \"testing123\",\n \"testvar2\": 12345,\n \"testvar3\": {\n \"mydict\": {\"foo\": \"bar\", \"abc\": {\"def\": \"ghi\"}},\n \"myint\": 0,\n \"mystr\": \"Hello\",\n },\n }\n }\n cfg2 = Config(\"test2\")\n assert cfg.configtest.testvar1 == cfg2.configtest.testvar1\n assert cfg.configtest.testvar1 == Config.configtest.testvar1\n assert cfg.configtest.testvar2 == cfg2.configtest.testvar2\n assert cfg.configtest.testvar3 == cfg2.configtest.testvar3\n Config.configtest.testvar2 += 1\n assert cfg.configtest.testvar2 == 12346\n with pytest.raises(ConfigNotFoundError):\n Config.does_not_exist.foo = \"bar\"\n with pytest.raises(ConfigNotFoundError):\n cfg.does_not_exist.foo = \"bar\"", + "completions": [ + "```python\ndef test_config():\n from dataclasses import dataclass, field\n from typing import Optional\n\n @dataclass\n class Config:\n kind: str\n a: Optional[int] = field(default=None)\n b: Optional[int] = field(default=None)\n\n add_config(Config)\n\n assert Config.running_config.classes[\"Config\"] == Config\n assert Config.running_config.types[\"Config\"] == {\"a\": int, \"b\": int}\n```\n" + ], + "line": 19, + "token": 166, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 17, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_with_num_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_with_num_ref_line", + "reference": "\ndef test_nanoplot_out_with_num_ref_line():\n\n out_with_num_ref_line = _generate_nanoplot(**CASES[1])\n\n assert _is_nanoplot_output(out_with_num_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_with_num_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color," + ], + "line": 749, + "token": 7038, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 18, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def begin_typing(self, x=None):\n self.capturekeystrokes = True\n # Disable keypress shortcuts, which may otherwise cause the figure to\n # be saved, closed, etc., until the user stops typing. The way to\n # achieve this depends on whether toolmanager is in use.\n stack = ExitStack() # Register cleanup actions when user stops typing.\n self._on_stop_typing = stack.close\n toolmanager = getattr(\n self.ax.figure.canvas.manager, \"toolmanager\", None)\n if toolmanager is not None:\n # If using toolmanager, lock keypresses, and plan to release the\n # lock when typing stops.\n toolmanager.keypresslock(self)\n stack.callback(toolmanager.keypresslock.release, self)\n else:\n # If not using toolmanager, disable all keypress-related rcParams.\n # Avoid spurious warnings if keymaps are getting deprecated.\n with _api.suppress_matplotlib_deprecation_warning():\n stack.enter_context(mpl.rc_context(\n {k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n```\n###test function signature:\n```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_TextBox", + "reference": "\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n plt.rcParams._set(\"toolbar\", toolbar)\n\n submit_event = mock.Mock(spec=noop, return_value=None)\n text_change_event = mock.Mock(spec=noop, return_value=None)\n tool = widgets.TextBox(ax, '')\n tool.on_submit(submit_event)\n tool.on_text_change(text_change_event)\n\n assert tool.text == ''\n\n do_event(tool, '_click')\n\n tool.set_val('x**2')\n\n assert tool.text == 'x**2'\n assert text_change_event.call_count == 1\n\n tool.begin_typing()\n tool.stop_typing()\n\n assert submit_event.call_count == 2\n\n do_event(tool, '_click')\n do_event(tool, '_keypress', key='+')\n do_event(tool, '_keypress', key='5')\n\n assert text_change_event.call_count == 3", + "completions": [ + "```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n with ExitStack() as stack:\n stack.enter_context(mpl.rc_context({k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n ax.begin_typing()\n assert ax.capturekeystrokes is True\n assert toolbar.keypresslock.is_locked(ax)\n```\n" + ], + "line": 31, + "token": 258, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def read_events(\n db: PrefectDBInterface,\n session: AsyncSession,\n events_filter: EventFilter,\n limit: \"int | None\" = None,\n offset: \"int | None\" = None,\n) -> Sequence[ORMEvent]:\n # Always order by occurred timestamp, with placeholder for order direction\n order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc\n\n # Check if distinct fields are provided\n if distinct_fields := build_distinct_queries(events_filter):\n # Define window function\n window_function = (\n sa.func.row_number()\n .over(partition_by=distinct_fields, order_by=order(db.Event.occurred))\n .label(\"row_number\")\n )\n # Create a subquery with the window function\n subquery = (\n sa.select(db.Event, window_function)\n .where(\n sa.and_(\n *events_filter.build_where_clauses(db)\n ) # Ensure the same filters are applied here\n )\n .subquery()\n )\n\n # Alias the subquery for easier column references\n aliased_table = aliased(db.Event, subquery)\n\n # Create the final query from the subquery, filtering to get only rows with row_number = 1\n select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1)\n\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(subquery.c.occurred))\n\n else:\n # If no distinct fields are provided, create a query for all events\n select_events_query = sa.select(db.Event).where(\n sa.and_(*events_filter.build_where_clauses(db))\n )\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(db.Event.occurred))\n\n if limit is not None:\n limit = max(0, min(limit, events_filter.logical_limit))\n select_events_query = select_events_query.limit(limit=limit)\n if offset is not None:\n select_events_query = select_events_query.offset(offset=offset)\n\n logger.debug(\"Running PostgreSQL query: %s\", select_events_query)\n\n select_events_query_result = await session.execute(select_events_query)\n return select_events_query_result.scalars().unique().all()\n```\n###test function signature:\n```python\n\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_read_events", + "reference": "\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_1.id\n\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=3),\n until=pendulum.now(\"UTC\").subtract(days=1),\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_2.id\n\n # Read criteria should apply AND logic\n # Nothing should be returned for this resource\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n resource=EventResourceFilter(id=[\"prefect.garbage.foo\"]),\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 0", + "completions": [ + "```python\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n # Arrange\n events_filter = EventFilter(\n occurred_after=datetime.datetime(2022, 1, 1),\n occurred_before=datetime.datetime(2022, 1, 2),\n order=EventOrder.DESC,\n )\n # Act\n events = await read_events(\n db=self.db,\n session=session,\n events_filter=events_filter,\n )\n # Assert\n self.assertEqual(len(events), 2)\n self.assertEqual(events[0], event_2)\n self.assertEqual(events[1], event_1)\n```\n" + ], + "line": 58, + "token": 412, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 20, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n def test_evaluate_diff_dim(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_evaluate_diff_dim", + "reference": " def test_evaluate_diff_dim(self):\n x1 = np.arange(3, 10, 2)\n kde = mlab.GaussianKDE(x1)\n x2 = np.arange(3, 12, 2)\n y_expected = [\n 0.08797252, 0.11774109, 0.11774109, 0.08797252, 0.0370153\n ]\n y = kde.evaluate(x2)\n np.testing.assert_array_almost_equal(y, y_expected, 7)", + "completions": [ + "```python\n def test_evaluate_diff_dim(self):\n \"\"\"\n Test evaluate with different dimensions\n \"\"\"\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n\n # Generate a random dataset\n X, _ = make_blobs(n_samples=100, n_features=10, centers=5,\n random_state=42)\n\n # Standardize the data\n scaler = StandardScaler()\n X = scaler.fit_transform(X)\n\n # Create a GaussianMixture object\n gmm = GaussianMixture(n_components=5, covariance_type='full')\n gmm.fit(X)\n\n # Generate a random set of points\n n_points = 100\n points = np.random.rand(n_points, 10)\n\n # Evaluate the points\n result = gmm.evaluate(points)\n\n # Check that the result is correct\n self.assertEqual(result.shape, (n_points,))\n self.assertTrue(np.all(result >= 0))\n self.assertTrue(np.all(result <= 1))\n```\n" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 21, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _differentiate(func, x, *, args=(), atol=None, rtol=None, maxiter=10,\n order=8, initial_step=0.5, step_factor=2.0,\n step_direction=0, preserve_shape=False, callback=None):\n # TODO (followup):\n # - investigate behavior at saddle points\n # - array initial_step / step_factor?\n # - multivariate functions?\n\n res = _differentiate_iv(func, x, args, atol, rtol, maxiter, order, initial_step,\n step_factor, step_direction, preserve_shape, callback)\n (func, x, args, atol, rtol, maxiter, order,\n h0, fac, hdir, preserve_shape, callback) = res\n\n # Initialization\n # Since f(x) (no step) is not needed for central differences, it may be\n # possible to eliminate this function evaluation. However, it's useful for\n # input validation and standardization, and everything else is designed to\n # reduce function calls, so let's keep it simple.\n temp = eim._initialize(func, (x,), args, preserve_shape=preserve_shape)\n func, xs, fs, args, shape, dtype = temp\n x, f = xs[0], fs[0]\n df = np.full_like(f, np.nan)\n # Ideally we'd broadcast the shape of `hdir` in `_elementwise_algo_init`, but\n # it's simpler to do it here than to generalize `_elementwise_algo_init` further.\n # `hdir` and `x` are already broadcasted in `_differentiate_iv`, so we know\n # that `hdir` can be broadcasted to the final shape.\n hdir = np.broadcast_to(hdir, shape).flatten()\n\n status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress\n nit, nfev = 0, 1 # one function evaluations performed above\n # Boolean indices of left, central, right, and (all) one-sided steps\n il = hdir < 0\n ic = hdir == 0\n ir = hdir > 0\n io = il | ir\n\n # Most of these attributes are reasonably obvious, but:\n # - `fs` holds all the function values of all active `x`. The zeroth\n # axis corresponds with active points `x`, the first axis corresponds\n # with the different steps (in the order described in\n # `_differentiate_weights`).\n # - `terms` (which could probably use a better name) is half the `order`,\n # which is always even.\n work = _RichResult(x=x, df=df, fs=f[:, np.newaxis], error=np.nan, h=h0,\n df_last=np.nan, error_last=np.nan, h0=h0, fac=fac,\n atol=atol, rtol=rtol, nit=nit, nfev=nfev,\n status=status, dtype=dtype, terms=(order+1)//2,\n hdir=hdir, il=il, ic=ic, ir=ir, io=io)\n # This is the correspondence between terms in the `work` object and the\n # final result. In this case, the mapping is trivial. Note that `success`\n # is prepended automatically.\n res_work_pairs = [('status', 'status'), ('df', 'df'), ('error', 'error'),\n ('nit', 'nit'), ('nfev', 'nfev'), ('x', 'x')]\n\n def pre_func_eval(work):\n \"\"\"Determine the abscissae at which the function needs to be evaluated.\n\n See `_differentiate_weights` for a description of the stencil (pattern\n of the abscissae).\n\n In the first iteration, there is only one stored function value in\n `work.fs`, `f(x)`, so we need to evaluate at `order` new points. In\n subsequent iterations, we evaluate at two new points. Note that\n `work.x` is always flattened into a 1D array after broadcasting with\n all `args`, so we add a new axis at the end and evaluate all point\n in one call to the function.\n\n For improvement:\n - Consider measuring the step size actually taken, since `(x + h) - x`\n is not identically equal to `h` with floating point arithmetic.\n - Adjust the step size automatically if `x` is too big to resolve the\n step.\n - We could probably save some work if there are no central difference\n steps or no one-sided steps.\n \"\"\"\n n = work.terms # half the order\n h = work.h # step size\n c = work.fac # step reduction factor\n d = c**0.5 # square root of step reduction factor (one-sided stencil)\n # Note - no need to be careful about dtypes until we allocate `x_eval`\n\n if work.nit == 0:\n hc = h / c**np.arange(n)\n hc = np.concatenate((-hc[::-1], hc))\n else:\n hc = np.asarray([-h, h]) / c**(n-1)\n\n if work.nit == 0:\n hr = h / d**np.arange(2*n)\n else:\n hr = np.asarray([h, h/d]) / c**(n-1)\n\n n_new = 2*n if work.nit == 0 else 2 # number of new abscissae\n x_eval = np.zeros((len(work.hdir), n_new), dtype=work.dtype)\n il, ic, ir = work.il, work.ic, work.ir\n x_eval[ir] = work.x[ir, np.newaxis] + hr\n x_eval[ic] = work.x[ic, np.newaxis] + hc\n x_eval[il] = work.x[il, np.newaxis] - hr\n return x_eval\n\n def post_func_eval(x, f, work):\n \"\"\" Estimate the derivative and error from the function evaluations\n\n As in `pre_func_eval`: in the first iteration, there is only one stored\n function value in `work.fs`, `f(x)`, so we need to add the `order` new\n points. In subsequent iterations, we add two new points. The tricky\n part is getting the order to match that of the weights, which is\n described in `_differentiate_weights`.\n\n For improvement:\n - Change the order of the weights (and steps in `pre_func_eval`) to\n simplify `work_fc` concatenation and eliminate `fc` concatenation.\n - It would be simple to do one-step Richardson extrapolation with `df`\n and `df_last` to increase the order of the estimate and/or improve\n the error estimate.\n - Process the function evaluations in a more numerically favorable\n way. For instance, combining the pairs of central difference evals\n into a second-order approximation and using Richardson extrapolation\n to produce a higher order approximation seemed to retain accuracy up\n to very high order.\n - Alternatively, we could use `polyfit` like Jacobi. An advantage of\n fitting polynomial to more points than necessary is improved noise\n tolerance.\n \"\"\"\n n = work.terms\n n_new = n if work.nit == 0 else 1\n il, ic, io = work.il, work.ic, work.io\n\n # Central difference\n # `work_fc` is *all* the points at which the function has been evaluated\n # `fc` is the points we're using *this iteration* to produce the estimate\n work_fc = (f[ic, :n_new], work.fs[ic, :], f[ic, -n_new:])\n work_fc = np.concatenate(work_fc, axis=-1)\n if work.nit == 0:\n fc = work_fc\n else:\n fc = (work_fc[:, :n], work_fc[:, n:n+1], work_fc[:, -n:])\n fc = np.concatenate(fc, axis=-1)\n\n # One-sided difference\n work_fo = np.concatenate((work.fs[io, :], f[io, :]), axis=-1)\n if work.nit == 0:\n fo = work_fo\n else:\n fo = np.concatenate((work_fo[:, 0:1], work_fo[:, -2*n:]), axis=-1)\n\n work.fs = np.zeros((len(ic), work.fs.shape[-1] + 2*n_new))\n work.fs[ic] = work_fc\n work.fs[io] = work_fo\n\n wc, wo = _differentiate_weights(work, n)\n work.df_last = work.df.copy()\n work.df[ic] = fc @ wc / work.h\n work.df[io] = fo @ wo / work.h\n work.df[il] *= -1\n\n work.h /= work.fac\n work.error_last = work.error\n # Simple error estimate - the difference in derivative estimates between\n # this iteration and the last. This is typically conservative because if\n # convergence has begin, the true error is much closer to the difference\n # between the current estimate and the *next* error estimate. However,\n # we could use Richarson extrapolation to produce an error estimate that\n # is one order higher, and take the difference between that and\n # `work.df` (which would just be constant factor that depends on `fac`.)\n work.error = abs(work.df - work.df_last)\n\n def check_termination(work):\n \"\"\"Terminate due to convergence, non-finite values, or error increase\"\"\"\n stop = np.zeros_like(work.df).astype(bool)\n\n i = work.error < work.atol + work.rtol*abs(work.df)\n work.status[i] = eim._ECONVERGED\n stop[i] = True\n\n if work.nit > 0:\n i = ~((np.isfinite(work.x) & np.isfinite(work.df)) | stop)\n work.df[i], work.status[i] = np.nan, eim._EVALUEERR\n stop[i] = True\n\n # With infinite precision, there is a step size below which\n # all smaller step sizes will reduce the error. But in floating point\n # arithmetic, catastrophic cancellation will begin to cause the error\n # to increase again. This heuristic tries to avoid step sizes that are\n # too small. There may be more theoretically sound approaches for\n # detecting a step size that minimizes the total error, but this\n # heuristic seems simple and effective.\n i = (work.error > work.error_last*10) & ~stop\n work.status[i] = _EERRORINCREASE\n stop[i] = True\n\n return stop\n\n def post_termination_check(work):\n return\n\n def customize_result(res, shape):\n return shape\n\n return eim._loop(work, callback, shape, maxiter, func, args, dtype,\n pre_func_eval, post_func_eval, check_termination,\n post_termination_check, customize_result, res_work_pairs,\n preserve_shape)\n```\n###test function signature:\n```python\n\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_maxiter_callback", + "reference": "\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n x = 0.612814\n dist = stats.norm()\n maxiter = 3\n\n def f(x):\n res = dist.cdf(x)\n return res\n\n default_order = 8\n res = differentiate(f, x, maxiter=maxiter, rtol=1e-15)\n assert not np.any(res.success)\n assert np.all(res.nfev == default_order + 1 + (maxiter - 1)*2)\n assert np.all(res.nit == maxiter)\n\n def callback(res):\n callback.iter += 1\n callback.res = res\n assert hasattr(res, 'x')\n assert res.df not in callback.dfs\n callback.dfs.add(res.df)\n assert res.status == eim._EINPROGRESS\n if callback.iter == maxiter:\n raise StopIteration\n callback.iter = -1 # callback called once before first iteration\n callback.res = None\n callback.dfs = set()\n\n res2 = differentiate(f, x, callback=callback, rtol=1e-15)\n # terminating with callback is identical to terminating due to maxiter\n # (except for `status`)\n for key in res.keys():\n if key == 'status':\n assert res[key] == eim._ECONVERR\n assert callback.res[key] == eim._EINPROGRESS\n assert res2[key] == eim._ECALLBACK\n else:\n assert res2[key] == callback.res[key] == res[key]", + "completions": [ + "```python\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n def callback(work):\n if work.nit == 0:\n work.maxiter = 1\n else:\n work.maxiter = 0\n res = _differentiate(lambda x: x**2, 0, maxiter=2, callback=callback)\n self.assertEqual(res.nit, 1)\n self.assertEqual(res.status, 0)\n```\n" + ], + "line": 187, + "token": 1981, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 22, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef get_standard_metadata(\n pyproject_dict: Mapping[str, Any],\n settings: ScikitBuildSettings,\n) -> StandardMetadata:\n new_pyproject_dict = dict(pyproject_dict)\n # Handle any dynamic metadata\n for field, provider, config in load_dynamic_metadata(settings.metadata):\n if provider is None:\n msg = f\"{field} is missing provider\"\n raise KeyError(msg)\n if field not in pyproject_dict.get(\"project\", {}).get(\"dynamic\", []):\n msg = f\"{field} is not in project.dynamic\"\n raise KeyError(msg)\n new_pyproject_dict[\"project\"][field] = provider.dynamic_metadata(field, config)\n new_pyproject_dict[\"project\"][\"dynamic\"].remove(field)\n\n metadata = StandardMetadata.from_pyproject(new_pyproject_dict)\n\n # For scikit-build-core < 0.5, we keep the normalized name for back-compat\n if settings.minimum_version is not None and settings.minimum_version < Version(\n \"0.5\"\n ):\n metadata.name = metadata.canonical_name\n\n # The description field is required to be one line. Instead of merging it\n # or cutting off subsequent lines (setuptools), we throw a nice error.\n # But we didn't validate before 0.9.\n if (\n settings.minimum_version is None or settings.minimum_version >= Version(\"0.9\")\n ) and \"\\n\" in (metadata.description or \"\"):\n msg = \"Multiple lines in project.description are not supported; this is supposed to be a one line summary\"\n raise ValueError(msg)\n\n return metadata\n```\n###test function signature:\n```python\n\ndef test_plugin_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_plugin_metadata", + "reference": "\ndef test_plugin_metadata():\n reason_msg = (\n \"install hatch-fancy-pypi-readme and setuptools-scm to test the \"\n \"dynamic metadata plugins\"\n )\n pytest.importorskip(\"hatch_fancy_pypi_readme\", reason=reason_msg)\n pytest.importorskip(\"setuptools_scm\", reason=reason_msg)\n if shutil.which(\"git\") is None:\n pytest.skip(\"git is not installed\")\n\n shutil.copy(\"plugin_project.toml\", \"pyproject.toml\")\n\n subprocess.run([\"git\", \"init\", \"--initial-branch=main\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.name\", \"bot\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.email\", \"bot@scikit-build.org\"], check=True)\n subprocess.run([\"git\", \"add\", \"pyproject.toml\"], check=True)\n subprocess.run([\"git\", \"commit\", \"-m\", \"initial commit\"], check=True)\n subprocess.run([\"git\", \"tag\", \"v0.1.0\", \"-m\", \"initial commint\"], check=True)\n\n with Path(\"pyproject.toml\").open(\"rb\") as ft:\n pyproject = tomllib.load(ft)\n settings_reader = SettingsReader(pyproject, {}, state=\"metadata_wheel\")\n settings = settings_reader.settings\n\n settings_reader.validate_may_exit()\n\n metadata = get_standard_metadata(pyproject, settings)\n\n assert str(metadata.version) == \"0.1.0\"\n assert metadata.readme == pyproject_metadata.Readme(\n \"Fragment #1Fragment #2\", None, \"text/x-rst\"\n )\n\n assert set(GetRequires().dynamic_metadata()) == {\n \"hatch-fancy-pypi-readme>=22.3\",\n \"setuptools-scm\",\n }", + "completions": [ + "```python\n\ndef test_plugin_metadata():\n from skbuild import __version__ as skbuild_version\n from skbuild.settings import ScikitBuildSettings\n from skbuild.metadata import StandardMetadata\n from skbuild.utils import load_dynamic_metadata\n from skbuild.version import Version\n\n pyproject_dict = {\n \"build-system\": {\n \"requires\": [\"setuptools>=40.8.0\", \"wheel\"],\n \"build-backend\": \"setuptools.build_meta\",\n },\n \"project\": {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n \"description\": \"A build system for Python packages\",\n \"dynamic\": [\"name\", \"version\"],\n },\n }\n\n settings = ScikitBuildSettings(\n minimum_version=Version(\"0.1.0\"),\n metadata=[\n (\"name\", None, None),\n (\"version\", None, None),\n ],\n )\n\n metadata = get_standard_metadata(pyproject_dict, settings)\n\n assert metadata.name == \"scikit-build\"\n assert metadata.version == \"0.1.0\"\n assert metadata.description == \"A build system for Python packages\"\n\n assert metadata.canonical_name == \"scikit-build\"\n assert metadata.canonical_version == \"0.1.0\"\n\n assert metadata.dynamic_metadata == {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n }\n\n assert metadata.dynamic_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_providers == [None, None]\n\n assert metadata.dynamic_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_metadata_providers == [None, None]\n\n assert metadata.dynamic_metadata_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields_by_provider == {None: [\"name\", \"version\"]}\n\n assert metadata.dynamic_metadata_providers_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_configs_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field == {\n (None, \"name\"): \"scikit-build\",\n (None, \"version\"): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider == {\n (\"name\", None): \"scikit-build\",\n (\"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider == {\n (\"name\", None): None,\n (\"version\", None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config == {\n (None, \"name\", None): \"scikit-build\",\n (None, \"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config == {\n (\"name\", None, None): \"scikit-build\",\n (\"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider_and_config == {\n (\"name\", None, None): None,\n (\"version\", None, None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config_and_provider == {\n (None, \"name\", None, None): \"scikit-build\",\n (None, \"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config_and_provider == {\n (\"name\", None, None" + ], + "line": 40, + "token": 347, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 23, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _get_datetime_format(value):\n if not isinstance(value, pd.Series):\n value = pd.Series(value)\n\n value = value[~value.isna()]\n value = value.astype(str).to_numpy()\n\n return _guess_datetime_format_for_array(value)\n```\n###test function signature:\n```python\ndef test__get_datetime_format():\n # Setup\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test__get_datetime_format", + "reference": "def test__get_datetime_format():\n # Setup\n string_value = '2021-02-02'\n list_value = [np.nan, '2021-02-02']\n series_value = pd.Series(['2021-02-02T12:10:59'])\n\n # Run\n string_out = _get_datetime_format(string_value)\n list_out = _get_datetime_format(list_value)\n series_out = _get_datetime_format(series_value)\n\n # Assert\n expected_output = '%Y-%m-%d'\n assert string_out == expected_output\n assert list_out == expected_output\n assert series_out == '%Y-%m-%dT%H:%M:%S'", + "completions": [ + "```python\ndef test__get_datetime_format():\n # Setup\n value = pd.Series([\"2020-01-01\", \"2020-01-02\", \"2020-01-03\"])\n expected = \"%Y-%m-%d\"\n\n # Exercise\n actual = _get_datetime_format(value)\n\n # Verify\n assert actual == expected\n```\n" + ], + "line": 17, + "token": 124, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 24, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def load(self) -> None:\n super().load()\n\n _path = str(\n importlib_resources.files(\"distilabel\")\n / \"steps\"\n / \"tasks\"\n / \"templates\"\n / \"ultrafeedback\"\n / f\"{self.aspect}.jinja2\"\n )\n\n self._template = Template(open(_path).read())\n```\n###test function signature:\n```python\n\n def test_process_with_complex_aspect(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_process_with_complex_aspect", + "reference": "\n def test_process_with_complex_aspect(self) -> None:\n pipeline = Pipeline(name=\"unit-test-pipeline\")\n llm = UltraFeedbackLLM()\n\n task = UltraFeedback(\n name=\"ultrafeedback\",\n aspect=\"truthfulness\",\n llm=llm,\n pipeline=pipeline,\n )\n task.load()\n\n assert next(\n task.process([{\"instruction\": \"test\", \"generations\": [\"A\", \"B\"]}])\n ) == [\n {\n \"instruction\": \"test\",\n \"generations\": [\"A\", \"B\"],\n \"types\": [1, 2],\n \"rationales\": [\"text\", \"text\"],\n \"ratings\": [1, 2],\n \"rationales-for-ratings\": [\"text\", \"text\"],\n \"model_name\": \"ultrafeedback-model\",\n }\n ]", + "completions": [ + "```python\n def test_process_with_complex_aspect(self) -> None:\n self.aspect = \"complex_aspect\"\n self.load()\n self.assertEqual(\n self._template.render(\n {\n \"aspect\": self.aspect,\n \"aspect_name\": \"Complex Aspect\",\n \"aspect_description\": \"This is a complex aspect\",\n \"aspect_type\": \"complex\",\n \"aspect_value\": \"1\",\n \"aspect_weight\": \"1\",\n \"aspect_color\": \"red\",\n \"aspect_icon\": \"icon-circle\",\n \"aspect_icon_color\": \"red\",\n \"aspect_icon_bg_color\": \"white\",\n \"aspect_icon_size\": \"20\",\n \"aspect_icon_border_color\": \"red\",\n \"aspect_icon_border_width\": \"2\",\n \"aspect_icon_border_radius\": \"50\",\n \"aspect_icon_padding\": \"5\",\n \"aspect_icon_margin\": \"5\",\n \"aspect_icon_bg_size\": \"20\",\n \"aspect_icon_bg_border_color\": \"red\",\n \"aspect_icon_bg_border_width\": \"2\",\n \"aspect_icon_bg_border_radius\": \"50\",\n \"aspect_icon_bg_padding\": \"5\",\n \"aspect_icon_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_" + ], + "line": 21, + "token": 145, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 25, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _setup_cuda_device(device: torch.device) -> torch.device:\n local_rank = _get_local_rank() or 0\n if device.index is None:\n device = torch.device(type=\"cuda\", index=local_rank)\n\n # Ensure index is available before setting device\n if device.index >= torch.cuda.device_count():\n raise RuntimeError(\n \"The local rank is larger than the number of available GPUs.\"\n )\n\n torch.cuda.set_device(device)\n return device\n```\n###test function signature:\n```python\n\n def test_get_gpu_device(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_gpu_device", + "reference": "\n def test_get_gpu_device(self) -> None:\n device_idx = torch.cuda.device_count() - 1\n assert device_idx >= 0\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(device_idx)}, clear=True):\n device = get_device()\n assert device.type == \"cuda\"\n assert device.index == device_idx\n assert device.index == torch.cuda.current_device()\n\n # Test that we raise an error if the device index is specified on distributed runs\n if device_idx > 0:\n with pytest.raises(\n RuntimeError,\n match=f\"Device specified is cuda:0 but was assigned cuda:{device_idx}\",\n ):\n device = get_device(\"cuda:0\")\n\n invalid_device_idx = device_idx + 10\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(invalid_device_idx)}):\n with pytest.raises(\n RuntimeError,\n match=\"The local rank is larger than the number of available GPUs\",\n ):\n device = get_device(\"cuda\")\n\n # Test that we fall back to 0 if LOCAL_RANK is not specified\n device = torch.device(_get_device_type_from_env())\n device = _setup_cuda_device(device)\n assert device.type == \"cuda\"\n assert device.index == 0\n assert device.index == torch.cuda.current_device()", + "completions": [ + "```python\n def test_get_gpu_device(self) -> None:\n device = torch.device(type=\"cuda\", index=0)\n self.assertEqual(device, _setup_cuda_device(device))\n```\n" + ], + "line": 21, + "token": 157, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):\n local_dir = settings.home_dir / settings.cache_dir / \"models\" / model_id\n client.load_artifact(local_dir, posixpath.join(\"models\", model_id))\n meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)\n # internal worker case, no token based http client\n if not local_dir.exists():\n raise RuntimeError(f\"Cannot find existing model {project_key}.{model_id} in {local_dir}\")\n with (Path(local_dir) / META_FILENAME).open(encoding=\"utf-8\") as f:\n file_meta = yaml.load(f, Loader=yaml.Loader)\n classification_labels = cls.cast_labels(meta_response)\n meta = ModelMeta(\n name=meta_response.name,\n description=meta_response.description,\n model_type=SupportedModelTypes[meta_response.modelType],\n feature_names=meta_response.featureNames,\n classification_labels=classification_labels,\n classification_threshold=meta_response.threshold,\n loader_module=file_meta[\"loader_module\"],\n loader_class=file_meta[\"loader_class\"],\n )\n\n model_py_ver = (\n tuple(meta_response.languageVersion.split(\".\")) if \"PYTHON\" == meta_response.language.upper() else None\n )\n\n clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver)\n\n constructor_params = meta.__dict__\n constructor_params[\"id\"] = str(model_id)\n\n del constructor_params[\"loader_module\"]\n del constructor_params[\"loader_class\"]\n\n model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params)\n return model\n```\n###test function signature:\n```python\n\ndef test_model_download(request):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_model_download", + "reference": "\ndef test_model_download(request):\n model: BaseModel = request.getfixturevalue(\"german_credit_model\")\n project_key = str(uuid.uuid4())\n\n with utils.MockedProjectCacheDir():\n with utils.MockedClient(mock_all=False) as (client, mr):\n # The model needs to request files\n requested_urls = []\n requested_urls.extend(utils.register_uri_for_model_meta_info(mr, model, project_key))\n requested_urls.extend(\n utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True)\n )\n\n downloaded_model = BaseModel.download(client=client, project_key=project_key, model_id=str(model.id))\n\n for requested_url in requested_urls:\n assert utils.is_url_requested(mr.request_history, requested_url)\n\n assert downloaded_model.id == model.id\n assert downloaded_model.meta == model.meta", + "completions": [ + "```python\ndef test_model_download(request):\n from pathlib import Path\n from unittest.mock import MagicMock\n from giskard.client import GiskardClient\n from giskard.models import ModelMetaInfo, ModelMeta, SupportedModelTypes\n from giskard.utils import cast_labels\n from giskard.models.base import Model\n from giskard.models.classification import ClassificationModel\n from giskard.models.regression import RegressionModel\n from giskard.models.clustering import ClusteringModel\n from giskard.models.dim_reduction import DimReductionModel\n from giskard.models.anomaly_detection import AnomalyDetectionModel\n from giskard.models.time_series import TimeSeriesModel\n from giskard.models.text_classification import TextClassificationModel\n from giskard.models.text_recommendation import TextRecommendationModel\n from giskard.models.image_classification import ImageClassificationModel\n from giskard.models.image_recommendation import ImageRecommendationModel\n from giskard.models.image_segmentation import ImageSegmentationModel\n from giskard.models.image_generation import ImageGenerationModel\n from giskard.models.image_style_transfer import ImageStyleTransferModel\n from giskard.models.image_super_resolution import ImageSuperResolutionModel\n from giskard.models.image_denoising import ImageDenoisingModel\n from giskard.models.image_colorization import ImageColorizationModel\n from giskard.models.image_deraining import ImageDerainingModel\n from giskard.models.image_deblurring import ImageDeblurringModel\n from giskard.models.image_edge_detection import ImageEdgeDetectionModel\n from giskard.models.image_face_detection import ImageFaceDetectionModel\n from giskard.models.image_face_recognition import ImageFaceRecognitionModel\n from giskard.models.image_object_detection import ImageObjectDetectionModel\n from giskard.models.image_object_recognition import ImageObjectRecognitionModel\n from giskard.models.image_segmentation_mask_rcnn import ImageSegmentationMaskRCNNModel\n from giskard.models.image_segmentation_yolo import ImageSegmentationYOLOModel\n from giskard.models.image_segmentation_faster_rcnn import ImageSegmentationFasterRCNNModel\n from giskard.models.image_segmentation_unet import ImageSegmentationUNetModel\n from giskard.models.image_segmentation_deeplab import ImageSegmentationDeepLabModel\n from giskard.models.image_segmentation_fcn import ImageSegmentationFCNModel\n from giskard.models.image_segmentation_pspnet import ImageSegmentationPSPNetModel\n from giskard.models.image_segmentation_linknet import ImageSegmentationLinkNetModel\n from giskard.models.image_segmentation_unet_plus import ImageSegmentationUNetPlusModel\n from giskard.models.image_segmentation_deeplab_v3 import ImageSegmentationDeepLabV3Model\n from giskard.models.image_segmentation_deeplab_v3_plus import ImageSegmentationDeepLabV3PlusModel\n from giskard.models.image_segmentation_fcn_resnet101 import ImageSegmentationFCNResNet101Model\n from giskard.models.image_segmentation_fcn_resnet50 import ImageSegmentationFCNResNet50Model\n from giskard.models.image_segmentation_fcn_mobilenet_v2 import ImageSegmentationFCNMobileNetV2Model\n from giskard.models.image_segmentation_fcn_mobilenet_v3 import ImageSegmentationFCNMobileNetV3Model\n from giskard.models.image_segmentation_fcn_efficientnet_b" + ], + "line": 40, + "token": 316, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 27, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\ndef test_rounded_date_columns():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rounded_date_columns", + "reference": "def test_rounded_date_columns():\n df = pl.DataFrame(\n [\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.000000\", \"exp\": True},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.123456\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:01.000000\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00\", \"exp\": True},\n ]\n )\n\n df = df.with_columns(df[\"a\"].str.to_date().alias(\"a_dt\"))\n actual = columns_equal(df[\"a_dt\"], df[\"b\"])\n expected = df[\"exp\"]\n assert_series_equal(actual, expected, check_names=False)", + "completions": [ + "```python\ndef test_rounded_date_columns():\n # given\n col_1 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n col_2 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n\n # when\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n\n # then\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 28, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef annotations_to_instances(annos, image_size, mask_format=\"polygon\"):\n boxes = [BoxMode.convert(obj[\"bbox\"], obj[\"bbox_mode\"], BoxMode.XYXY_ABS) for obj in annos]\n target = Instances(image_size)\n boxes = target.gt_boxes = Boxes(boxes)\n boxes.clip(image_size)\n\n classes = [obj[\"category_id\"] for obj in annos]\n classes = torch.tensor(classes, dtype=torch.int64)\n target.gt_classes = classes\n\n if len(annos) and \"segmentation\" in annos[0]:\n segms = [obj[\"segmentation\"] for obj in annos]\n if mask_format == \"polygon\":\n masks = PolygonMasks(segms)\n else:\n assert mask_format == \"bitmask\", mask_format\n masks = []\n for segm in segms:\n if isinstance(segm, list):\n # polygon\n masks.append(polygons_to_bitmask(segm, *image_size))\n elif isinstance(segm, dict):\n # COCO RLE\n masks.append(mask_util.decode(segm))\n elif isinstance(segm, np.ndarray):\n assert segm.ndim == 2, \"Expect segmentation of 2 dimensions, got {}.\".format(\n segm.ndim\n )\n # mask array\n masks.append(segm)\n else:\n raise ValueError(\n \"Cannot convert segmentation of type '{}' to BitMasks!\"\n \"Supported types are: polygons as list[list[float] or ndarray],\"\n \" COCO-style RLE as a dict, or a full-image segmentation mask \"\n \"as a 2D ndarray.\".format(type(segm))\n )\n # torch.from_numpy does not support array with negative stride.\n masks = BitMasks(\n torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])\n )\n target.gt_masks = masks\n\n if len(annos) and \"keypoints\" in annos[0]:\n kpts = [obj.get(\"keypoints\", []) for obj in annos]\n target.gt_keypoints = Keypoints(kpts)\n\n return target\n```\n###test function signature:\n```python\n\n def test_transform_simple_annotation(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_simple_annotation", + "reference": "\n def test_transform_simple_annotation(self):\n transforms = T.TransformList([T.HFlipTransform(400)])\n anno = {\n \"bbox\": np.asarray([10, 10, 200, 300]),\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 3,\n \"segmentation\": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],\n }\n\n output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))\n self.assertTrue(np.allclose(output[\"bbox\"], [200, 10, 390, 300]))\n self.assertEqual(len(output[\"segmentation\"]), len(anno[\"segmentation\"]))\n self.assertTrue(np.allclose(output[\"segmentation\"][0], [390, 10, 300, 100, 300, 10]))\n\n detection_utils.annotations_to_instances([output, output], (400, 400))", + "completions": [ + "```python\n def test_transform_simple_annotation(self):\n annos = [\n {\n \"bbox\": [10, 10, 100, 100],\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 0,\n \"segmentation\": [\n [10, 10, 50, 50, 90, 90, 100, 100],\n [20, 20, 50, 50, 80, 80, 100, 100],\n ],\n \"keypoints\": [10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n }\n ]\n image_size = (100, 100)\n target = annotations_to_instances(annos, image_size)\n assert target.gt_boxes.tensor.tolist() == [[10, 10, 100, 100]]\n assert target.gt_classes.tolist() == [0]\n assert target.gt_masks.polygons().tolist() == [\n [[10, 10], [50, 50], [90, 90], [100, 100]],\n [[20, 20], [50, 50], [80, 80], [100, 100]],\n ]\n assert target.gt_keypoints.tensor.tolist() == [[10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n```\n" + ], + "line": 54, + "token": 437, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 29, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_simple(seed):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_simple", + "reference": "\ndef test_sabre_simple(seed):\n placer = Trivial()\n circ = Circuit(5)\n circ.add(gates.CZ(0, 1))\n initial_layout = placer(circ)\n router = Sabre(connectivity=star_connectivity(), seed=seed)\n routed_circuit, final_map = router(circuit=circ, initial_layout=initial_layout)\n assert router.added_swaps == 1\n assert final_map == {\"q0\": 2, \"q1\": 1, \"q2\": 0, \"q3\": 3, \"q4\": 4}\n assert routed_circuit.queue[0].qubits == (0, 2)\n assert isinstance(routed_circuit.queue[0], gates.SWAP)\n assert isinstance(routed_circuit.queue[1], gates.CZ)\n assert_connectivity(circuit=routed_circuit, connectivity=star_connectivity())\n assert_circuit_equivalence(\n original_circuit=circ,\n transpiled_circuit=routed_circuit,\n final_map=final_map,\n initial_map=initial_layout,\n )", + "completions": [ + "```python\ndef test_sabre_simple(seed):\n import random\n random.seed(seed)\n import networkx as nx\n from pyquil.quil import Program\n from pyquil.api import QVMConnection\n from pyquil.gates import CNOT, H, MEASURE\n from pyquil.quilbase import Declare\n from pyquil.quilbase import DefGate\n from pyquil.quilbase import DefCircuit\n from pyquil.quilbase import Gate\n from pyquil.quilbase import Circuit\n from pyquil.quilbase import QubitPlaceholder\n from pyquil.quilbase import MemoryReference\n from pyquil.quilbase import Parameter\n from pyquil.quilbase import Pragma\n from pyquil.quilbase import PragmaDeclare\n from pyquil.quilbase import PragmaDefGate\n from pyquil.quilbase import PragmaDefCircuit\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefQubit\n from pyquil.quilbase import PragmaDefMemory\n from pyquil.quilbase import PragmaDefRegister\n from pyquil.quilbase import PragmaDefLabel\n from pyquil.quilbase import PragmaDefFunction\n from pyquil.quilbase import PragmaDefProcedure\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyqu" + ], + "line": 27, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 30, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef unique_patterns_2d(\n agrid: NDArray[np.int64], ksize: int, periodic_input: bool\n) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:\n assert ksize >= 1\n if periodic_input:\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n else:\n # TODO: implement non-wrapped image handling\n # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n\n patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(\n agrid,\n (\n agrid.shape[0] - ksize + 1,\n agrid.shape[1] - ksize + 1,\n ksize,\n ksize,\n *agrid.shape[2:],\n ),\n agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],\n writeable=False,\n )\n patch_codes = hash_downto(patches, 2)\n uc, ui = np.unique(patch_codes, return_index=True)\n locs = np.unravel_index(ui, patch_codes.shape)\n up: NDArray[np.int64] = patches[locs[0], locs[1]]\n ids: NDArray[np.int64] = np.vectorize(\n {code: ind for ind, code in enumerate(uc)}.get\n )(patch_codes)\n return ids, up, patch_codes\n```\n###test function signature:\n```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_patterns_2d", + "reference": "\ndef test_unique_patterns_2d(img_redmaze) -> None:\n img = img_redmaze\n tile_size = 1\n pattern_width = 2\n _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(\n img, tile_size\n )\n\n (\n _patterns_in_grid,\n pattern_contents_list,\n patch_codes,\n ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)\n assert patch_codes[1][2] == 4867810695119132864\n assert pattern_contents_list[7][1][1] == 8253868773529191888", + "completions": [ + "```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n ids, up, patch_codes = unique_patterns_2d(img_redmaze, 3, True)\n assert ids.shape == (10,)\n assert up.shape == (10, 3, 3)\n assert patch_codes.shape == (10,)\n```\n" + ], + "line": 48, + "token": 425, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 31, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef serialize(data: Any, encoding: str | None = None) -> str | bytes:\n serialized: bytes | str\n\n if encoding is not None:\n serialized = yaml.dump(data, encoding=encoding, Dumper=AiiDADumper)\n else:\n serialized = yaml.dump(data, Dumper=AiiDADumper)\n\n return serialized\n```\n###test function signature:\n```python\ndef test_mixed_attribute_normal_dict():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mixed_attribute_normal_dict", + "reference": "def test_mixed_attribute_normal_dict():\n from aiida.common.extendeddicts import AttributeDict\n\n # Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively\n dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}\n attribute_dict = AttributeDict(dictionary)\n\n # Now add a normal dictionary in the attribute dictionary\n attribute_dict['nested']['normal'] = {'a': 2}\n\n serialized = serialize.serialize(attribute_dict)\n deserialized = serialize.deserialize_unsafe(serialized)\n\n assert attribute_dict, deserialized", + "completions": [ + "```python\ndef test_mixed_attribute_normal_dict():\n data = {\n \"a\": 1,\n \"b\": 2,\n \"c\": 3,\n \"d\": 4,\n \"e\": 5,\n \"f\": 6,\n \"g\": 7,\n \"h\": 8,\n \"i\": 9,\n \"j\": 10,\n \"k\": 11,\n \"l\": 12,\n \"m\": 13,\n \"n\": 14,\n \"o\": 15,\n \"p\": 16,\n \"q\": 17,\n \"r\": 18,\n \"s\": 19,\n \"t\": 20,\n \"u\": 21,\n \"v\": 22,\n \"w\": 23,\n \"x\": 24,\n \"y\": 25,\n \"z\": 26,\n }\n assert serialize(data) == \"\"\"a: 1\nb: 2\nc: 3\nd: 4\ne: 5\nf: 6\ng: 7\nh: 8\ni: 9\nj: 10\nk: 11\nl: 12\nm: 13\nn: 14\no: 15\np: 16\nq: 17\nr: 18\ns: 19\nt: 20\nu: 21\nv: 22\nw: 23\nx: 24\ny: 25\nz: 26\n\"\"\"\n```\n" + ], + "line": 17, + "token": 133, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected_against_alias", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n http_client.select_model(model_id=\"yolov8n-640\")\n\n # when\n result = http_client.unload_model(model_id=\"yolov8n-640\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"coco/3\",\n }\n assert (\n http_client.selected_model is None\n ), \"Even when alias is in use - selected model should be emptied\"", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n requests_mock.post(\n f\"{API_URL}/model/remove\",\n json={\n \"model_id\": \"test_model_id\",\n },\n headers=DEFAULT_HEADERS,\n )\n client = RoboflowClient(api_key=\"test_api_key\", api_secret=\"test_api_secret\")\n # when\n response = client.unload_model(model_id=\"test_model_alias\")\n # then\n assert response.model_id == \"test_model_id\"\n assert response.model_name == \"test_model_name\"\n assert response.model_version == \"test_model_version\"\n assert response.model_type == \"test_model_type\"\n assert response.model_status == \"test_model_status\"\n assert response.model_created_at == \"test_model_created_at\"\n assert response.model_updated_at == \"test_model_updated_at\"\n assert response.model_tags == \"test_model_tags\"\n assert response.model_description == \"test_model_description\"\n assert response.model_url == \"test_model_url\"\n assert response.model_thumbnail_url == \"test_model_thumbnail_url\"\n assert response.model_training_data_url == \"test_model_training_data_url\"\n assert response.model_training_data_type == \"test_model_training_data_type\"\n assert response.model_training_data_size == \"test_model_training_data_size\"\n assert response.model_training_data_description == \"test_model_training_data_description\"\n assert response.model_training_data_tags == \"test_model_training_data_tags\"\n assert response.model_training_data_created_at == \"test_model_training_data_created_at\"\n assert response.model_training_data_updated_at == \"test_model_training_data_updated_at\"\n assert response.model_training_data_status == \"test_model_training_data_status\"\n assert response.model_training_data_status_message == \"test_model_training_data_status_message\"\n assert response.model_training_data_status_details == \"test_model_training_data_status_details\"\n assert response.model_training_data_status_created_at == \"test_model_training_data_status_created_at\"\n assert response.model_training_data_status_updated_at == \"test_model_training_data_status_updated_at\"\n assert response.model_training_data_status_completed_at == \"test_model_training_data_status_completed_at\"\n assert response.model_training_data_status_duration == \"test_model_training_data_status_duration\"\n assert response.model_training_data_status_budget_type == \"test_model_training_data_status_budget_type\"\n assert response.model_training_data_status_budget_amount == \"test_model_training_data_status_budget_amount\"\n assert response.model_training_data_status_budget_unit == \"test_model_training_data_status_budget_unit\"\n assert response.model_training_data_status_budget_remaining == \"test_model_training_data_status_budget_remaining\"\n assert response.model_training_data_status_budget_message == \"test_model_training_data_status_budget_message\"\n assert response.model_training_data_status_budget_details == \"test_model_training_data_status_budget_details\"\n assert response.model_training_data_status_budget_created_at == \"test_model_training_data_status_budget_created_at\"\n assert response.model_training_data_status_budget_updated_at == \"test_" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 33, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef edge_index_to_vector(\n edge_index: Tensor,\n size: Tuple[int, int],\n bipartite: bool,\n force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n row, col = edge_index\n\n if bipartite: # No need to account for self-loops.\n idx = (row * size[1]).add_(col)\n population = size[0] * size[1]\n return idx, population\n\n elif force_undirected:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We only operate on the upper triangular matrix:\n mask = row < col\n row, col = row[mask], col[mask]\n offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n idx = row.mul_(num_nodes).add_(col).sub_(offset)\n population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n return idx, population\n\n else:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We remove self-loops as we do not want to take them into account\n # when sampling negative values.\n mask = row != col\n row, col = row[mask], col[mask]\n col[row < col] -= 1\n idx = row.mul_(num_nodes - 1).add_(col)\n population = num_nodes * num_nodes - num_nodes\n return idx, population\n```\n###test function signature:\n```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_edge_index_to_vector_and_vice_versa", + "reference": "\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n N = 10\n row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n edge_index = torch.stack([row, col], dim=0)\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n assert population == N * N\n assert idx.tolist() == list(range(population))\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n assert is_undirected(edge_index2)\n assert edge_index.tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n assert population == N * N - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n force_undirected=True)\n assert population == (N * (N + 1)) / 2 - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n force_undirected=True)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()", + "completions": [ + "```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n\n # Test force_undirected:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test force_undirected with bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n```\n" + ], + "line": 43, + "token": 351, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 34, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_integration_test(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_integration_test", + "reference": "\n def test_spectrogram_integration_test(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,\n 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,\n 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,\n 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,\n 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,\n 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,\n 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,\n 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,\n 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,\n 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,\n 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,\n 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,\n 0.0293578 , 0.03452379, 0.02194803, 0.01676056,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=256,\n num_mel_filters=400,\n min_frequency=20,\n max_frequency=8000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"kaldi\",\n triangularize_in_mel_space=True,\n )\n\n mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"povey\", periodic=False),\n frame_length=400,\n hop_length=160,\n fft_length=512,\n power=2.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n preemphasis=0.97,\n mel_filters=mel_filters,\n log_mel=\"log\",\n mel_floor=1.1920928955078125e-07,\n remove_dc_offset=True,\n )\n self.assertEqual(spec.shape, (400, 584))\n\n # fmt: off\n expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -6.52463769, -7.73677889, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.18650018, -3.37195286,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -4.70190154, -2.4217066 , -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -5.62755239, -3.53385194,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -9.43303023, -8.77480925, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.2951092 , -5.51585994,\n -15.94238515, -15.94238515, -15.94238515, -4.40151721,\n -3.95228878, -15.94238515, -15.94238515, -15.94238515,\n -6.10365415, -4.59494697, -15.94238515, -15.94238515,\n -15.94238515, -8.10727767, -6.2585298 , -15.94238515,\n -15.94238515, -15.94238515, -5.60161702, -4.47217004,\n -15.94238515, -15.94238515, -15.94238515, -5.91641988]\n )\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))", + "completions": [ + "```python\n def test_spectrogram_integration_test(self):\n waveform = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])\n window = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n expected_spectrogram = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]])\n\n self.assertTrue(np.allclose(spectrogram, expected_spectrogram))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 35, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef run_query_expansion_node(modules: List[Callable],\n module_params: List[Dict],\n previous_result: pd.DataFrame,\n node_line_dir: str,\n strategies: Dict,\n ) -> pd.DataFrame:\n if not os.path.exists(node_line_dir):\n os.makedirs(node_line_dir)\n node_dir = os.path.join(node_line_dir, \"query_expansion\")\n if not os.path.exists(node_dir):\n os.makedirs(node_dir)\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n\n # run query expansion\n results, execution_times = zip(*map(lambda task: measure_speed(\n task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))\n average_times = list(map(lambda x: x / len(results[0]), execution_times))\n\n # save results to folder\n pseudo_module_params = deepcopy(module_params)\n for i, module_param in enumerate(pseudo_module_params):\n if 'prompt' in module_params:\n module_param['prompt'] = str(i)\n filepaths = list(map(lambda x: os.path.join(node_dir, f'{x}.parquet'), range(len(modules))))\n list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet\n filenames = list(map(lambda x: os.path.basename(x), filepaths))\n\n # make summary file\n summary_df = pd.DataFrame({\n 'filename': filenames,\n 'module_name': list(map(lambda module: module.__name__, modules)),\n 'module_params': module_params,\n 'execution_time': average_times,\n })\n\n # Run evaluation when there are more than one module.\n if len(modules) > 1:\n # pop general keys from strategies (e.g. metrics, speed_threshold)\n general_key = ['metrics', 'speed_threshold']\n general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))\n extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))\n\n # first, filter by threshold if it is enabled.\n if general_strategy.get('speed_threshold') is not None:\n results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],\n filenames)\n\n # check metrics in strategy\n if general_strategy.get('metrics') is None:\n raise ValueError(\"You must at least one metrics for query expansion evaluation.\")\n\n if extra_strategy.get('top_k') is None:\n extra_strategy['top_k'] = 10 # default value\n\n # get retrieval modules from strategy\n retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)\n\n # get retrieval_gt\n retrieval_gt = pd.read_parquet(os.path.join(project_dir, \"data\", \"qa.parquet\"))['retrieval_gt'].tolist()\n\n # run evaluation\n evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(\n retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,\n general_strategy['metrics'], project_dir, previous_result), results))\n\n evaluation_df = pd.DataFrame({\n 'filename': filenames,\n **{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))\n for metric_name in general_strategy['metrics']}\n })\n summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')\n\n best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)\n # change metric name columns to query_expansion_metric_name\n best_result = best_result.rename(columns={\n metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})\n best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])\n else:\n best_result, best_filename = results[0], filenames[0]\n best_result = pd.concat([previous_result, best_result], axis=1)\n\n # add 'is_best' column at summary file\n summary_df['is_best'] = summary_df['filename'] == best_filename\n\n # save files\n summary_df.to_csv(os.path.join(node_dir, \"summary.csv\"), index=False)\n best_result.to_parquet(os.path.join(node_dir, f\"best_{os.path.splitext(best_filename)[0]}.parquet\"), index=False)\n\n return best_result\n```\n###test function signature:\n```python\n\ndef test_run_query_expansion_node_default(node_line_dir):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_query_expansion_node_default", + "reference": "\ndef test_run_query_expansion_node_default(node_line_dir):\n generator_models['mock'] = MockLLM\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n qa_path = os.path.join(project_dir, \"data\", \"qa.parquet\")\n previous_result = pd.read_parquet(qa_path)\n\n modules = [query_decompose, hyde]\n module_params = [{'llm': \"mock\", 'temperature': 0.2, 'batch': 7}, {'llm': \"mock\"}]\n strategies = {\n 'metrics': metrics\n }\n best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n base_query_expansion_test(best_result, node_line_dir)", + "completions": [ + "```python\ndef test_run_query_expansion_node_default(node_line_dir):\n modules = [query_expansion_module]\n module_params = [{'prompt': 'default'}]\n previous_result = pd.DataFrame({'query': ['a', 'b', 'c']})\n strategies = {'metrics': ['recall', 'precision', 'f1'], 'speed_threshold': 1000}\n run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n```\n" + ], + "line": 84, + "token": 850, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 36, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(self) -> Tuple[pd.DataFrame, FreqDetectionResult]:\n try:\n if self._prophet_detector_params[\"suppress_stan\"]:\n pd.set_option(\"mode.chained_assignment\", None)\n\n # Skip measurements based on feedbacks given from SODA Cloud\n preprocessed_df = self.preprocess(time_series_df=self.raw_time_series_df)\n\n # Automatically detect frequency of the time series\n freq_detector = FrequencyDetector(\n logs=self.logs,\n params=self.params,\n time_series_df=preprocessed_df,\n manual_freq=self.training_dataset_params.frequency,\n )\n freq_detection_result = freq_detector.detect_frequency()\n\n # Return if frequency detection failed\n if freq_detection_result.error_code_int >= ERROR_CODE_LEVEL_CUTTOFF:\n return self.exit_with_warning(freq_detection_result)\n\n # Apply training dataset configurations\n training_df = self.apply_training_dataset_configs(\n time_series_df=preprocessed_df, freq_detection_result=freq_detection_result\n )\n\n # Remove big gaps from the time series to not confuse Prophet\n training_df = self.remove_big_gaps_from_time_series(\n time_series_df=training_df, freq_detection_result=freq_detection_result\n )\n\n # Only use the last n points for training based on the window length\n window_length = self.get_window_length(training_df=training_df)\n training_df = training_df.iloc[-window_length:]\n\n training_df_shape = training_df[\"y\"].dropna().shape[0]\n if training_df_shape <= self._min_n_points:\n freq_detection_result = get_not_enough_measurements_freq_result(n_data_points=training_df_shape)\n return self.exit_with_warning(freq_detection_result)\n\n model_hyperparameters = self.get_prophet_hyperparameters(time_series_df=training_df)\n\n predictions_df = self.setup_fit_predict(\n time_series_df=training_df, model_hyperparameters=model_hyperparameters\n )\n anomalies_df = self.detect_anomalies(time_series_df=training_df, predictions_df=predictions_df)\n anomalies_df = self.generate_severity_zones(anomalies_df=anomalies_df)\n anomalies_df = self.compute_alert_level(anomalies_df=anomalies_df)\n return anomalies_df, freq_detection_result\n except Exception as e:\n raise e\n```\n###test function signature:\n```python\n\ndef test_with_exit() -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_exit", + "reference": "\ndef test_with_exit() -> None:\n time_series_df = generate_random_dataframe(size=2, n_rows_to_convert_none=0, frequency=\"D\")\n detector = ProphetDetector(\n logs=LOGS,\n params=PARAMS,\n time_series_df=time_series_df,\n model_cfg=ModelConfigs(),\n training_dataset_params=TrainingDatasetParameters(),\n severity_level_params=SeverityLevelParameters(),\n )\n df_anomalies, frequency_result = detector.run()\n assert df_anomalies.empty\n assert frequency_result.error_code_int == 100\n assert frequency_result.freq_detection_strategy == \"not_enough_measurements_custom\"", + "completions": [ + "```python\ndef test_with_exit() -> None:\n from soda.sodacl.checks.check_frequency import FrequencyDetector\n from soda.sodacl.checks.check_frequency import get_not_enough_measurements_freq_result\n from soda.sodacl.checks.check_frequency import ERROR_CODE_LEVEL_CUTTOFF\n from soda.sodacl.checks.check_frequency import FreqDetectionResult\n from soda.sodacl.checks.check_frequency import CheckFrequency\n from soda.sodacl.checks.check_frequency import CheckFrequencyParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyLogs\n from soda.sodacl.checks.check_frequency import CheckFrequencyTrainingDatasetParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyPro" + ], + "line": 52, + "token": 310, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 37, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef average_transforms(\n R: torch.Tensor,\n t: torch.Tensor,\n w: torch.Tensor,\n mask: torch.Tensor,\n dim: int,\n t_edge: Optional[torch.Tensor] = None,\n dither: Optional[bool] = True,\n dither_eps: float = 1e-4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert dim >= 0, \"dimension must index from the left\"\n w = torch.where(\n mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)\n )\n\n # We use different averaging models based on the number of weights\n num_transform_weights = w.size(-1)\n if num_transform_weights == 1:\n # Share a single scalar weight between t and R.\n probs = w.softmax(dim)\n t_probs = probs\n R_probs = probs[..., None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 2:\n # Use separate scalar weights for each of t and R.\n probs = w.softmax(dim)\n t_probs, R_probs = probs.unbind(-1)\n t_probs = t_probs[..., None]\n R_probs = R_probs[..., None, None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 3:\n # For R use a signed scalar weight.\n R_probs = w[..., 2].softmax(dim)[..., None, None]\n\n # For t use a two-parameter precision matrix P = P_isometric + P_radial.\n # We need to hand compute softmax over the shared dim x 2 elements.\n w_t = w[..., :2]\n w_t_total = w_t.logsumexp([dim, -1], True)\n p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)\n\n # Use Gaussian fusion for translation.\n t_edge = t_edge * mask.to(t_edge.dtype)[..., None]\n t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)\n else:\n raise NotImplementedError\n\n # Average rotation via SVD\n R_avg_unc = (R * R_probs).sum(dim)\n R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)\n U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)\n R_avg = U @ Vh\n\n # Enforce that matrix is rotation matrix\n d = torch.linalg.det(R_avg)\n d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)\n Vh = Vh * d_expand\n R_avg = U @ Vh\n return R_avg, t_avg\n```\n###test function signature:\n```python\n\ndef test_average_transforms(rotations, translations):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_average_transforms", + "reference": "\ndef test_average_transforms(rotations, translations):\n R = torch.stack([rotations[0], torch.eye(3)])\n t = torch.stack([translations[0], torch.zeros(3)])\n w = torch.ones(2, 2)\n mask = torch.ones(2)\n # average of a transform with the identity is \"half\" the transform\n R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)\n R_total_fromavg, _ = compose_transforms(\n R_avg, torch.zeros(3), R_avg, torch.zeros(3)\n )\n _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)\n assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)\n assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)", + "completions": [ + "```python\ndef test_average_transforms(rotations, translations):\n # test case 1\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 2\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 3\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 4\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 5\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 6\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 7\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 8\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 9\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w" + ], + "line": 65, + "token": 537, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 38, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n return GPT2Config(\n vocab_size=btlm_config.vocab_size,\n n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n n_embd=btlm_config.hidden_size,\n n_layer=btlm_config.num_hidden_layers,\n n_head=btlm_config.num_attention_heads,\n n_inner=btlm_config.n_inner,\n activation_function=btlm_config.activation_function,\n resid_pdrop=btlm_config.resid_pdrop,\n embd_pdrop=btlm_config.embd_pdrop,\n attn_pdrop=btlm_config.attn_pdrop,\n layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n initializer_range=btlm_config.initializer_range,\n bos_token_id=btlm_config.bos_token_id,\n eos_token_id=btlm_config.eos_token_id,\n # These are new arguments not in the original GPT2Config\n use_alibi=btlm_config.position_embedding_type == \"alibi\",\n use_flash_attn=btlm_config.position_embedding_type == \"alibi\", # Alibi code path requires flash_attn\n mup_width_scale=btlm_config.mup_width_scale,\n mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n mup_output_multiplier=btlm_config.mup_output_alpha,\n mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n mlp_multiple_of=1,\n )\n```\n###test function signature:\n```python\ndef test_btlm_optimized(model_name):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_btlm_optimized", + "reference": "def test_btlm_optimized(model_name):\n dtype = torch.float16\n device = \"cuda\"\n config = btlm_config_to_gpt2_config(\n AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n )\n config.fused_bias_fc = True\n config.fused_dropout_add_ln = True\n config.residual_in_fp32 = True\n\n pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n model = GPTLMHeadModel(config, device=device, dtype=dtype)\n model.load_state_dict(pretrained_state_dict)\n model.eval()\n\n torch.manual_seed(0)\n batch_size = 2\n max_seqlen = 256\n seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n input_ids = torch.randint(\n 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n )\n with torch.no_grad():\n out = model.transformer(input_ids)\n logits = model(input_ids).logits\n del model\n\n # Without device_map, the model is loaded on the CPU, which is very slow\n # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n model_ref = AutoModelForCausalLM.from_pretrained(\n model_name, device_map=\"auto\", trust_remote_code=True\n )\n model_ref.eval()\n with torch.no_grad():\n out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n logits_ref = model_ref(input_ids).logits.to(device=device)\n del model_ref\n\n model_hf = AutoModelForCausalLM.from_pretrained(\n model_name,\n torch_dtype=dtype,\n device_map={\"\": device},\n trust_remote_code=True,\n )\n model_hf.eval()\n with torch.no_grad():\n out_hf = model_hf.transformer(input_ids).last_hidden_state\n logits_hf = model_hf(input_ids).logits\n del model_hf\n\n print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n assert (logits - logits_ref).abs().max().item() < 3 * (\n logits_hf - logits_ref\n ).abs().max().item()", + "completions": [ + "```python\ndef test_btlm_optimized(model_name):\n from transformers import GPT2Config, PretrainedConfig\n from bert_extended.optimized import btlm_config_to_gpt2_config\n\n config = PretrainedConfig.from_pretrained(model_name)\n gpt2_config = btlm_config_to_gpt2_config(config)\n gpt2_config.to_json_file(f\"{model_name}_optimized.json\")\n```\n" + ], + "line": 35, + "token": 169, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 39, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef grad_solid_harmonic(r: jnp.ndarray, l_max: int) -> jnp.ndarray:\n def cartesian_sum(r, ell):\n # computes all cartesian sums of order ell\n # i.e. all terms (a*x + b*y + c*z) for a+b+c = ell\n if ell == 0:\n return np.array([0.])\n return np.sum(np.stack(\n np.meshgrid(*[r]*ell, indexing='ij'), axis=0).reshape(ell, -1), axis=0)\n\n dharmonics = []\n for ell in range(l_max + 1):\n dx = cartesian_sum(np.array([1., 0., 0.]), ell)\n dy = cartesian_sum(np.array([0., 1., 0.]), ell)\n dz = cartesian_sum(np.array([0., 0., 1.]), ell)\n d = np.stack([dx, dy, dz], axis=1)\n cp = jax.vmap(functools.partial(cartesian_product, ell=ell), out_axes=-1)\n dcp = jnp.where(r.T[None, :, :] == 0, 0,\n d[:, :, None]*cp(r)[:, None, :] / r.T[None, :, :])\n sh = jnp.einsum('cs,cxn->snx', full_cart2sph(ell, True), dcp)\n pad = l_max - ell\n sh = jnp.pad(sh, [[pad, pad], [0, 0], [0, 0]])\n dharmonics.append(sh)\n\n return jnp.stack(dharmonics, axis=1) # [L, M, grid, xyz]\n```\n###test function signature:\n```python\n\n def test_grad_solid_harmonic(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_grad_solid_harmonic", + "reference": "\n def test_grad_solid_harmonic(self):\n np.random.seed(0)\n r = np.random.randn(100, 3)\n l_max = 5\n\n jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max))\n expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])),\n [1, 2, 0, 3])\n\n with self.subTest('by hand'):\n observed = gto.grad_solid_harmonic(r, l_max)\n np.testing.assert_allclose(observed, expected, atol=1.e-4)\n with self.subTest('by jax'):\n observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max)\n np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4)", + "completions": [ + "```python\n def test_grad_solid_harmonic(self):\n r = jnp.array([[0.5, 0.5, 0.5]])\n l_max = 2\n dharmonics = grad_solid_harmonic(r, l_max)\n self.assertEqual(dharmonics.shape, (l_max + 1, l_max + 1, 1, 3))\n self.assertTrue(jnp.allclose(dharmonics[0, 0, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 0, 0, :], jnp.array([1., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 1, 0, :], jnp.array([0., 1., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 0, 0, :], jnp.array([0., 0., 1.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 1, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 2, 0, :], jnp.array([0., 0., 0.])))\n```\n" + ], + "line": 32, + "token": 389, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 40, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_basis(\n base_shape, angular_tesselation, remove_symmetries=True, eps=1e-4\n):\n\n if base_shape == 'tetrahedron':\n verts = np.array([\n (np.sqrt(8 / 9), 0, -1 / 3),\n (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),\n (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3),\n (0, 0, 1),\n ])\n faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)])\n elif base_shape == 'icosahedron':\n a = (np.sqrt(5) + 1) / 2\n verts = np.array([\n (-1, 0, a),\n (1, 0, a),\n (-1, 0, -a),\n (1, 0, -a),\n (0, a, 1),\n (0, a, -1),\n (0, -a, 1),\n (0, -a, -1),\n (a, 1, 0),\n (-a, 1, 0),\n (a, -1, 0),\n (-a, -1, 0),\n ]) / np.sqrt(a + 2)\n faces = np.array([\n (0, 4, 1),\n (0, 9, 4),\n (9, 5, 4),\n (4, 5, 8),\n (4, 8, 1),\n (8, 10, 1),\n (8, 3, 10),\n (5, 3, 8),\n (5, 2, 3),\n (2, 7, 3),\n (7, 10, 3),\n (7, 6, 10),\n (7, 11, 6),\n (11, 0, 6),\n (0, 1, 6),\n (6, 1, 10),\n (9, 0, 11),\n (9, 11, 2),\n (9, 2, 5),\n (7, 2, 11),\n ])\n elif base_shape == 'octahedron':\n verts = np.array(\n [(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]\n )\n corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n else:\n raise ValueError(f'base_shape {base_shape} not supported')\n verts = tesselate_geodesic(verts, faces, angular_tesselation)\n\n if remove_symmetries:\n # Remove elements of `verts` that are reflections of each other.\n match = compute_sq_dist(verts.T, -verts.T) < eps\n verts = verts[~np.any(np.triu(match), axis=0), :]\n\n basis = verts[:, ::-1]\n return basis\n```\n###test function signature:\n```python\n def test_generate_basis_golden(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_basis_golden", + "reference": " def test_generate_basis_golden(self):\n\n basis = geopoly.generate_basis('tetrahedron', 1)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('tetrahedron', 2)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.57735027, 0.00000000, -0.81649658],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.57735027, -0.70710678, 0.40824829],\n [-0.57735027, 0.70710678, 0.40824829],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('icosahedron', 2)\n basis_golden = np.array([\n [0.85065081, 0.00000000, 0.52573111],\n [0.80901699, 0.50000000, 0.30901699],\n [0.52573111, 0.85065081, 0.00000000],\n [1.00000000, 0.00000000, 0.00000000],\n [0.80901699, 0.50000000, -0.30901699],\n [0.85065081, 0.00000000, -0.52573111],\n [0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, -0.85065081],\n [0.50000000, 0.30901699, -0.80901699],\n [0.00000000, 1.00000000, 0.00000000],\n [-0.52573111, 0.85065081, 0.00000000],\n [-0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, 0.85065081],\n [-0.30901699, 0.80901699, 0.50000000],\n [0.30901699, 0.80901699, 0.50000000],\n [0.50000000, 0.30901699, 0.80901699],\n [0.50000000, -0.30901699, 0.80901699],\n [0.00000000, 0.00000000, 1.00000000],\n [-0.50000000, 0.30901699, 0.80901699],\n [-0.80901699, 0.50000000, 0.30901699],\n [-0.80901699, 0.50000000, -0.30901699],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('octahedron', 4)\n basis_golden = np.array([\n [0.00000000, 0.00000000, -1.00000000],\n [0.00000000, -0.31622777, -0.94868330],\n [0.00000000, -0.70710678, -0.70710678],\n [0.00000000, -0.94868330, -0.31622777],\n [0.00000000, -1.00000000, 0.00000000],\n [-0.31622777, 0.00000000, -0.94868330],\n [-0.40824829, -0.40824829, -0.81649658],\n [-0.40824829, -0.81649658, -0.40824829],\n [-0.31622777, -0.94868330, 0.00000000],\n [-0.70710678, 0.00000000, -0.70710678],\n [-0.81649658, -0.40824829, -0.40824829],\n [-0.70710678, -0.70710678, 0.00000000],\n [-0.94868330, 0.00000000, -0.31622777],\n [-0.94868330, -0.31622777, 0.00000000],\n [-1.00000000, 0.00000000, 0.00000000],\n [0.00000000, -0.31622777, 0.94868330],\n [0.00000000, -0.70710678, 0.70710678],\n [0.00000000, -0.94868330, 0.31622777],\n [0.40824829, -0.40824829, 0.81649658],\n [0.40824829, -0.81649658, 0.40824829],\n [0.31622777, -0.94868330, 0.00000000],\n [0.81649658, -0.40824829, 0.40824829],\n [0.70710678, -0.70710678, 0.00000000],\n [0.94868330, -0.31622777, 0.00000000],\n [0.31622777, 0.00000000, -0.94868330],\n [0.40824829, 0.40824829, -0.81649658],\n [0.40824829, 0.81649658, -0.40824829],\n [0.70710678, 0.00000000, -0.70710678],\n [0.81649658, 0.40824829, -0.40824829],\n [0.94868330, 0.00000000, -0.31622777],\n [0.40824829, -0.40824829, -0.81649658],\n [0.40824829, -0.81649658, -0.40824829],\n [0.81649658, -0.40824829, -0.40824829],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))", + "completions": [ + "```python\n def test_generate_basis_golden(self):\n basis = generate_basis(\n base_shape='tetrahedron', angular_tesselation=10, remove_symmetries=True, eps=1e-4\n )\n self.assertTrue(np.allclose(basis, np.array([[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0]])))\n```\n" + ], + "line": 75, + "token": 692, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 41, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:\n if not isinstance(t, torch.Tensor):\n t = torch.Tensor([t]).float()\n\n t_max = self.t_max.to(t.device)\n t_min = self.t_min.to(t.device)\n t_tilde = t_min + (t_max - t_min) * t\n\n return t_tilde\n```\n###test function signature:\n```python\n\ndef test_noise_schedule_diffusion_coeff(kind):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_noise_schedule_diffusion_coeff", + "reference": "\ndef test_noise_schedule_diffusion_coeff(kind):\n noise_schedule = GaussianNoiseSchedule(kind=kind, log_snr_range=(-6, 6))\n ts = torch.linspace(1e-2, 1 - 1e-2, 10)\n t_map = noise_schedule.t_map(ts) # map time to the prescribed log_SNR range\n\n if kind == \"log_snr\":\n g = noise_schedule.g(ts)\n\n # compute true beta_t\n l_range = noise_schedule.log_snr_range\n snr = torch.exp((1 - t_map) * l_range[1] + t_map * l_range[0])\n g_true = (-(l_range[0] - l_range[1]) / (snr + 1)).sqrt()\n\n assert torch.allclose(g, g_true, atol=1e-4)\n\n if kind == \"ot_linear\":\n g = noise_schedule.g(ts)\n g_true = (2.0 * t_map / (1.0 - t_map)).sqrt()\n assert torch.allclose(g, g_true, atol=1e-4)", + "completions": [ + "```python\ndef test_noise_schedule_diffusion_coeff(kind):\n from torch.testing import assert_allclose\n from torch.nn import functional as F\n from torch import tensor\n from torch.distributions import Uniform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import Transform\n from torch.distributions.transforms import SigmoidTransform\n from torch.distributions.transforms import TanhTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import ComposeTransform\n from torch.distributions.transforms import PowerTransform\n from torch.distributions.transforms import ExpTransform\n from torch.distributions.transforms import LogTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import" + ], + "line": 17, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 42, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_data_quality", + "reference": "\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index, temperature_data, temperature_mean=False, data_quality=True\n )\n assert df.shape == (27, 4)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_not_null\",\n \"temperature_null\",\n ]\n assert round(df.temperature_not_null.mean(), 2) == 729.23\n assert round(df.temperature_null.mean(), 2) == 0.0", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"1H\", tz=\"America/Chicago\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0" + ], + "line": 168, + "token": 984, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 43, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n\n # when\n result = http_client.unload_model(model_id=\"some/1\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"some/1\",\n }", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n model_id = \"my_model_id\"\n de_aliased_model_id = \"my_de_aliased_model_id\"\n response_payload = {\n \"model_id\": de_aliased_model_id,\n \"model_name\": \"my_model_name\",\n \"model_version\": \"my_model_version\",\n \"model_type\": \"my_model_type\",\n \"model_status\": \"my_model_status\",\n \"model_description\": \"my_model_description\",\n \"model_tags\": [\"my_model_tag1\", \"my_model_tag2\"],\n \"model_created_at\": \"my_model_created_at\",\n \"model_updated_at\": \"my_model_updated_at\",\n }\n requests_mock.post(\n f\"{self.__api_url}/model/remove\",\n json=response_payload,\n headers=DEFAULT_HEADERS,\n )\n # when\n response = self.client.unload_model(model_id=model_id)\n # then\n assert response.model_id == de_aliased_model_id\n assert response.model_name == \"my_model_name\"\n assert response.model_version == \"my_model_version\"\n assert response.model_type == \"my_model_type\"\n assert response.model_status == \"my_model_status\"\n assert response.model_description == \"my_model_description\"\n assert response.model_tags == [\"my_model_tag1\", \"my_model_tag2\"]\n assert response.model_created_at == \"my_model_created_at\"\n assert response.model_updated_at == \"my_model_updated_at\"\n assert self.client.selected_model is None\n```\n" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=20.0,\n max_iteration=100,\n distance_strict=False,\n seed=None,\n):\n rng = np.random.default_rng(seed=seed)\n units_locations = np.zeros((num_units, 3), dtype=\"float32\")\n\n minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um\n minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um\n\n units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)\n units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)\n units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)\n\n if minimum_distance is not None:\n solution_found = False\n renew_inds = None\n for i in range(max_iteration):\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n inds0, inds1 = np.nonzero(distances < minimum_distance)\n mask = inds0 != inds1\n inds0 = inds0[mask]\n inds1 = inds1[mask]\n\n if inds0.size > 0:\n if renew_inds is None:\n renew_inds = np.unique(inds0)\n else:\n # random only bad ones in the previous set\n renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]\n\n units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)\n units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)\n units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)\n else:\n solution_found = True\n break\n\n if not solution_found:\n if distance_strict:\n raise ValueError(\n f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} \"\n \"You can use distance_strict=False or reduce minimum distance\"\n )\n else:\n warnings.warn(f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}\")\n\n return units_locations\n```\n###test function signature:\n```python\n\ndef test_generate_unit_locations():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_unit_locations", + "reference": "\ndef test_generate_unit_locations():\n seed = 0\n\n probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)\n channel_locations = probe.contact_positions\n\n num_units = 100\n minimum_distance = 20.0\n unit_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=minimum_distance,\n max_iteration=500,\n distance_strict=False,\n seed=seed,\n )\n distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)\n dist_flat = np.triu(distances, k=1).flatten()\n dist_flat = dist_flat[dist_flat > 0]\n assert np.all(dist_flat > minimum_distance)\n\n # import matplotlib.pyplot as plt\n # fig, ax = plt.subplots()\n # ax.hist(dist_flat, bins = np.arange(0, 400, 10))\n # fig, ax = plt.subplots()\n # from probeinterface.plotting import plot_probe\n # plot_probe(probe, ax=ax)\n # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)\n # plt.show()", + "completions": [ + "```python\n\ndef test_generate_unit_locations():\n num_units = 10\n channel_locations = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])\n margin_um = 20.0\n minimum_z = 5.0\n maximum_z = 40.0\n minimum_distance = 20.0\n max_iteration = 100\n distance_strict = False\n seed = None\n\n units_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=margin_um,\n minimum_z=minimum_z,\n maximum_z=maximum_z,\n minimum_distance=minimum_distance,\n max_iteration=max_iteration,\n distance_strict=distance_strict,\n seed=seed,\n )\n\n assert units_locations.shape == (num_units, 3)\n assert np.all(units_locations[:, 0] >= margin_um)\n assert np.all(units_locations[:, 0] <= 100 - margin_um)\n assert np.all(units_locations[:, 1] >= margin_um)\n assert np.all(units_locations[:, 1] <= 100 - margin_um)\n assert np.all(units_locations[:, 2] >= minimum_z)\n assert np.all(units_locations[:, 2] <= maximum_z)\n\n # check distance\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n assert np.all(distances >= minimum_distance)\n\n # check no duplicate\n assert np.all(np.unique(units_locations, axis=0) == units_locations)\n```\n" + ], + "line": 58, + "token": 453, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 45, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def to_dataframe(\n self, index: Optional[Union[str, None]] = \"date\", sort_by: Optional[str] = None\n ) -> pd.DataFrame:\n\n def is_list_of_basemodel(items: Union[List[T], T]) -> bool:\n return isinstance(items, list) and all(\n isinstance(item, BaseModel) for item in items\n )\n\n if self.results is None or not self.results:\n raise OpenBBError(\"Results not found.\")\n\n if isinstance(self.results, pd.DataFrame):\n return self.results\n\n try:\n res = self.results\n df = None\n sort_columns = True\n\n # List[Dict]\n if isinstance(res, list) and len(res) == 1 and isinstance(res[0], dict):\n r = res[0]\n dict_of_df = {}\n\n for k, v in r.items():\n # Dict[str, List[BaseModel]]\n if is_list_of_basemodel(v):\n dict_of_df[k] = basemodel_to_df(v, index)\n sort_columns = False\n # Dict[str, Any]\n else:\n dict_of_df[k] = pd.DataFrame(v)\n\n df = pd.concat(dict_of_df, axis=1)\n\n # List[BaseModel]\n elif is_list_of_basemodel(res):\n dt: Union[List[Data], Data] = res # type: ignore\n df = basemodel_to_df(dt, index)\n sort_columns = False\n # List[List | str | int | float] | Dict[str, Dict | List | BaseModel]\n else:\n try:\n df = pd.DataFrame(res)\n # Set index, if any\n if index is not None and index in df.columns:\n df.set_index(index, inplace=True)\n\n except ValueError:\n if isinstance(res, dict):\n df = pd.DataFrame([res])\n\n if df is None:\n raise OpenBBError(\"Unsupported data format.\")\n\n # Drop columns that are all NaN, but don't rearrange columns\n if sort_columns:\n df.sort_index(axis=1, inplace=True)\n df = df.dropna(axis=1, how=\"all\")\n\n # Sort by specified column\n if sort_by:\n df.sort_values(by=sort_by, inplace=True)\n\n except OpenBBError as e:\n raise e\n except ValueError as ve:\n raise OpenBBError(\n f\"ValueError: {ve}. Ensure the data format matches the expected format.\"\n ) from ve\n except TypeError as te:\n raise OpenBBError(\n f\"TypeError: {te}. Check the data types in your results.\"\n ) from te\n except Exception as ex:\n raise OpenBBError(f\"An unexpected error occurred: {ex}\") from ex\n\n return df\n```\n###test function signature:\n```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_to_dataframe_w_args", + "reference": "def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n co = OBBject(results=results)\n\n # Act and Assert\n result = co.to_dataframe(index=index, sort_by=sort_by)\n assert isinstance(result, pd.DataFrame)\n assert result.index.name == index\n\n # check if dataframe is properly sorted\n assert result[sort_by].is_monotonic_increasing", + "completions": [ + "```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n from openbb_terminal.decorators import log_start_end\n from openbb_terminal.stocks.fundamental_analysis import financedatabase_model\n\n @log_start_end(log=logger)\n def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n df = financedatabase_model.to_dataframe(results, index, sort_by)\n # Act\n # Assert\n assert df.shape[0] > 0\n assert df.shape[1] > 0\n assert df.index.name == index\n assert df.columns.tolist() == sort_by\n\n # Act\n test_to_dataframe_w_args(results, index, sort_by)\n # Assert\n```\n" + ], + "line": 76, + "token": 566, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 46, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef unique(ar, return_index=False, return_inverse=False,\n return_counts=False, axis=None, *, equal_nan=True):\n ar = np.asanyarray(ar)\n if axis is None:\n ret = _unique1d(ar, return_index, return_inverse, return_counts, \n equal_nan=equal_nan)\n return _unpack_tuple(ret)\n\n # axis was specified and not None\n try:\n ar = np.moveaxis(ar, axis, 0)\n except np.AxisError:\n # this removes the \"axis1\" or \"axis2\" prefix from the error message\n raise np.AxisError(axis, ar.ndim) from None\n\n # Must reshape to a contiguous 2D array for this to work...\n orig_shape, orig_dtype = ar.shape, ar.dtype\n ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))\n ar = np.ascontiguousarray(ar)\n dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]\n\n # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured\n # data type with `m` fields where each field has the data type of `ar`.\n # In the following, we create the array `consolidated`, which has\n # shape `(n,)` with data type `dtype`.\n try:\n if ar.shape[1] > 0:\n consolidated = ar.view(dtype)\n else:\n # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is\n # a data type with itemsize 0, and the call `ar.view(dtype)` will\n # fail. Instead, we'll use `np.empty` to explicitly create the\n # array with shape `(len(ar),)`. Because `dtype` in this case has\n # itemsize 0, the total size of the result is still 0 bytes.\n consolidated = np.empty(len(ar), dtype=dtype)\n except TypeError as e:\n # There's no good way to do this for object arrays, etc...\n msg = 'The axis argument to unique is not supported for dtype {dt}'\n raise TypeError(msg.format(dt=ar.dtype)) from e\n\n def reshape_uniq(uniq):\n n = len(uniq)\n uniq = uniq.view(orig_dtype)\n uniq = uniq.reshape(n, *orig_shape[1:])\n uniq = np.moveaxis(uniq, 0, axis)\n return uniq\n\n output = _unique1d(consolidated, return_index,\n return_inverse, return_counts, equal_nan=equal_nan)\n output = (reshape_uniq(output[0]),) + output[1:]\n return _unpack_tuple(output)\n```\n###test function signature:\n```python\n\n def test_unique_1d(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_1d", + "reference": "\n def test_unique_1d(self):\n\n def check_all(a, b, i1, i2, c, dt):\n base_msg = 'check {0} failed for type {1}'\n\n msg = base_msg.format('values', dt)\n v = unique(a)\n assert_array_equal(v, b, msg)\n\n msg = base_msg.format('return_index', dt)\n v, j = unique(a, True, False, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i1, msg)\n\n msg = base_msg.format('return_inverse', dt)\n v, j = unique(a, False, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i2, msg)\n\n msg = base_msg.format('return_counts', dt)\n v, j = unique(a, False, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, c, msg)\n\n msg = base_msg.format('return_index and return_inverse', dt)\n v, j1, j2 = unique(a, True, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n\n msg = base_msg.format('return_index and return_counts', dt)\n v, j1, j2 = unique(a, True, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format('return_inverse and return_counts', dt)\n v, j1, j2 = unique(a, False, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i2, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format(('return_index, return_inverse '\n 'and return_counts'), dt)\n v, j1, j2, j3 = unique(a, True, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n assert_array_equal(j3, c, msg)\n\n a = [5, 7, 1, 2, 1, 5, 7]*10\n b = [1, 2, 5, 7]\n i1 = [2, 3, 0, 1]\n i2 = [2, 3, 0, 1, 0, 2, 3]*10\n c = np.multiply([2, 1, 2, 2], 10)\n\n # test for numeric arrays\n types = []\n types.extend(np.typecodes['AllInteger'])\n types.extend(np.typecodes['AllFloat'])\n types.append('datetime64[D]')\n types.append('timedelta64[D]')\n for dt in types:\n aa = np.array(a, dt)\n bb = np.array(b, dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for object arrays\n dt = 'O'\n aa = np.empty(len(a), dt)\n aa[:] = a\n bb = np.empty(len(b), dt)\n bb[:] = b\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for structured arrays\n dt = [('', 'i'), ('', 'i')]\n aa = np.array(list(zip(a, a)), dt)\n bb = np.array(list(zip(b, b)), dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for ticket #2799\n aa = [1. + 0.j, 1 - 1.j, 1]\n assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])\n\n # test for ticket #4785\n a = [(1, 2), (1, 2), (2, 3)]\n unq = [1, 2, 3]\n inv = [0, 1, 0, 1, 1, 2]\n a1 = unique(a)\n assert_array_equal(a1, unq)\n a2, a2_inv = unique(a, return_inverse=True)\n assert_array_equal(a2, unq)\n assert_array_equal(a2_inv, inv)\n\n # test for chararrays with return_inverse (gh-5099)\n a = np.chararray(5)\n a[...] = ''\n a2, a2_inv = np.unique(a, return_inverse=True)\n assert_array_equal(a2_inv, np.zeros(5))\n\n # test for ticket #9137\n a = []\n a1_idx = np.unique(a, return_index=True)[1]\n a2_inv = np.unique(a, return_inverse=True)[1]\n a3_idx, a3_inv = np.unique(a, return_index=True,\n return_inverse=True)[1:]\n assert_equal(a1_idx.dtype, np.intp)\n assert_equal(a2_inv.dtype, np.intp)\n assert_equal(a3_idx.dtype, np.intp)\n assert_equal(a3_inv.dtype, np.intp)\n\n # test for ticket 2111 - float\n a = [2.0, np.nan, 1.0, np.nan]\n ua = [1.0, 2.0, np.nan]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - complex\n a = [2.0-1j, np.nan, 1.0+1j, complex(0.0, np.nan), complex(1.0, np.nan)]\n ua = [1.0+1j, 2.0-1j, complex(0.0, np.nan)]\n ua_idx = [2, 0, 3]\n ua_inv = [1, 2, 0, 2, 2]\n ua_cnt = [1, 1, 3]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - datetime64\n nat = np.datetime64('nat')\n a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]\n ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - timedelta\n nat = np.timedelta64('nat')\n a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]\n ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for gh-19300\n all_nans = [np.nan] * 4\n ua = [np.nan]\n ua_idx = [0]\n ua_inv = [0, 0, 0, 0]\n ua_cnt = [4]\n assert_equal(np.unique(all_nans), ua)\n assert_equal(np.unique(all_nans, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(all_nans, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(all_nans, return_counts=True), (ua, ua_cnt))", + "completions": [ + "```python\n def test_unique_1d(self):\n ar = np.array([1, 2, 3, 4, 5, 5, 5, 5, 5, 5])\n self.assertEqual(unique(ar), np.array([1, 2, 3, 4, 5]))\n self.assertEqual(unique(ar, return_index=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n```\n" + ], + "line": 56, + "token": 571, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 47, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef scaled_dot_product_attention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n att_mask: Optional[Union[AttentionMask, \"SparseCS\", torch.Tensor]],\n dropout: Optional[torch.nn.Module] = None,\n block_size: int = 128,\n) -> torch.Tensor:\n autocast_disabled = (\n _has_cpp_library\n and isinstance(att_mask, SparseCS)\n or (att_mask is not None and att_mask.is_sparse)\n )\n seq_len = q.shape[-2]\n\n # switch if:\n # causal is required but mask is not sparse\n # fp16 or under amp context\n # sequence length is divisible by block size\n # same seq len for K and Q\n switch_to_blocksparse = (\n _is_blocksparse_available\n and (att_mask is not None and not att_mask.is_sparse)\n and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)\n and (q.dtype == torch.float16 or torch.is_autocast_enabled())\n and not seq_len % block_size\n and q.shape[-2] == k.shape[-2]\n )\n\n if switch_to_blocksparse:\n logger.info(\"Switching causal attention to Triton blocksparse...\")\n return blocksparse_attention(q, k, v, dropout, block_size)\n\n with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore\n if autocast_disabled:\n q, k, v = q.float(), k.float(), v.float()\n\n att = scaled_query_key_softmax(q, k, att_mask=att_mask)\n\n # Optional dropout, could be part of the masking in the future\n att = _apply_dropout(att, dropout)\n\n # Get to the predicted values, for all heads\n # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)\n y = bmm(att, v)\n return y\n```\n###test function signature:\n```python\n\ndef test_amp_attention_dense(device):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_amp_attention_dense", + "reference": "\ndef test_amp_attention_dense(device):\n b, s, d = 8, 64, 32\n prob = 0.9\n\n a = torch.rand(b, s, d, device=device)\n m = torch.rand(s, s, device=device) > prob\n\n with torch.cuda.amp.autocast():\n r = scaled_dot_product_attention(a, a, a, m)\n\n expected_device = torch.float16 if device == \"cuda\" else torch.float32\n assert r.dtype == expected_device", + "completions": [ + "```python\n\ndef test_amp_attention_dense(device):\n q = torch.randn(1, 128, 128, device=device)\n k = torch.randn(1, 128, 128, device=device)\n v = torch.randn(1, 128, 128, device=device)\n att_mask = torch.randn(1, 128, 128, device=device)\n dropout = torch.nn.Dropout(0.1)\n block_size = 128\n y = scaled_dot_product_attention(q, k, v, att_mask, dropout, block_size)\n assert y.shape == (1, 128, 128)\n```\n" + ], + "line": 50, + "token": 390, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 48, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def quantize(\n cls,\n tensor: Tensor,\n nbits: int = 4,\n channel_wise: bool = True,\n group_size: int = 64,\n optimize: bool = False,\n round_zero: bool = False,\n axis: int = 0,\n bitpack: bool = True,\n compute_dtype: torch.dtype | None = None,\n view_as_float: bool = False,\n device: str = \"cuda\",\n ) -> tuple:\n assert nbits in Quantizer.SUPPORTED_BITS, (\n \"nbits=\" + str(nbits) + \" not supported.\"\n )\n assert axis in [0, 1], \"axis should be either 0 or 1\"\n if group_size is not None:\n assert is_divisible(tensor.numel(), group_size), (\n \"group_size should be divisble by the total tensor dimensions. shape: \"\n + str(tensor.shape)\n + \", group_size: \"\n + str(group_size)\n )\n\n W = tensor.float()\n shape = W.shape\n\n # Reshape for grouping\n if (group_size is not None) and channel_wise:\n W = (\n W.reshape([-1, group_size])\n if (axis == 1)\n else W.reshape([group_size, -1])\n )\n\n # Get min/max values\n if not channel_wise:\n _min, _max = W.min(), W.max()\n optimize = False\n else:\n _min = W.min(axis=axis, keepdim=True)[0]\n _max = W.max(axis=axis, keepdim=True)[0]\n\n max_v = 2**nbits - 1\n min_v = 0\n min_max = [min_v, max_v]\n\n # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.\n scale = (max_v / (_max - _min)).clamp(\n max=2e4\n ) # clamp to avoid half-precision problems\n zero = -_min * scale\n\n # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14\n if round_zero:\n zero = torch.round(zero)\n\n # Fine-tune weights\n if optimize:\n W_q, scale, zero = Quantizer.optimize_weights(\n tensor=W,\n scale=scale,\n zero=zero,\n min_max=min_max,\n axis=axis,\n device=device,\n )\n else:\n W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])\n\n # Store meta-data (we invert the scale for dequantization)\n meta = {\n \"nbits\": nbits,\n \"group_size\": group_size,\n \"shape\": shape,\n \"scale\": 1.0 / scale,\n \"zero\": zero,\n \"axis\": axis,\n \"packing\": Quantizer.bit_to_packing[nbits],\n }\n meta[\"unpack_view_dtype\"] = Quantizer.unpack_view_dtype[meta[\"packing\"]]\n\n # Pack bits\n meta[\"view_as_float\"] = view_as_float\n if bitpack:\n W_q = Quantizer.pack[meta[\"packing\"]](W_q)\n if view_as_float:\n W_q = W_q.view(\n torch.float32 if compute_dtype is None else compute_dtype\n ) # store quantized weights as compute_dtype\n else:\n W_q = W_q.to(tensor.dtype)\n meta[\"packing\"] = None\n\n # cleanup\n del W, _min, _max\n torch.cuda.empty_cache()\n\n return W_q, meta\n```\n###test function signature:\n```python\n\n def test_quantizer(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_quantizer", + "reference": "\n def test_quantizer(self): \n for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:\n for nbits in [8,4,3,2,1]:\n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, view_as_float=False)\n if nbits == 3:\n assert W_q.dtype == torch.int32\n else:\n assert W_q.dtype == torch.uint8\n w_dq = Quantizer.dequantize(W_q, meta)\n norm1 = torch.norm(self.w - w_dq, p=0.7)\n \n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, compute_dtype=compute_dtype, view_as_float=True)\n assert W_q.dtype == compute_dtype\n w_dq = Quantizer.dequantize(W_q, meta)\n norm2 = torch.norm(self.w - w_dq, p=0.7)\n \n self.assertTrue(torch.equal(norm1, norm2))", + "completions": [ + "```python\n\n def test_quantizer(self):\n # test case 1\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 2\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=True, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], 0.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 3\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=True, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 4\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor," + ], + "line": 100, + "token": 645, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 49, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_num_ref_area", + "reference": "\ndef test_nanoplot_out_bars_with_num_ref_area():\n\n out_bars_with_num_ref_area = _generate_nanoplot(**CASES[13])\n\n assert _is_nanoplot_output(out_bars_with_num_ref_area)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_area = [1, 3]\n y_ref_line = 2\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_" + ], + "line": 749, + "token": 7038, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 50, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef mean(x, axis=None, keepdims=False):\n if isinstance(x, (list, tuple)):\n x = stack(x)\n x = convert_to_tensor(x)\n if axis == () or axis == []:\n # Torch handles the empty axis case differently from numpy.\n return x\n axis = to_tuple_or_list(axis) # see [NB] below\n\n ori_dtype = standardize_dtype(x.dtype)\n # torch.mean only supports floating point inputs\n compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n if \"int\" in ori_dtype or ori_dtype == \"bool\":\n result_dtype = compute_dtype\n else:\n result_dtype = ori_dtype\n\n # [NB] the python torch op torch.mean() is generated into\n # `torch._C._VariableFunctions.pyi`, and the method\n # signature is overloaded.\n # Dynamo won't actually find the correct signature of\n # `torch.mean()` if arguments are passed via kwargs\n # So we have to pass the arguments via positional args\n # EXCEPT for those that are forced as kwargs via the `*`\n # delimiter in the overloaded method signatures.\n # Additionally, we have to create a singleton-tuple\n # when `axis` is an int to match the existing fn signature\n result = torch.mean(\n x,\n axis,\n keepdims,\n dtype=to_torch_dtype(compute_dtype),\n )\n return cast(result, result_dtype)\n```\n###test function signature:\n```python\n\n def test_mean(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mean", + "reference": "\n def test_mean(self):\n x = np.array([[1, 2, 3], [3, 2, 1]])\n self.assertAllClose(knp.mean(x), np.mean(x))\n self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n self.assertAllClose(\n knp.mean(x, axis=1, keepdims=True),\n np.mean(x, axis=1, keepdims=True),\n )\n\n self.assertAllClose(knp.Mean()(x), np.mean(x))\n self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n self.assertAllClose(\n knp.Mean(axis=1, keepdims=True)(x),\n np.mean(x, axis=1, keepdims=True),\n )\n\n # test overflow\n x = np.array([65504, 65504, 65504], dtype=\"float16\")\n self.assertAllClose(knp.mean(x), np.mean(x))", + "completions": [ + "```python\n\n def test_mean(self):\n x = torch.randn(2, 3, 4)\n self.assertEqual(mean(x), torch.mean(x))\n self.assertEqual(mean(x, axis=0), torch.mean(x, dim=0))\n self.assertEqual(mean(x, axis=1), torch.mean(x, dim=1))\n self.assertEqual(mean(x, axis=2), torch.mean(x, dim=2))\n self.assertEqual(mean(x, axis=(0, 1)), torch.mean(x, dim=(0, 1)))\n self.assertEqual(mean(x, axis=(0, 2)), torch.mean(x, dim=(0, 2)))\n self.assertEqual(mean(x, axis=(1, 2)), torch.mean(x, dim=(1, 2)))\n self.assertEqual(mean(x, axis=(0, 1, 2)), torch.mean(x, dim=(0, 1, 2)))\n self.assertEqual(mean(x, axis=0, keepdims=True), torch.mean(x, dim=0, keepdim=True))\n self.assertEqual(mean(x, axis=1, keepdims=True), torch.mean(x, dim=1, keepdim=True))\n self.assertEqual(mean(x, axis=2, keepdims=True), torch.mean(x, dim=2, keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1), keepdims=True), torch.mean(x, dim=(0, 1), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 2), keepdims=True), torch.mean(x, dim=(0, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(1, 2), keepdims=True), torch.mean(x, dim=(1, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1, 2), keepdims=True), torch.mean(x, dim=(0, 1, 2), keepdim=True))\n```\n" + ], + "line": 42, + "token": 322, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 51, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_intrinsic_value(\n cash_flow: float,\n growth_rate: float,\n perpetual_growth_rate: float,\n weighted_average_cost_of_capital: float,\n cash_and_cash_equivalents: float,\n total_debt: float,\n shares_outstanding: float,\n periods: int = 5,\n) -> pd.DataFrame:\n components = {}\n\n cash_flow_projection = [cash_flow]\n\n # Cash Flow to Use\n for period in range(1, periods + 1):\n if period == 1:\n cash_flow_projection.append(cash_flow_projection[0] * (1 + growth_rate))\n else:\n cash_flow_projection.append(\n cash_flow_projection[period - 1] * (1 + growth_rate)\n )\n\n # Calculate the Terminal Value\n terminal_value = (\n cash_flow_projection[-1]\n * (1 + perpetual_growth_rate)\n / (weighted_average_cost_of_capital - perpetual_growth_rate)\n )\n\n # Add Terminal Value to the end of the cash flow projection\n cash_flow_projection[-1] = cash_flow_projection[-1] + terminal_value\n\n # Calculate the Present Value based on the Discounted Cash Flow Formula\n cash_flow_present_value = []\n for index, cash_flow_value in enumerate(cash_flow_projection):\n cash_flow_present_value.append(\n cash_flow_value / (1 + weighted_average_cost_of_capital) ** (index + 1)\n )\n\n # Calculate the Enterprise Value\n enterprise_value = sum(cash_flow_present_value)\n\n # Calculate the Equity Value\n equity_value = enterprise_value + cash_and_cash_equivalents - total_debt\n\n # Calculate the Intrinsic Value\n intrinsic_value = equity_value / shares_outstanding\n\n # Combine the components into a dictionary\n components = {\n \"Terminal Value\": terminal_value,\n \"Cash Flow Projection\": cash_flow_projection[-1],\n \"Enterprise Value\": enterprise_value,\n \"Equity Value\": equity_value,\n \"Intrinsic Value\": intrinsic_value,\n }\n\n return pd.DataFrame.from_dict(\n components, orient=\"index\", columns=[f\"Periods = {periods}\"]\n )\n```\n###test function signature:\n```python\n\ndef test_get_intrinsic_value(recorder):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_intrinsic_value", + "reference": "\ndef test_get_intrinsic_value(recorder):\n recorder.capture(\n intrinsic_model.get_intrinsic_value(\n cash_flow=500,\n growth_rate=0.05,\n perpetual_growth_rate=0.02,\n weighted_average_cost_of_capital=0.1,\n cash_and_cash_equivalents=100,\n total_debt=100,\n shares_outstanding=100,\n periods=5,\n )\n )", + "completions": [ + "```python\n\ndef test_get_intrinsic_value(recorder):\n cash_flow = 1000000\n growth_rate = 0.05\n perpetual_growth_rate = 0.05\n weighted_average_cost_of_capital = 0.1\n cash_and_cash_equivalents = 1000000\n total_debt = 1000000\n shares_outstanding = 1000000\n periods = 5\n\n result = get_intrinsic_value(\n cash_flow,\n growth_rate,\n perpetual_growth_rate,\n weighted_average_cost_of_capital,\n cash_and_cash_equivalents,\n total_debt,\n shares_outstanding,\n periods,\n )\n\n recorder.capture(result)\n```\n" + ], + "line": 61, + "token": 379, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 52, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef to_numpy(\n evset: EventSet,\n timestamp_to_datetime: bool = True,\n timestamps: bool = True,\n) -> Dict[str, ndarray]:\n timestamp_key = \"timestamp\"\n index_names = evset.schema.index_names()\n feature_names = evset.schema.feature_names()\n\n column_names = index_names + feature_names\n if timestamps:\n column_names += [timestamp_key]\n\n dst = {column_name: [] for column_name in column_names}\n for index, data in evset.data.items():\n assert isinstance(index, tuple)\n\n if timestamps:\n # Timestamps\n if evset.schema.is_unix_timestamp and timestamp_to_datetime:\n dst[timestamp_key].append(\n data.timestamps.astype(\"datetime64[s]\")\n )\n else:\n dst[timestamp_key].append(data.timestamps)\n\n # Features\n for feature_name, feature in zip(feature_names, data.features):\n dst[feature_name].append(feature)\n\n # Indexes\n num_timestamps = len(data.timestamps)\n for index_name, index_item in zip(index_names, index):\n dst[index_name].append(np.repeat(index_item, num_timestamps))\n\n dst = {k: np.concatenate(v) for k, v in dst.items()}\n return dst\n```\n###test function signature:\n```python\n\n def test_correct(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct", + "reference": "\n def test_correct(self):\n evset = event_set(\n timestamps=[\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n features={\n \"feature_1\": [0.5, 0.6],\n \"my_index\": [\"red\", \"blue\"],\n },\n indexes=[\"my_index\"],\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"timestamp\": np.array(\n [\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n dtype=\"datetime64[s]\",\n ),\n \"feature_1\": np.array([0.5, 0.6]),\n \"my_index\": np.array([b\"red\", b\"blue\"]),\n }\n\n for k in expected:\n np.testing.assert_array_equal(\n np.sort(result[k]), np.sort(expected[k])\n )", + "completions": [ + "```python\n def test_correct(self):\n evset = EventSet(\n data={\n \"timestamp\": [1610000000, 1610000000, 1610000000, 1610000000],\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n },\n schema=EventSetSchema(\n index_names=[\"index\"],\n feature_names=[\"feature1\", \"feature2\"],\n is_unix_timestamp=True,\n ),\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n \"timestamp\": [\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n ],\n }\n\n self.assertDictEqual(result, expected)\n```\n" + ], + "line": 41, + "token": 276, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 53, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def from_file(cls, path: str) -> \"LayerConfig\":\n logger.info(f\"Loading route config from {path}\")\n _, ext = os.path.splitext(path)\n with open(path, \"r\") as f:\n if ext == \".json\":\n layer = json.load(f)\n elif ext in [\".yaml\", \".yml\"]:\n layer = yaml.safe_load(f)\n else:\n raise ValueError(\n \"Unsupported file type. Only .json and .yaml are supported\"\n )\n\n if not is_valid(json.dumps(layer)):\n raise Exception(\"Invalid config JSON or YAML\")\n\n encoder_type = layer[\"encoder_type\"]\n encoder_name = layer[\"encoder_name\"]\n routes = []\n for route_data in layer[\"routes\"]:\n # Handle the 'llm' field specially if it exists\n if \"llm\" in route_data and route_data[\"llm\"] is not None:\n llm_data = route_data.pop(\n \"llm\"\n ) # Remove 'llm' from route_data and handle it separately\n # Use the module path directly from llm_data without modification\n llm_module_path = llm_data[\"module\"]\n # Dynamically import the module and then the class from that module\n llm_module = importlib.import_module(llm_module_path)\n llm_class = getattr(llm_module, llm_data[\"class\"])\n # Instantiate the LLM class with the provided model name\n llm = llm_class(name=llm_data[\"model\"])\n # Reassign the instantiated llm object back to route_data\n route_data[\"llm\"] = llm\n\n # Dynamically create the Route object using the remaining route_data\n route = Route(**route_data)\n routes.append(route)\n\n return cls(\n encoder_type=encoder_type, encoder_name=encoder_name, routes=routes\n )\n```\n###test function signature:\n```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_from_file_yaml", + "reference": "\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n config_path = tmp_path / \"config.yaml\"\n config_path.write_text(\n layer_yaml()\n ) # Assuming layer_yaml() returns a valid YAML string\n\n # Load the LayerConfig from the temporary file\n layer_config = LayerConfig.from_file(str(config_path))\n\n # Assertions to verify the loaded configuration\n assert layer_config.encoder_type == \"cohere\"\n assert layer_config.encoder_name == \"embed-english-v3.0\"\n assert len(layer_config.routes) == 2\n assert layer_config.routes[0].name == \"politics\"", + "completions": [ + "```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n yaml_file = tmp_path / \"test_layer.yaml\"\n yaml_file.write_text(\n \"\"\"\n encoder_type: openai\n encoder_name: gpt2\n routes:\n - name: test_route\n model_name: test_model\n llm:\n module: tests.test_layer\n class: TestLLM\n model: test_model\n encoder_kwargs:\n max_position_embeddings: 512\n num_attention_heads: 8\n intermediate_size: 2048\n hidden_act: gelu\n hidden_dropout_prob: 0.1\n attention_probs_dropout_prob: 0.1\n type_vocab_size: 16\n initializer_range: 0.02\n layer_norm_eps: 1e-12\n use_cache: True\n is_decoder: False\n classifier_dropout: 0.1\n use_multiplicative_input_mask: False\n num_labels: 2\n output_attentions: False\n output_hidden_states: False\n return_dict: True\n task_specific_params:\n summarization:\n max_length: 20\n min_length: 5\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n" + ], + "line": 49, + "token": 405, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 54, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def bind_all(self, circuit: QuantumCircuit) -> np.ndarray:\n arr = np.empty(self.shape, dtype=object)\n for idx in np.ndindex(self.shape):\n arr[idx] = self.bind(circuit, idx)\n return arr\n```\n###test function signature:\n```python\n def test_data(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_data", + "reference": " def test_data(self):\n with self.subTest(\"binding a single value\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 0)\n self.assertEqual(ba.shape, ())\n self.assertEqual(ba.size, 1)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuit = ba.bind(self.circuit, ())\n self.assertEqual(bound_circuit, self.circuit.assign_parameters(vals))\n\n with self.subTest(\"binding an array\"):\n vals = np.linspace(0, 1, 300).reshape((2, 3, 50))\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 2)\n self.assertEqual(ba.shape, (2, 3))\n self.assertEqual(ba.size, 6)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuits = ba.bind_all(self.circuit)\n self.assertEqual(bound_circuits.shape, (2, 3))\n self.assertEqual(bound_circuits[0, 0], self.circuit.assign_parameters(vals[0, 0]))\n self.assertEqual(bound_circuits[0, 1], self.circuit.assign_parameters(vals[0, 1]))\n self.assertEqual(bound_circuits[0, 2], self.circuit.assign_parameters(vals[0, 2]))\n self.assertEqual(bound_circuits[1, 0], self.circuit.assign_parameters(vals[1, 0]))\n self.assertEqual(bound_circuits[1, 1], self.circuit.assign_parameters(vals[1, 1]))\n self.assertEqual(bound_circuits[1, 2], self.circuit.assign_parameters(vals[1, 2]))\n\n with self.subTest(\"binding a single param\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params[0]: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 1)\n self.assertEqual(ba.ndim, 1)\n self.assertEqual(ba.shape, (50,))\n self.assertEqual(ba.size, 50)\n self.assertEqual(list(ba.data.keys()), [(self.params[0].name,)])\n np.testing.assert_allclose(list(ba.data.values()), [vals[..., np.newaxis]])", + "completions": [ + "```python\n def test_data(self):\n from qiskit import QuantumCircuit\n from qiskit.quantum_info import Pauli\n from qiskit.quantum_info import PauliTable\n from qiskit.quantum_info import PauliSumOp\n from qiskit.quantum_info import PauliOp\n from qiskit.quantum_info import PauliExpectation\n from qiskit.quantum_info import PauliMoment\n from qiskit.quantum_info import PauliTensor\n from qiskit.quantum_info import PauliList\n from qiskit.quantum_info import PauliTransform\n from qiskit.quantum_info import PauliTransformOp\n from qiskit.quantum_info import PauliTransformList\n from qiskit.quantum_info import PauliTransformSum\n from qiskit.quantum_info import PauliTransformProduct\n from qiskit.quantum_info import PauliTransformExpectation\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum" + ], + "line": 15, + "token": 122, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef set_seed(\n seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None\n) -> int:\n world_size, rank = get_world_size_and_rank()\n max_val = np.iinfo(np.uint32).max - world_size + 1\n min_val = np.iinfo(np.uint32).min\n if seed is None:\n rand_seed = torch.randint(min_val, max_val, (1,))\n seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks\n if seed < min_val or seed > max_val:\n raise ValueError(\n f\"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]\"\n )\n local_seed = seed + rank\n _log.debug(\n f\"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}\"\n )\n\n torch.manual_seed(local_seed)\n np.random.seed(local_seed)\n random.seed(local_seed)\n\n if debug_mode is not None:\n _log.debug(f\"Setting deterministic debug mode to {debug_mode}\")\n torch.set_deterministic_debug_mode(debug_mode)\n deterministic_debug_mode = torch.get_deterministic_debug_mode()\n if deterministic_debug_mode == 0:\n _log.debug(\"Disabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = False\n torch.backends.cudnn.benchmark = True\n else:\n _log.debug(\"Enabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.benchmark = False\n # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n return seed\n```\n###test function signature:\n```python\n\n def test_deterministic_true(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_deterministic_true", + "reference": "\n def test_deterministic_true(self) -> None:\n for det_debug_mode, det_debug_mode_str in [(1, \"warn\"), (2, \"error\")]:\n warn_only = det_debug_mode == 1\n for debug_mode in (det_debug_mode, det_debug_mode_str):\n set_seed(42, debug_mode=debug_mode)\n assert torch.backends.cudnn.deterministic\n assert not torch.backends.cudnn.benchmark\n assert det_debug_mode == torch.get_deterministic_debug_mode()\n assert torch.are_deterministic_algorithms_enabled()\n assert (\n warn_only == torch.is_deterministic_algorithms_warn_only_enabled()\n )\n assert os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] == \":4096:8\"", + "completions": [ + "```python\n def test_deterministic_true(self) -> None:\n seed = set_seed(debug_mode=True)\n self.assertEqual(seed, 123456789)\n```\n" + ], + "line": 45, + "token": 345, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 56, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n```\n###test function signature:\n```python\n\ndef test_concat_periodarray_2d():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_concat_periodarray_2d", + "reference": "\ndef test_concat_periodarray_2d():\n pi = pd.period_range(\"2016-01-01\", periods=36, freq=\"D\")\n arr = pi._data.reshape(6, 6)\n\n result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)\n tm.assert_period_array_equal(result, arr)\n\n result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)\n tm.assert_period_array_equal(result, arr)\n\n msg = (\n \"all the input array dimensions.* for the concatenation axis must match exactly\"\n )\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)\n\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:2], arr[2:]], axis=1)", + "completions": [ + "```python\n\ndef test_concat_periodarray_2d():\n import numpy as np\n import pandas as pd\n from pandas.core.dtypes.common import is_period_dtype\n\n def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n\n def _concat_datetime(to_concat, axis: AxisInt = 0):\n from pandas.core.arrays import PeriodArray\n\n if axis == 0:\n " + ], + "line": 70, + "token": 610, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 57, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef ingest_ocp_payload(request):\n request_id = uuid4().hex\n response_data = {\"request-id\": request_id, \"payload-name\": []}\n if request.method == \"POST\":\n if not request.FILES:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for _, file in request.FILES.items():\n payload_name = file.name\n response_data[\"payload-name\"].append(payload_name)\n s3_signature = get_s3_signature(settings.S3_ENDPOINT, payload_name, method=\"put_object\")\n res = upload_file_to_s3(s3_signature, data=file.file)\n if res.status_code == HTTPStatus.OK:\n response_data[\"upload\"] = \"success\"\n else:\n response_data[\"upload\"] = \"failed\"\n response_data[\"failed-reason\"] = res.reason\n return Response(response_data, status=res.status_code)\n send_payload(request_id, payload_name)\n else:\n params = request.query_params\n payload_names = params.get(\"payload_name\")\n if not payload_names:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for payload_name in payload_names.split(\",\"):\n send_payload(request_id, payload_name)\n response_data[\"payload-name\"].append(payload_name)\n\n response_data[\"ingest-started\"] = True\n\n return Response(response_data, status=HTTPStatus.ACCEPTED)\n```\n###test function signature:\n```python\n\n def test_ingest_ocp_payload(self):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ingest_ocp_payload", + "reference": "\n def test_ingest_ocp_payload(self):\n # Arrange\n file1 = SimpleUploadedFile(\"file1.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n file2 = SimpleUploadedFile(\"file2.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n test_table = [\n # Happy path tests\n (\n \"POST\",\n {\"file1\": file1, \"file2\": file2},\n \"\",\n HTTPStatus.ACCEPTED,\n {\"upload\": \"success\", \"ingest-started\": True},\n \"happy_path_post\",\n ),\n (\"GET\", {}, \"?payload_name=file1,file2\", HTTPStatus.ACCEPTED, {\"ingest-started\": True}, \"happy_path_get\"),\n # Edge cases\n (\"POST\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_files_post\"),\n (\"GET\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_payload_name_get\"),\n # Error cases\n (\"POST\", {\"file1\": file1}, \"\", HTTPStatus.INTERNAL_SERVER_ERROR, {\"upload\": \"failed\"}, \"error_case_post\"),\n (\"GET\", {}, \"?payload_name=non_existent_file\", HTTPStatus.ACCEPTED, {}, \"error_case_get\"),\n ]\n for test in test_table:\n method, files, query_params, expected_status, expected_response, test_id = test\n with self.subTest(test=test_id):\n\n mock_response = MagicMock()\n mock_response.status_code = (\n HTTPStatus.OK if expected_response.get(\"upload\") != \"failed\" else HTTPStatus.INTERNAL_SERVER_ERROR\n )\n mock_response.reason = \"Some reason\"\n\n with patch(\"masu.api.ingest_ocp_payload.get_s3_signature\") as mock_get_s3_signature, patch(\n \"masu.api.ingest_ocp_payload.upload_file_to_s3\", return_value=mock_response\n ) as mock_upload_file_to_s3, patch(\"masu.api.ingest_ocp_payload.send_payload\") as mock_send_payload:\n\n request = (\n APIRequestFactory().get(f\"/{query_params}\")\n if method == \"GET\"\n else APIRequestFactory().post(\"/\", data=files)\n )\n\n # Act\n response = ingest_ocp_payload(request)\n\n # Assert\n self.assertEqual(response.status_code, expected_status)\n if method == \"POST\" and files:\n mock_get_s3_signature.assert_called()\n mock_upload_file_to_s3.assert_called()\n if method == \"GET\" and \"payload_name\" in query_params:\n mock_send_payload.assert_called()\n for key, value in expected_response.items():\n self.assertEqual(response.data[key], value)", + "completions": [ + "```python\n def test_ingest_ocp_payload(self):\n # Arrange\n request = self.factory.post(\n \"/api/v1/ingest/ocp/payload\",\n data={\"payload_name\": \"payload_name\"},\n content_type=\"multipart/form-data\",\n )\n request.FILES = {\"payload_name\": \"payload_name\"}\n # Act\n response = ingest_ocp_payload(request)\n # Assert\n self.assertEqual(response.status_code, HTTPStatus.ACCEPTED)\n self.assertEqual(response.data[\"request-id\"], \"request_id\")\n self.assertEqual(response.data[\"payload-name\"], [\"payload_name\"])\n self.assertEqual(response.data[\"ingest-started\"], True)\n```\n" + ], + "line": 41, + "token": 308, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 58, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def compute_nb_items_with_metadata_field(\n project_id: str,\n metadata_field: str,\n collection_name: str,\n quantile_value: float = 0.1,\n) -> Tuple[int, float, int]:\n\n mongo_db = await get_mongo_db()\n\n try:\n pipeline = [\n # Match on the project id and the existence of the metadata field\n {\n \"$match\": {\n \"project_id\": project_id,\n f\"metadata.{metadata_field}\": {\"$exists\": True, \"$ne\": None},\n }\n },\n {\n \"$group\": {\n \"_id\": \"$user_id\", # Group by 'user_id'\n \"count\": {\"$sum\": 1}, # Count the documents in each group\n }\n },\n {\n \"$sort\": {\"count\": -1} # Sort by 'count' in descending order\n },\n ]\n\n # Execute the aggregation pipeline\n results = [] # List to hold the results\n async for group in mongo_db[collection_name].aggregate(pipeline):\n results.append(group[\"count\"])\n\n logger.debug(\n f\"Results for {metadata_field} in collection {collection_name} for project {project_id}: {results}\"\n )\n\n # If no results, return 0\n if len(results) == 0:\n return 0, 0.0, 0\n\n # Get the average and the quantiles\n average = sum(results) / len(results)\n bottom_quantile = results[int(len(results) * quantile_value)]\n top_quantile = results[int(len(results) * (1 - quantile_value))]\n\n return bottom_quantile, average, top_quantile\n\n except Exception as e:\n logger.warning(\n f\"Failed to fetch the number of {metadata_field} in collection {collection_name} for project {project_id}: {e}\",\n )\n return 0, 0.0, 0\n```\n###test function signature:\n```python\n\nasync def test_main_pipeline(db, populated_project):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_main_pipeline", + "reference": "\nasync def test_main_pipeline(db, populated_project):\n async for mongo_db in db:\n test_project_id = populated_project.id\n\n # Run the function\n topics = await fetch_topics(test_project_id)\n\n logger.debug(f\"Topics: {topics}\")\n\n count = await nb_items_with_a_metadata_field(\n test_project_id, \"tasks\", \"user_id\"\n )\n\n assert count > 0\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_successrate_metadata_quantiles(test_project_id, \"user_id\")\n\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert bottom_quantile >= 0\n # assert bottom_quantile <= top_quantile\n\n test_project_id = populated_project.id\n test_metadata_field = \"user_id\"\n test_collection_name = \"tasks\"\n test_quantile_value = 0.1\n\n # Run the function\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_nb_items_with_metadata_field(\n test_project_id,\n test_metadata_field,\n test_collection_name,\n test_quantile_value,\n )\n\n logger.debug(\"runned compute_nb_items_with_metadata_field\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, int)\n assert isinstance(average, float)\n assert isinstance(top_quantile, int)\n # assert bottom_quantile <= average <= top_quantile\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_session_length_per_metadata(\n test_project_id,\n test_metadata_field,\n test_quantile_value,\n )\n logger.debug(\"runned compute_session_length_per_metadata\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, float)\n assert isinstance(average, float)\n assert isinstance(top_quantile, float)\n # assert bottom_quantile <= average <= top_quantile", + "completions": [ + "```python\nasync def test_main_pipeline(db, populated_project):\n # Arrange\n project_id = populated_project.id\n collection_name = \"collection_name\"\n metadata_field = \"metadata_field\"\n quantile_value = 0.1\n\n # Act\n result = await compute_nb_items_with_metadata_field(\n project_id, metadata_field, collection_name, quantile_value\n )\n\n # Assert\n assert result == (0, 0.0, 0)\n```\n" + ], + "line": 56, + "token": 430, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 59, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n```\n###test function signature:\n```python\n\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_engine_auto_error_message", + "reference": "\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n from pandas.compat._optional import VERSIONS\n\n # Do we have engines installed, but a bad version of them?\n pa_min_ver = VERSIONS.get(\"pyarrow\")\n fp_min_ver = VERSIONS.get(\"fastparquet\")\n have_pa_bad_version = (\n False\n if not _HAVE_PYARROW\n else Version(pyarrow.__version__) < Version(pa_min_ver)\n )\n have_fp_bad_version = (\n False\n if not _HAVE_FASTPARQUET\n else Version(fastparquet.__version__) < Version(fp_min_ver)\n )\n # Do we have usable engines installed?\n have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version\n have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version\n\n if not have_usable_pa and not have_usable_fp:\n # No usable engines found.\n if have_pa_bad_version:\n match = f\"Pandas requires version .{pa_min_ver}. or newer of .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n\n if have_fp_bad_version:\n match = f\"Pandas requires version .{fp_min_ver}. or newer of .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")", + "completions": [ + "```python\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n import pyarrow\n import fastparquet\n from pyarrow.parquet import ParquetEngine\n from fastparquet.parquet import ParquetEngine as FastParquetEngine\n\n def get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n\n # pyarrow is installed\n assert get_engine(\"auto\").__class__ == PyArrowImpl\n assert get_engine(\"pyarrow\").__class__ == PyArrowImpl\n assert get_engine(\"fastparquet\").__class__ == FastParquetImpl\n\n # pyarrow is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'pyarrow'\" in str(excinfo.value)\n\n # fastparquet is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'fastparquet'\" in str(excinfo.value)\n\n # pyarrow is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n\n # fastparquet is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n" + ], + "line": 38, + "token": 277, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 60, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def attribute(\n self,\n inp: InterpretableInput,\n target: Union[str, torch.Tensor, None] = None,\n num_trials: int = 1,\n gen_args: Optional[Dict] = None,\n # internal callback hook can be used for logging\n _inspect_forward: Optional[Callable] = None,\n **kwargs,\n ) -> LLMAttributionResult:\n\n assert isinstance(\n inp, self.SUPPORTED_INPUTS\n ), f\"LLMAttribution does not support input type {type(inp)}\"\n\n if target is None:\n # generate when None\n assert hasattr(self.model, \"generate\") and callable(self.model.generate), (\n \"The model does not have recognizable generate function.\"\n \"Target must be given for attribution\"\n )\n\n if not gen_args:\n gen_args = DEFAULT_GEN_ARGS\n\n model_inp = self._format_model_input(inp.to_model_input())\n output_tokens = self.model.generate(model_inp, **gen_args)\n target_tokens = output_tokens[0][model_inp.size(1) :]\n else:\n assert gen_args is None, \"gen_args must be None when target is given\"\n\n if type(target) is str:\n # exclude sos\n target_tokens = self.tokenizer.encode(target)[1:]\n target_tokens = torch.tensor(target_tokens)\n elif type(target) is torch.Tensor:\n target_tokens = target\n\n attr = torch.zeros(\n [\n 1 + len(target_tokens) if self.include_per_token_attr else 1,\n inp.n_itp_features,\n ],\n dtype=torch.float,\n device=self.device,\n )\n\n for _ in range(num_trials):\n attr_input = inp.to_tensor().to(self.device)\n\n cur_attr = self.attr_method.attribute(\n attr_input,\n additional_forward_args=(inp, target_tokens, _inspect_forward),\n **kwargs,\n )\n\n # temp necessary due to FA & Shapley's different return shape of multi-task\n # FA will flatten output shape internally (n_output_token, n_itp_features)\n # Shapley will keep output shape (batch, n_output_token, n_input_features)\n cur_attr = cur_attr.reshape(attr.shape)\n\n attr += cur_attr\n\n attr = attr / num_trials\n\n attr = inp.format_attr(attr)\n\n return LLMAttributionResult(\n attr[0],\n attr[1:]\n if self.include_per_token_attr\n else None, # shape(n_output_token, n_input_features)\n inp.values,\n self.tokenizer.convert_ids_to_tokens(target_tokens),\n )\n```\n###test function signature:\n```python\n\n def test_llm_attr(self, AttrClass) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_llm_attr", + "reference": "\n def test_llm_attr(self, AttrClass) -> None:\n llm = DummyLLM()\n llm.to(self.device)\n tokenizer = DummyTokenizer()\n llm_attr = LLMAttribution(AttrClass(llm), tokenizer)\n\n inp = TextTemplateInput(\"{} b {} {} e {}\", [\"a\", \"c\", \"d\", \"f\"])\n res = llm_attr.attribute(inp, \"m n o p q\")\n\n self.assertEqual(res.seq_attr.shape, (4,))\n self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))\n self.assertEqual(res.input_tokens, [\"a\", \"c\", \"d\", \"f\"])\n self.assertEqual(res.output_tokens, [\"m\", \"n\", \"o\", \"p\", \"q\"])\n self.assertEqual(res.seq_attr.device.type, self.device)\n self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)", + "completions": [ + "```python\n def test_llm_attr(self, AttrClass) -> None:\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for L" + ], + "line": 72, + "token": 464, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 61, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def plan(self, task: str):\n agent_pool_description = self.agent_pool.get_agent_pool_description()\n feedback = self.get_last_feedback()\n last_plan = self.get_last_plan()\n\n # run the planning\n plan = self.planning.execute(task, agent_pool_description, last_plan, feedback)\n\n logger.info(f\"Plan: {plan}\")\n\n self.shared_memory.add(\n event_type=EventType.planning, agent=self.name, content=str(plan)\n )\n\n return plan\n```\n###test function signature:\n```python\n\ndef test_planner_succeeds(get_llm): # noqa: F811\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_planner_succeeds", + "reference": "\ndef test_planner_succeeds(get_llm): # noqa: F811\n llm = get_llm(__file__, test_planner_succeeds.__name__)\n\n physicist_description = (\n \"The physicist agent answers questions or research about physics-related topics\"\n )\n physicist = Physicist(\n name=\"Physicist\",\n description=physicist_description,\n llm=llm,\n )\n\n agent_pool = AgentPool()\n agent_pool.add_agents([physicist])\n\n shared_memory = SharedMemory(\n objective=\"Share the information across different agents.\",\n agent_pool=agent_pool,\n )\n\n planner = Planner(\n name=\"planner\",\n agent_pool=agent_pool,\n shared_memory=shared_memory,\n llm=llm,\n )\n\n task = \"\"\"We need to render a highly complex 3D image on the solar system. We can use any publicly avaliable\n resources to achieve this task.\"\"\" # noqa: E501\n\n p = planner.plan(task=task)\n assert len(p.steps) > 0", + "completions": [ + "```python\ndef test_planner_succeeds(get_llm): # noqa: F811\n planner = get_llm\n planner.agent_pool.set_agent_pool_description(\n {\n \"agent_pool\": {\n \"agents\": [\n {\n \"name\": \"agent1\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n {\n \"name\": \"agent2\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n ]\n }\n }\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task1\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task2\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task3\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task4\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task5\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task6\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task7\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\"," + ], + "line": 21, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_edge(self, from_step: str, to_step: str) -> None:\n if from_step not in self.G:\n raise ValueError(f\"Step with name '{from_step}' does not exist\")\n\n if to_step not in self.G:\n raise ValueError(f\"Step with name '{to_step}' does not exist\")\n\n if to_step in self.G[from_step]:\n raise ValueError(\n f\"There is already a edge from '{to_step}' to '{from_step}'\"\n )\n\n if to_step in nx.ancestors(self.G, from_step):\n raise ValueError(\n f\"Cannot add edge from '{from_step}' to '{to_step}' as it would create a cycle.\"\n )\n\n self.G.add_edge(from_step, to_step)\n```\n###test function signature:\n```python\n\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validate_missing_step_input", + "reference": "\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n class DummyStep3(Step):\n @property\n def inputs(self) -> List[str]:\n return [\"instruction\"]\n\n @property\n def outputs(self) -> List[str]:\n return [\"response\"]\n\n def process(self) -> \"StepOutput\": # type: ignore\n yield [{\"response\": \"response1\"}]\n\n dag = DAG()\n dag.add_step(dummy_generator_step)\n dag.add_step(DummyStep3(name=\"dummy_step_3\", pipeline=pipeline))\n dag.add_edge(\"dummy_generator_step\", \"dummy_step_3\")\n\n with pytest.raises(\n ValueError,\n match=\"Step 'dummy_step_3' should have a parameter with type hint `StepInput`\",\n ):\n dag.validate()", + "completions": [ + "```python\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_2\")\n```\n" + ], + "line": 26, + "token": 238, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 63, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_hourly_temp_mean", + "reference": "\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n meter_data = il_electricity_cdd_hdd_hourly[\"meter_data\"][\"2016-03-01\":\"2016-07-01\"]\n temperature_data = il_electricity_cdd_hdd_hourly[\"temperature_data\"][\n \"2016-03-01\":\"2016-07-01\"\n ]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert list(sorted(df.columns)) == [\n \"n_hours_dropped\",\n \"n_hours_kept\",\n \"temperature_mean\",\n ]\n assert df.shape == (2952, 3)\n\n assert round(df.temperature_mean.mean()) == 62.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n df = il_electricity_cdd_hdd_hourly.copy()\n df = df.iloc[100:150]\n df = df.assign(\n **{\n \"cdd_10\": np.maximum(df.temperature_mean - 10, 0),\n \"cdd_20\": np.maximum(df.temperature_mean - 20, 0),\n \"hdd_10\": np.maximum(10 - df.temperature_mean, 0),\n \"hdd_20\": np.maximum(20 - df.temperature_mean, 0),\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n df = df.drop([\"temperature_mean\"], axis=1)\n df = df.reindex(il_electricity_cdd_hdd_hourly.index)\n df = overwrite_partial_rows_with_nan(df)\n df = df.iloc[:-1].reindex(df.index)\n assert_frame_equal(df, il_electricity_cdd_hdd_hourly)\n```\n" + ], + "line": 167, + "token": 992, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 64, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef pandasSQL_builder(\n con,\n schema: str | None = None,\n need_transaction: bool = False,\n) -> PandasSQL:\n import sqlite3\n\n if isinstance(con, sqlite3.Connection) or con is None:\n return SQLiteDatabase(con)\n\n sqlalchemy = import_optional_dependency(\"sqlalchemy\", errors=\"ignore\")\n\n if isinstance(con, str) and sqlalchemy is None:\n raise ImportError(\"Using URI string without sqlalchemy installed.\")\n\n if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):\n return SQLDatabase(con, schema, need_transaction)\n\n warnings.warn(\n \"pandas only supports SQLAlchemy connectable (engine/connection) or \"\n \"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 \"\n \"objects are not tested. Please consider using SQLAlchemy.\",\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return SQLiteDatabase(con)\n```\n###test function signature:\n```python\n\n def test_execute_closed_connection(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_execute_closed_connection", + "reference": "\n def test_execute_closed_connection(self):\n create_sql = \"\"\"\n CREATE TABLE test\n (\n a TEXT,\n b TEXT,\n c REAL,\n PRIMARY KEY (a, b)\n );\n \"\"\"\n with contextlib.closing(sqlite3.connect(\":memory:\")) as conn:\n cur = conn.cursor()\n cur.execute(create_sql)\n\n with sql.pandasSQL_builder(conn) as pandas_sql:\n pandas_sql.execute('INSERT INTO test VALUES(\"foo\", \"bar\", 1.234)')\n\n msg = \"Cannot operate on a closed database.\"\n with pytest.raises(sqlite3.ProgrammingError, match=msg):\n tquery(\"select * from test\", con=conn)", + "completions": [ + "```python\n def test_execute_closed_connection(self):\n from pandas.io.sql import PandasSQL\n from pandas.io.sql.sqldatabase import SQLiteDatabase\n from pandas.io.sql.pandas_sql import PandasSQLQueryCompiler\n from pandas.io.sql.pandas_sql import PandasSQLEngine\n from pandas.io.sql.pandas_sql import PandasSQLTable\n from pandas.io.sql.pandas_sql import PandasSQLConverter\n from pandas.io.sql.pandas_sql import PandasSQLAlchemy\n from pandas.io.sql.pandas_sql import PandasSQLQuery\n from pandas.io.sql.pandas_sql import PandasSQLExecutionContext\n from pandas.io.sql.pandas_sql import PandasSQLCompiler\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas." + ], + "line": 31, + "token": 237, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 65, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_hourly_degree_days", + "reference": "\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index,\n temperature_data,\n heating_balance_points=[60, 61],\n cooling_balance_points=[65, 66],\n temperature_mean=False,\n degree_day_method=\"hourly\",\n )\n assert df.shape == (27, 6)\n assert list(sorted(df.columns)) == [\n \"cdd_65\",\n \"cdd_66\",\n \"hdd_60\",\n \"hdd_61\",\n \"n_hours_dropped\",\n \"n_hours_kept\",\n ]\n snapshot.assert_match(\n [\n round(df.hdd_60.mean(), 2),\n round(df.hdd_61.mean(), 2),\n round(df.cdd_65.mean(), 2),\n round(df.cdd_66.mean(), 2),\n round(df.n_hours_kept.mean(), 2),\n round(df.n_hours_dropped.mean(), 2),\n ],\n \"values\",\n )", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"H\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 1" + ], + "line": 168, + "token": 985, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 66, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef initialize_model_parallel(\n tensor_model_parallel_size: int = 1,\n pipeline_model_parallel_size: int = 1,\n virtual_pipeline_model_parallel_size: Optional[int] = None,\n pipeline_model_parallel_split_rank: Optional[int] = None,\n use_fp8: bool = False,\n) -> None:\n # Get world size and rank. Ensure some consistencies.\n assert torch.distributed.is_initialized()\n world_size: int = torch.distributed.get_world_size()\n\n if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:\n raise RuntimeError(\n f\"world_size ({world_size}) is not divisible by tensor_model_parallel_size \"\n f\"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n )\n\n data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n\n num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n num_data_parallel_groups: int = world_size // data_parallel_size\n\n if virtual_pipeline_model_parallel_size is not None:\n if not pipeline_model_parallel_size > 2:\n raise RuntimeError(\"pipeline-model-parallel size should be greater than 2 with \" \"interleaved schedule\")\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size\n\n if pipeline_model_parallel_split_rank is not None:\n global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK\n _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank\n\n rank = torch.distributed.get_rank()\n\n # Build the data-parallel groups.\n global _DATA_PARALLEL_GROUP\n global _DATA_PARALLEL_GROUP_GLOO\n global _DATA_PARALLEL_GLOBAL_RANKS\n assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'\n all_data_parallel_group_ranks = []\n for i in range(pipeline_model_parallel_size):\n start_rank = i * num_pipeline_model_parallel_groups\n end_rank = (i + 1) * num_pipeline_model_parallel_groups\n for j in range(tensor_model_parallel_size):\n ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)\n all_data_parallel_group_ranks.append(list(ranks))\n group = torch.distributed.new_group(ranks)\n group_gloo = torch.distributed.new_group(ranks, backend=\"gloo\")\n if rank in ranks:\n _DATA_PARALLEL_GROUP = group\n _DATA_PARALLEL_GROUP_GLOO = group_gloo\n _DATA_PARALLEL_GLOBAL_RANKS = ranks\n\n # Build the model-parallel groups.\n global _MODEL_PARALLEL_GROUP\n assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'\n for i in range(data_parallel_size):\n ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _MODEL_PARALLEL_GROUP = group\n\n # Build the tensor model-parallel groups.\n global _TENSOR_MODEL_PARALLEL_GROUP\n assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized'\n for i in range(num_tensor_model_parallel_groups):\n ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _TENSOR_MODEL_PARALLEL_GROUP = group\n\n # Build the pipeline model-parallel groups and embedding groups\n # (first and last rank in each pipeline model-parallel group).\n global _PIPELINE_MODEL_PARALLEL_GROUP\n global _PIPELINE_GLOBAL_RANKS\n assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized'\n global _EMBEDDING_GROUP\n global _EMBEDDING_GLOBAL_RANKS\n assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'\n global _POSITION_EMBEDDING_GROUP\n global _POSITION_EMBEDDING_GLOBAL_RANKS\n assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'\n for i in range(num_pipeline_model_parallel_groups):\n ranks = range(i, world_size, num_pipeline_model_parallel_groups)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _PIPELINE_MODEL_PARALLEL_GROUP = group\n _PIPELINE_GLOBAL_RANKS = ranks\n # Setup embedding group (to exchange gradients between\n # first and last stages).\n if len(ranks) > 1:\n embedding_ranks = [ranks[0], ranks[-1]]\n position_embedding_ranks = [ranks[0]]\n if pipeline_model_parallel_split_rank is not None:\n if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:\n embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]]\n if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:\n position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]\n else:\n embedding_ranks = ranks\n position_embedding_ranks = ranks\n\n group = torch.distributed.new_group(embedding_ranks)\n if rank in embedding_ranks:\n _EMBEDDING_GROUP = group\n if rank in ranks:\n _EMBEDDING_GLOBAL_RANKS = embedding_ranks\n\n group = torch.distributed.new_group(position_embedding_ranks)\n if rank in position_embedding_ranks:\n _POSITION_EMBEDDING_GROUP = group\n if rank in ranks:\n _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks\n\n # Build the FP8 groups.\n global _AMAX_REDUCTION_GROUP\n assert _AMAX_REDUCTION_GROUP is None, \\\n 'FP8 amax reduction group is already initialized'\n if use_fp8:\n amax_group_size: int = tensor_model_parallel_size * data_parallel_size\n num_amax_groups: int = world_size // amax_group_size\n for i in range(num_amax_groups):\n start_rank = i * amax_group_size\n end_rank = (i + 1) * amax_group_size\n ranks = range(start_rank, end_rank)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _AMAX_REDUCTION_GROUP = group\n\n # Initialize global memory buffer\n # This isn't really \"parallel state\" but there isn't another good place to\n # put this. If we end up with a more generic initialization of megatron-core\n # we could stick it there\n _set_global_memory_buffer()\n```\n###test function signature:\n```python\n\ndef test_initialize__and_destroy_model_parallel():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_initialize__and_destroy_model_parallel", + "reference": "\ndef test_initialize__and_destroy_model_parallel():\n with pytest.raises(AssertionError):\n assert(ps.initialize_model_parallel())\n Utils.initialize_distributed()\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))\n Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)\n\n assert(ps.model_parallel_is_initialized())\n assert(ps.get_model_parallel_group() is not None)\n assert(ps.get_tensor_model_parallel_group() is not None)\n assert(ps.get_pipeline_model_parallel_group() is not None)\n assert(ps.get_data_parallel_group() is not None) \n Utils.destroy_model_parallel()\n assert(ps._MODEL_PARALLEL_GROUP is None)", + "completions": [ + "```python\ndef test_initialize_and_destroy_model_parallel():\n initialize_model_parallel(\n tensor_model_parallel_size=1,\n pipeline_model_parallel_size=1,\n virtual_pipeline_model_parallel_size=None,\n pipeline_model_parallel_split_rank=None,\n use_fp8=False,\n )\n destroy_model_parallel()\n```\n" + ], + "line": 133, + "token": 886, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 67, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_shapes(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_shapes", + "reference": "\n def test_spectrogram_shapes(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 729))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 183))", + "completions": [ + "```python\n def test_spectrogram_shapes(self):\n waveform = np.random.rand(1000)\n window = np.hanning(100)\n frame_length = 100\n hop_length = 10\n fft_length = 100\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (10, 100))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 68, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_base64_image_given_and_no_resize_needed", + "reference": "\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n _, image = example_local_image\n _, buffer = cv2.imencode(\".jpg\", image)\n base64_image = base64.b64encode(buffer).decode(\"ascii\")\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=base64_image,\n max_width=None,\n max_height=None,\n )\n\n # then\n assert (\n scaling_factor is None\n ), \"No resize parameters given, so scaling factor should not be established\"\n assert (\n serialised_image == base64_image\n ), \"Serialised image should be identical with input\"", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n image_bytes = base64.b64decode(example_local_image[0])\n image = bytes_to_opencv_image(payload=image_bytes)\n # when\n result = load_image_from_string(\n reference=example_local_image[0],\n max_height=None,\n max_width=None,\n )\n # then\n assert result[0] == example_local_image[0]\n assert result[1] is None\n assert np.array_equal(image, result[1])\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 69, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def submit_payout_address(self, robot_index=1):\n path = reverse(\"order\")\n params = f\"?order_id={self.order_id}\"\n headers = self.get_robot_auth(robot_index)\n\n payout_address = create_address(\"robot\")\n signed_payout_address = sign_message(\n payout_address,\n passphrase_path=f\"tests/robots/{robot_index}/token\",\n private_key_path=f\"tests/robots/{robot_index}/enc_priv_key\",\n )\n body = {\n \"action\": \"update_address\",\n \"address\": signed_payout_address,\n \"mining_fee_rate\": 50,\n }\n self.response = self.client.post(path + params, body, **headers)\n```\n###test function signature:\n```python\n def test_successful_onchain(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_successful_onchain", + "reference": " def test_successful_onchain(self):\n trade = Trade(self.client)\n trade.publish_order()\n trade.take_order()\n trade.lock_taker_bond()\n trade.lock_escrow(trade.taker_index)\n trade.submit_payout_address(trade.maker_index)\n trade.confirm_fiat(trade.maker_index)\n trade.confirm_fiat(trade.taker_index)\n\n trade.process_payouts(mine_a_block=True)\n trade.get_order(trade.maker_index)\n\n data = trade.response.json()\n\n self.assertEqual(trade.response.status_code, 200)\n self.assertResponse(trade.response)\n\n self.assertEqual(data[\"status_message\"], Order.Status(Order.Status.SUC).label)\n self.assertTrue(data[\"is_fiat_sent\"])\n self.assertFalse(data[\"is_disputed\"])\n self.assertIsInstance(data[\"maker_summary\"][\"address\"], str)\n self.assertIsHash(data[\"maker_summary\"][\"txid\"])\n\n self.assert_order_logs(data[\"id\"])", + "completions": [ + "```python\n def test_successful_onchain(self):\n self.submit_payout_address(robot_index=1)\n self.assertEqual(self.response.status_code, 200)\n self.assertEqual(self.response.json()[\"status\"], \"success\")\n```\n" + ], + "line": 26, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 70, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_blank_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_blank_line", + "reference": "\ndef test_header_to_metadata_and_cell_blank_line():\n text = \"\"\"---\ntitle: Sample header\n---\n\nHeader is followed by a blank line\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {}\n assert lines[pos].startswith(\"Header is\")", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_blank_line():\n lines = [\"#' ---\", \"\"]\n header_prefix = \"#'\"\n header_suffix = \"\"\n ext = None\n root_level_metadata_as_raw_cell = True\n metadata, jupyter, cell, i = header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext, root_level_metadata_as_raw_cell\n )\n assert metadata == {}\n assert jupyter == []\n assert cell == new_raw_cell(source=\"\\n\".join([\"---\"]), metadata={\"lines_to_next_cell\": 1})\n assert i == 1\n```\n" + ], + "line": 104, + "token": 642, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 71, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def _create_output(\n self,\n accepted: torch.Tensor, # [batch_size, k]\n recovered_token_ids: torch.Tensor, # [batch_size, k]\n draft_token_ids: torch.Tensor, # [batch_size, k]\n bonus_token_ids: torch.Tensor, # [batch_size]\n ) -> torch.Tensor:\n bonus_token_ids = bonus_token_ids.squeeze()\n batch_size, k = recovered_token_ids.shape\n\n # Determine the index of the first False value for each row.\n limits = (accepted == 0).max(1).indices\n limits[~(accepted == 0).any(1)] = k\n\n # Create masks using the indices.\n indices = torch.arange(k, device=accepted.device).unsqueeze(0)\n accepted_mask = indices < limits.unsqueeze(1)\n after_false_mask = indices == limits.unsqueeze(1)\n\n # Create an extended output tensor\n output_with_bonus_tokens = -torch.ones(\n (batch_size, k + self._num_bonus_tokens),\n dtype=self.token_id_dtype,\n device=accepted.device)\n output = output_with_bonus_tokens[:, :k]\n\n # Fill in the first k columns of the output tensor using masks and data\n # tensors.\n output[:, :k] = torch.where(accepted_mask, draft_token_ids,\n -torch.ones_like(draft_token_ids))\n\n # Fill the last column.\n # We check output directly as accepted may have True values inconsistent\n # with causal acceptance.\n output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,\n bonus_token_ids, -1)\n\n # We disable bonus tokens because it causes corrupt KV cache for\n # proposal methods that require KV cache. We can fix it by \"prefilling\"\n # the bonus token in the proposer. The following issue tracks the fix.\n # https://github.com/vllm-project/vllm/issues/4212\n output_with_bonus_tokens[:, -1] = -1\n\n # Fill the recovered token ids.\n output.mul_(~after_false_mask).add_(\n recovered_token_ids.mul(after_false_mask))\n\n self.num_accepted_tokens += accepted.sum()\n self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()\n self.num_draft_tokens += batch_size * k\n\n return output_with_bonus_tokens\n```\n###test function signature:\n```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct_output_format", + "reference": "def test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n set_random_seed(seed)\n torch.set_default_device(device)\n\n batch_size = 10\n k = 5\n vocab_size = 3000\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"no_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"some_tokens_accepted\":\n last_accepted_indices = torch.randint(low=-1,\n high=k,\n size=(batch_size, ))\n accepted = mock_causal_accepted_tensor(k, last_accepted_indices)\n else:\n raise AssertionError()\n\n recovered_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n draft_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n bonus_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, 1),\n dtype=torch.int64)\n\n rejection_sampler = RejectionSampler()\n rejection_sampler.init_gpu_tensors(rank=0)\n output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access\n accepted,\n recovered_token_ids,\n draft_token_ids,\n bonus_token_ids,\n )\n\n # Bonus tokens are currently disabled. Verify they're set to -1.\n # See https://github.com/vllm-project/vllm/issues/4212\n expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n # Expect all tokens to be equal to draft tokens.\n assert torch.equal(output_token_ids[:, :-1], draft_token_ids)\n\n # Expect all bonus tokens to be included.\n assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)\n elif which_tokens_accepted == \"no_tokens_accepted\":\n # Expect first token to be equal to recovered tokens.\n assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])\n\n # Expect everything else to be -1.\n assert torch.equal(output_token_ids[:, 1:],\n torch.ones_like(output_token_ids[:, 1:]) * -1)\n elif which_tokens_accepted == \"some_tokens_accepted\":\n recovered_plus_bonus = torch.cat(\n (recovered_token_ids, expected_bonus_token_ids), dim=-1)\n # Assert first rejected token is a recovered token or bonus token.\n assert torch.equal(\n recovered_plus_bonus[torch.arange(0, batch_size),\n last_accepted_indices + 1],\n output_token_ids[torch.arange(0, batch_size),\n last_accepted_indices + 1])\n\n # Assert every subsequent token is -1.\n subsequent_mask = torch.arange(0, k + 1).expand(\n batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)\n assert torch.all(output_token_ids[subsequent_mask] == -1)", + "completions": [ + "```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n torch.manual_seed(seed)\n batch_size = 10\n k = 10\n num_bonus_tokens = 1\n accepted = torch.randint(0, 2, (batch_size, k), dtype=torch.int32)\n recovered_token_ids = torch.randint(0, 10, (batch_size, k),\n dtype=torch.int32)\n draft_token_ids = torch.randint(0, 10, (batch_size, k), dtype=torch.int32)\n bonus_token_ids = torch.randint(0, 10, (batch_size,), dtype=torch.int32)\n output = _create_output(accepted, recovered_token_ids, draft_token_ids,\n bonus_token_ids)\n assert output.shape == (batch_size, k + num_bonus_tokens)\n assert output.dtype == torch.int32\n assert output.device == device\n assert (output[:, :k] == torch.where(accepted, draft_token_ids,\n -torch.ones_like(draft_token_ids))).all()\n assert (output[:, -1] == torch.where(output[:, -1] != -1, bonus_token_ids,\n -1)).all()\n assert (output[:, -1] == -1).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted" + ], + "line": 54, + "token": 433, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 72, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_array(fp, allow_pickle=False, pickle_kwargs=None, *,\n max_header_size=_MAX_HEADER_SIZE):\n if allow_pickle:\n # Effectively ignore max_header_size, since `allow_pickle` indicates\n # that the input is fully trusted.\n max_header_size = 2**64\n\n version = read_magic(fp)\n _check_version(version)\n shape, fortran_order, dtype = _read_array_header(\n fp, version, max_header_size=max_header_size)\n if len(shape) == 0:\n count = 1\n else:\n count = numpy.multiply.reduce(shape, dtype=numpy.int64)\n\n # Now read the actual data.\n if dtype.hasobject:\n # The array contained Python objects. We need to unpickle the data.\n if not allow_pickle:\n raise ValueError(\"Object arrays cannot be loaded when \"\n \"allow_pickle=False\")\n if pickle_kwargs is None:\n pickle_kwargs = {}\n try:\n array = pickle.load(fp, **pickle_kwargs)\n except UnicodeError as err:\n # Friendlier error message\n raise UnicodeError(\"Unpickling a python object failed: %r\\n\"\n \"You may need to pass the encoding= option \"\n \"to numpy.load\" % (err,)) from err\n else:\n if isfileobj(fp):\n # We can use the fast fromfile() function.\n array = numpy.fromfile(fp, dtype=dtype, count=count)\n else:\n # This is not a real file. We have to read it the\n # memory-intensive way.\n # crc32 module fails on reads greater than 2 ** 32 bytes,\n # breaking large reads from gzip streams. Chunk reads to\n # BUFFER_SIZE bytes to avoid issue and reduce memory overhead\n # of the read. In non-chunked case count < max_read_count, so\n # only one read is performed.\n\n # Use np.ndarray instead of np.empty since the latter does\n # not correctly instantiate zero-width string dtypes; see\n # https://github.com/numpy/numpy/pull/6430\n array = numpy.ndarray(count, dtype=dtype)\n\n if dtype.itemsize > 0:\n # If dtype.itemsize == 0 then there's nothing more to read\n max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)\n\n for i in range(0, count, max_read_count):\n read_count = min(max_read_count, count - i)\n read_size = int(read_count * dtype.itemsize)\n data = _read_bytes(fp, read_size, \"array data\")\n array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n count=read_count)\n\n if fortran_order:\n array.shape = shape[::-1]\n array = array.transpose()\n else:\n array.shape = shape\n\n return array\n```\n###test function signature:\n```python\n\ndef test_version_2_0():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_version_2_0", + "reference": "\ndef test_version_2_0():\n f = BytesIO()\n # requires more than 2 byte for header\n dt = [((\"%d\" % i) * 100, float) for i in range(500)]\n d = np.ones(1000, dtype=dt)\n\n format.write_array(f, d, version=(2, 0))\n with warnings.catch_warnings(record=True) as w:\n warnings.filterwarnings('always', '', UserWarning)\n format.write_array(f, d)\n assert_(w[0].category is UserWarning)\n\n # check alignment of data portion\n f.seek(0)\n header = f.readline()\n assert_(len(header) % format.ARRAY_ALIGN == 0)\n\n f.seek(0)\n n = format.read_array(f, max_header_size=200000)\n assert_array_equal(d, n)\n\n # 1.0 requested but data cannot be saved this way\n assert_raises(ValueError, format.write_array, f, d, (1, 0))", + "completions": [ + "```python\ndef test_version_2_0():\n fp = open(\"test_file.npy\", \"rb\")\n read_array(fp)\n```\n" + ], + "line": 70, + "token": 516, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 73, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):\n if not torch.is_tensor(audio):\n if isinstance(audio, str):\n audio = load_audio(audio)\n audio = torch.from_numpy(audio)\n\n window = torch.hann_window(N_FFT).to(audio.device)\n stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)\n magnitudes = stft[..., :-1].abs() ** 2\n\n filters = mel_filters(audio.device, n_mels)\n mel_spec = filters @ magnitudes\n\n log_spec = torch.clamp(mel_spec, min=1e-10).log10()\n log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)\n log_spec = (log_spec + 4.0) / 4.0\n return log_spec\n```\n###test function signature:\n```python\n\ndef test_audio():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_audio", + "reference": "\ndef test_audio():\n audio_path = os.path.join(os.path.dirname(__file__), \"jfk.flac\")\n audio = load_audio(audio_path)\n assert audio.ndim == 1\n assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12\n assert 0 < audio.std() < 1\n\n mel_from_audio = log_mel_spectrogram(audio)\n mel_from_file = log_mel_spectrogram(audio_path)\n\n assert np.allclose(mel_from_audio, mel_from_file)\n assert mel_from_audio.max() - mel_from_audio.min() <= 2.0", + "completions": [ + "```python\ndef test_audio():\n audio = torch.randn(1, 16000)\n log_mel_spectrogram(audio)\n```\n" + ], + "line": 24, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 74, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef transform_grid_cell(\n grid: Grid,\n cell: npt.NDArray[np.float_] = np.eye(3),\n output_shape: Optional[Tuple[int, int, int]] = None,\n mode: str = \"constant\",\n order: int = 1,\n **kwargs,\n) -> Grid:\n # Take the current shape of the grid if no output shape was provided\n if output_shape is None:\n output_shape = grid.shape\n\n # Make sure the cell has type float\n cell = np.asarray(cell, dtype=float)\n\n # Get the current cell in coordinates of the destination axes\n inv_cell = cell_invert(cell).T\n projected_cell = grid.cell.dot(inv_cell)\n\n # From that, infere how long will the bounding box of the cell be\n lengths = abs(projected_cell).sum(axis=0)\n\n # Create the transformation matrix. Since we want to control the shape\n # of the output, we can not use grid.dcell directly, we need to modify it.\n scales = output_shape / lengths\n forward_t = (grid.dcell.dot(inv_cell) * scales).T\n\n # Scipy's affine transform asks for the inverse transformation matrix, to\n # map from output pixels to input pixels. By taking the inverse of our\n # transformation matrix, we get exactly that.\n tr = cell_invert(forward_t).T\n\n # Calculate the offset of the image so that all points of the grid \"fall\" inside\n # the output array.\n # For this we just calculate the centers of the input and output images\n center_input = 0.5 * (_a.asarrayd(grid.shape) - 1)\n center_output = 0.5 * (_a.asarrayd(output_shape) - 1)\n\n # And then make sure that the input center that is interpolated from the output\n # falls in the actual input's center\n offset = center_input - tr.dot(center_output)\n\n # We pass all the parameters to scipy's affine_transform\n transformed_image = affine_transform(\n grid.grid,\n tr,\n order=1,\n offset=offset,\n output_shape=output_shape,\n mode=mode,\n **kwargs,\n )\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1))\n new_grid.grid = transformed_image\n new_grid.geometry = grid.geometry\n new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset))\n\n # Find the offset between the origin before and after the transformation\n return new_grid\n```\n###test function signature:\n```python\n\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_grid_cell", + "reference": "\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10))\n\n assert new_grid.shape == (10, 10, 10)\n assert new_grid.lattice.is_cartesian()\n\n if not skewed:\n assert np.allclose(new_grid.lattice.cell, grid.lattice.cell)\n\n assert np.allclose(new_grid.grid, grid.grid) == (not skewed)\n assert not np.allclose(new_grid.grid, 0)\n\n assert np.allclose(new_grid.origin, grid.origin) == (not skewed)\n\n # Convert to a skewed cell\n directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]])\n new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5))\n\n assert new_grid.shape == (5, 5, 5)\n for i in range(3):\n n = new_grid.lattice.cell[i] / directions[i]\n assert np.allclose(n, n[0])", + "completions": [ + "```python\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n cell = grid.cell.dot(cell_invert(grid.cell))\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = transform_grid_cell(grid, cell, grid.shape, mode=\"constant\", order=1)\n\n # Check that the new grid has the same shape and cell as the original grid\n assert new_grid.shape == grid.shape\n assert np.allclose(new_grid.cell, grid.cell)\n\n # Check that the new grid has the same geometry as the original grid\n assert new_grid.geometry == grid.geometry\n\n # Check that the new grid has the same origin as the original grid\n assert np.allclose(new_grid.origin, grid.origin)\n\n # Check that the new grid has the same dcell as the original grid\n assert np.allclose(new_grid.dcell, grid.dcell)\n\n # Check that the new grid has the same grid as the original grid\n assert np.allclose(new_grid.grid, grid.grid)\n\n # Check that the new grid has the same skewed as the original grid\n assert np.allclose(new_grid.skewed, grid.skewed)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid" + ], + "line": 63, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 75, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:\n\n index_norm = norm(index, axis=1)\n xq_norm = norm(xq.T)\n sim = np.dot(index, xq.T) / (index_norm * xq_norm)\n return sim\n```\n###test function signature:\n```python\ndef test_top_scores__is_sorted(test_index):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_top_scores__is_sorted", + "reference": "def test_top_scores__is_sorted(test_index):\n\n xq = test_index[0] # should have max similarity\n\n sim = similarity_matrix(xq, test_index)\n _, idx = top_scores(sim, 3)\n\n # Scores and indexes should be sorted ascending\n assert np.array_equal(idx, np.array([2, 1, 0]))", + "completions": [ + "```python\ndef test_top_scores__is_sorted(test_index):\n sim = similarity_matrix(xq, test_index)\n assert np.all(np.diff(sim, axis=1) >= 0)\n```\n" + ], + "line": 15, + "token": 125, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_mixed_ref_area_2", + "reference": "\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n\n out_bars_with_mixed_ref_area_2 = _generate_nanoplot(**CASES[16])\n\n assert _is_nanoplot_output(out_bars_with_mixed_ref_area_2)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n y_ref_area = [1, 10]\n y_ref_line = 5\n x_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n expand_x = \"2020-01-01 00:00:00 UTC\"\n expand_y = \"2020-01-01 00:00:00 UTC\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n all_single_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_" + ], + "line": 749, + "token": 7038, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 77, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def mean(self, lower=0.0, upper=np.inf):\n loc = []\n for idx in self.index:\n qpd = self.qpd.loc[idx, :].values[0]\n l, _ = quad(exp_func, args=(qpd), a=lower, b=upper)\n loc.append(l)\n loc_arr = np.array(loc)\n return pd.DataFrame(loc_arr, index=self.index, columns=self.columns)\n```\n###test function signature:\n```python\ndef test_qpd_b_simple_use():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_qpd_b_simple_use", + "reference": "def test_qpd_b_simple_use():\n from skpro.distributions.qpd import QPD_B\n\n qpd = QPD_B(\n alpha=0.2,\n qv_low=[1, 2],\n qv_median=[3, 4],\n qv_high=[5, 6],\n lower=0,\n upper=10,\n )\n\n qpd.mean()", + "completions": [ + "```python\ndef test_qpd_b_simple_use():\n import numpy as np\n import pandas as pd\n from scipy.integrate import quad\n from scipy.special import expit\n\n def exp_func(x, qpd):\n return expit(qpd * x)\n\n qpd = pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z'])\n qpd_b = qpd.mean(lower=0.0, upper=np.inf)\n assert qpd_b.equals(pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z']))\n```\n" + ], + "line": 18, + "token": 146, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef create(\n supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False\n) -> adapters.HTTPAdapter:\n settings = supply_settings()\n allowed_methods = set(util.Retry.DEFAULT_ALLOWED_METHODS)\n if include_post:\n allowed_methods.add(\"POST\")\n\n retries = util.Retry(\n total=settings.total,\n backoff_factor=settings.backoff_factor,\n status_forcelist=settings.statuses,\n allowed_methods=allowed_methods,\n )\n return adapters.HTTPAdapter(max_retries=retries)\n```\n###test function signature:\n```python\n\ndef test_should_create_adapter(include_post):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_should_create_adapter", + "reference": "\ndef test_should_create_adapter(include_post):\n total_retries = 5\n backoff_factor = 0.4\n statuses = (random.randint(400, 600),)\n settings = config.ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses)\n adapter = http_adapter.create(supply_settings=lambda: settings, include_post=include_post)\n assert adapter.max_retries.total == total_retries\n assert adapter.max_retries.backoff_factor == backoff_factor\n assert adapter.max_retries.status_forcelist == statuses\n allowed_methods = adapter.max_retries.allowed_methods or []\n if include_post:\n assert \"POST\" in allowed_methods\n else:\n assert \"POST\" not in allowed_methods", + "completions": [ + "```python\n\ndef test_should_create_adapter(include_post):\n assert create(include_post=include_post) == adapters.HTTPAdapter(max_retries=util.Retry(\n total=config.ResilienceSettings().total,\n backoff_factor=config.ResilienceSettings().backoff_factor,\n status_forcelist=config.ResilienceSettings().statuses,\n allowed_methods=set(util.Retry.DEFAULT_ALLOWED_METHODS) if not include_post else set(util.Retry.DEFAULT_ALLOWED_METHODS).union({\"POST\"})\n ))\n```\n" + ], + "line": 24, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 79, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:\n summary_path = os.path.join(trial_path, 'summary.csv')\n if not os.path.exists(summary_path):\n raise ValueError(f\"summary.csv does not exist in {trial_path}.\")\n trial_summary_df = load_summary_file(summary_path, dict_columns=['best_module_params'])\n config_yaml_path = os.path.join(trial_path, 'config.yaml')\n with open(config_yaml_path, 'r') as f:\n config_dict = yaml.safe_load(f)\n yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)\n if output_path is not None:\n with open(output_path, 'w') as f:\n yaml.dump(yaml_dict, f)\n return yaml_dict\n```\n###test function signature:\n```python\n\ndef test_runner(evaluator):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_runner", + "reference": "\ndef test_runner(evaluator):\n os.environ['BM25'] = 'bm25'\n evaluator.start_trial(os.path.join(resource_dir, 'simple.yaml'))\n project_dir = evaluator.project_dir\n\n def runner_test(runner: Runner):\n answer = runner.run('What is the best movie in Korea? Have Korea movie ever won Oscar?',\n 'retrieved_contents')\n assert len(answer) == 1\n assert isinstance(answer, list)\n assert isinstance(answer[0], str)\n\n runner = Runner.from_trial_folder(os.path.join(project_dir, '0'))\n runner_test(runner)\n runner_test(runner)\n\n with tempfile.NamedTemporaryFile(suffix='yaml', mode='w+t') as yaml_path:\n extract_best_config(os.path.join(project_dir, '0'), yaml_path.name)\n runner = Runner.from_yaml(yaml_path.name, project_dir=project_dir)\n runner_test(runner)", + "completions": [ + "```python\ndef test_runner(evaluator):\n def test_extract_best_config():\n trial_path = 'path/to/trial'\n output_path = 'path/to/output'\n yaml_dict = extract_best_config(trial_path, output_path)\n assert yaml_dict is not None\n assert isinstance(yaml_dict, dict)\n assert 'best_module_params' in yaml_dict\n assert isinstance(yaml_dict['best_module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['" + ], + "line": 23, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 80, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def init_weights(self):\n logger = MMLogger.get_current_instance()\n if self.init_cfg is None:\n logger.warn(f'No pre-trained weights for '\n f'{self.__class__.__name__}, '\n f'training start from scratch')\n if self.use_abs_pos_embed:\n trunc_normal_(self.absolute_pos_embed, std=0.02)\n for m in self.modules():\n if isinstance(m, nn.Linear):\n trunc_normal_init(m, std=.02, bias=0.)\n elif isinstance(m, nn.LayerNorm):\n constant_init(m, 1.0)\n else:\n assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n f'specify `Pretrained` in ' \\\n f'`init_cfg` in ' \\\n f'{self.__class__.__name__} '\n ckpt = CheckpointLoader.load_checkpoint(\n self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n if 'state_dict' in ckpt:\n _state_dict = ckpt['state_dict']\n elif 'model' in ckpt:\n _state_dict = ckpt['model']\n else:\n _state_dict = ckpt\n if self.convert_weights:\n # supported loading weight from original repo,\n _state_dict = swin_converter(_state_dict)\n\n state_dict = OrderedDict()\n for k, v in _state_dict.items():\n if k.startswith('backbone.'):\n state_dict[k[9:]] = v\n\n # strip prefix of state_dict\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n # reshape absolute position embedding\n if state_dict.get('absolute_pos_embed') is not None:\n absolute_pos_embed = state_dict['absolute_pos_embed']\n N1, L, C1 = absolute_pos_embed.size()\n N2, C2, H, W = self.absolute_pos_embed.size()\n if N1 != N2 or C1 != C2 or L != H * W:\n logger.warning('Error in loading absolute_pos_embed, pass')\n else:\n state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n # interpolate position bias table if needed\n relative_position_bias_table_keys = [\n k for k in state_dict.keys()\n if 'relative_position_bias_table' in k\n ]\n for table_key in relative_position_bias_table_keys:\n table_pretrained = state_dict[table_key]\n table_current = self.state_dict()[table_key]\n L1, nH1 = table_pretrained.size()\n L2, nH2 = table_current.size()\n if nH1 != nH2:\n logger.warning(f'Error in loading {table_key}, pass')\n elif L1 != L2:\n S1 = int(L1**0.5)\n S2 = int(L2**0.5)\n table_pretrained_resized = F.interpolate(\n table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n size=(S2, S2),\n mode='bicubic')\n state_dict[table_key] = table_pretrained_resized.view(\n nH2, L2).permute(1, 0).contiguous()\n\n # load state_dict\n self.load_state_dict(state_dict, False)\n```\n###test function signature:\n```python\ndef test_swin_transformer():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_swin_transformer", + "reference": "def test_swin_transformer():\n\n with pytest.raises(TypeError):\n # Pretrained arg must be str or None.\n SwinTransformer(pretrained=123)\n\n with pytest.raises(AssertionError):\n # Because swin uses non-overlapping patch embed, so the stride of patch\n # embed must be equal to patch size.\n SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)\n\n # test pretrained image size\n with pytest.raises(AssertionError):\n SwinTransformer(pretrain_img_size=(224, 224, 224))\n\n # Test absolute position embedding\n temp = torch.randn((1, 3, 224, 224))\n model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)\n model.init_weights()\n model(temp)\n\n # Test patch norm\n model = SwinTransformer(patch_norm=False)\n model(temp)\n\n # Test normal inference\n temp = torch.randn((1, 3, 32, 32))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 31, 31))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 112, 137))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 28, 35)\n assert outs[1].shape == (1, 192, 14, 18)\n assert outs[2].shape == (1, 384, 7, 9)\n assert outs[3].shape == (1, 768, 4, 5)\n\n model = SwinTransformer(frozen_stages=4)\n model.train()\n for p in model.parameters():\n assert not p.requires_grad", + "completions": [ + "```python\ndef test_swin_transformer():\n from mmcv.cnn import (build_pos_embed, build_pos_embed_sine,\n constant_init, trunc_normal_init)\n from mmcv.runner import CheckpointLoader\n from mmdet.models.builder import build_backbone\n from mmdet.models.utils import swin_converter\n from mmdet.utils import get_root_logger\n from mmdet.utils.logger import MMLogger\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import PatchEmbed\n from mmdet.models.backbones.swin import PatchMerging\n from mmdet.models.backbones.swin import BasicLayer\n from mmdet.models.backbones.swin import BasicBlock\n from mmdet.models.backbones.swin import DepthwiseSeparableConv\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransform" + ], + "line": 79, + "token": 586, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 81, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\n\ndef test_date_columns_equal():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_date_columns_equal", + "reference": "\ndef test_date_columns_equal():\n data = \"\"\"a|b|expected\n2017-01-01|2017-01-01|True\n2017-01-02|2017-01-02|True\n2017-10-01|2017-10-10|False\n2017-01-01||False\n|2017-01-01|False\n||True\"\"\"\n df = pl.read_csv(\n io.StringIO(data),\n separator=\"|\",\n null_values=[\"NULL\"],\n missing_utf8_is_empty_string=True,\n )\n # First compare just the strings\n actual_out = columns_equal(df[\"a\"], df[\"b\"], rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)\n\n # Then compare converted to datetime objects\n col_a = df[\"a\"].str.to_datetime(strict=False)\n col_b = df[\"b\"].str.to_datetime(strict=False)\n actual_out = columns_equal(col_a, col_b, rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)", + "completions": [ + "```python\ndef test_date_columns_equal():\n col_1 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n col_2 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 82, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def upsert_child_firing(\n db: PrefectDBInterface,\n session: AsyncSession,\n firing: Firing,\n):\n automation_id = firing.trigger.automation.id\n parent_trigger_id = firing.trigger.parent.id\n child_trigger_id = firing.trigger.id\n\n upsert = (\n postgresql.insert(db.CompositeTriggerChildFiring)\n .values(\n automation_id=automation_id,\n parent_trigger_id=parent_trigger_id,\n child_trigger_id=child_trigger_id,\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n )\n .on_conflict_do_update(\n index_elements=[\n db.CompositeTriggerChildFiring.automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id,\n ],\n set_=dict(\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n updated=pendulum.now(\"UTC\"),\n ),\n )\n )\n\n await session.execute(upsert)\n\n result = await session.execute(\n sa.select(db.CompositeTriggerChildFiring).filter(\n db.CompositeTriggerChildFiring.automation_id == automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id == parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id == child_trigger_id,\n )\n )\n\n return result.scalars().one()\n```\n###test function signature:\n```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_insert_compound_trigger_child_firing", + "reference": "\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n firing = Firing(\n trigger=saw_a_red_dragon,\n trigger_states={TriggerState.Triggered},\n triggered=frozen_time.add(seconds=1),\n triggering_labels={},\n triggering_event=baby_red_dragon_passed_by,\n )\n res = await upsert_child_firing(\n session=automations_session,\n firing=firing,\n )\n\n assert res and res.id\n assert res.automation_id == raise_the_alarm.id\n assert res.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert res.child_trigger_id == saw_a_red_dragon.id\n assert res.child_firing_id == firing.id\n assert res.child_fired_at == firing.triggered\n\n child_firings = await get_child_firings(\n session=automations_session, trigger=saw_a_red_and_green_dragon\n )\n assert len(child_firings) == 1\n assert child_firings[0].id == res.id\n assert child_firings[0].child_firing_id == firing.id", + "completions": [ + "```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n \"\"\"\n Test that a compound trigger child firing is inserted correctly\n \"\"\"\n # Arrange\n firing = Firing(\n id=uuid4(),\n triggered=frozen_time,\n trigger=saw_a_red_and_green_dragon,\n event=baby_red_dragon_passed_by,\n )\n\n # Act\n result = await upsert_child_firing(\n db=automations_session,\n session=automations_session,\n firing=firing,\n )\n\n # Assert\n assert result.automation_id == raise_the_alarm.id\n assert result.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert result.child_trigger_id == saw_a_red_dragon.id\n assert result.child_firing_id == firing.id\n assert result.child_fired_at == firing.triggered\n assert result.child_firing == firing.dict()\n assert result.updated == frozen_time\n\n```\n" + ], + "line": 58, + "token": 225, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 83, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef write_output(output_file, output_type, results):\n\n # If json or txt, write contents directly to file.\n if output_type == \"json\" or output_type == \"txt\":\n with open(output_file, \"w\") as f:\n json.dump(results, f, indent=4)\n return\n # If csv or xlsx, convert results to pandas dataframes.\n urls_df = get_urls_df(results)\n codes_df = get_codes_df(results)\n\n # If csv, write dataframes to separate csv files for urls, codes.\n if output_type == \"csv\":\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n urls_df.to_csv(urls_output_file, index=False)\n codes_df.to_csv(codes_output_file, index=False)\n\n # If xlsx, write dataframes to separate sheets for urls, codes.\n if output_type == \"xlsx\":\n writer = pd.ExcelWriter(output_file, engine=\"xlsxwriter\")\n urls_df.to_excel(writer, sheet_name=\"URLs\", index=False)\n codes_df.to_excel(writer, sheet_name=\"Codes\", index=False)\n writer.close()\n```\n###test function signature:\n```python\n def test_write_output_csv(self, mock_urls, mock_codes):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_output_csv", + "reference": " def test_write_output_csv(self, mock_urls, mock_codes):\n\n test_file = \"./test_output/test_file.csv\"\n test_file_urls = \"./test_output/test_file_urls.csv\"\n test_file_codes = \"./test_output/test_file_codes.csv\"\n test_results = {\"test\": \"test\"}\n mock_urls.return_value = pd.DataFrame([test_results])\n mock_codes.return_value = pd.DataFrame([test_results])\n\n with open(test_file_urls, \"w\") as f:\n pass\n\n with open(test_file_codes, \"w\") as f:\n pass\n\n write_output(test_file, \"csv\", test_results)\n\n test_data_urls = pd.read_csv(test_file_urls).to_dict(orient=\"records\")[0]\n test_data_codes = pd.read_csv(test_file_codes).to_dict(orient=\"records\")[0]\n\n os.remove(test_file_urls)\n os.remove(test_file_codes)\n\n self.assertEqual(test_data_urls, test_results)\n self.assertEqual(test_data_codes, test_results)", + "completions": [ + "```python\ndef test_write_output_csv(self, mock_urls, mock_codes):\n \"\"\"\n Test write_output function with csv output type.\n \"\"\"\n output_file = \"output.csv\"\n output_type = \"csv\"\n results = {\"urls\": mock_urls, \"codes\": mock_codes}\n write_output(output_file, output_type, results)\n\n # Check that the urls and codes dataframes are written to separate csv files.\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n self.assertTrue(os.path.isfile(urls_output_file))\n self.assertTrue(os.path.isfile(codes_output_file))\n\n # Check that the dataframes are correct.\n urls_df = pd.read_csv(urls_output_file)\n codes_df = pd.read_csv(codes_output_file)\n self.assertEqual(urls_df.shape, (2, 2))\n self.assertEqual(codes_df.shape, (2, 2))\n self.assertEqual(urls_df.iloc[0, 0], \"url1\")\n self.assertEqual(codes_df.iloc[0, 0], \"code1\")\n self.assertEqual(urls_df.iloc[1, 0], \"url2\")\n self.assertEqual(codes_df.iloc[1, 0], \"code2\")\n```\n" + ], + "line": 31, + "token": 279, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 84, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def terminate_idle_instance() -> None:\n async with get_session_ctx() as session:\n async with PROCESSING_POOL_LOCK:\n res = await session.execute(\n select(InstanceModel)\n .where(\n InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE,\n InstanceModel.deleted == False,\n InstanceModel.job == None, # noqa: E711\n InstanceModel.status == InstanceStatus.IDLE,\n )\n .options(joinedload(InstanceModel.project))\n )\n instances = res.scalars().all()\n\n for instance in instances:\n last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)\n if instance.last_job_processed_at is not None:\n last_time = instance.last_job_processed_at.replace(\n tzinfo=datetime.timezone.utc\n )\n\n idle_seconds = instance.termination_idle_time\n delta = datetime.timedelta(seconds=idle_seconds)\n\n current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc)\n\n if last_time + delta < current_time:\n jpd = JobProvisioningData.__response__.parse_raw(\n instance.job_provisioning_data\n )\n await terminate_job_provisioning_data_instance(\n project=instance.project, job_provisioning_data=jpd\n )\n instance.deleted = True\n instance.deleted_at = get_current_datetime()\n instance.finished_at = get_current_datetime()\n instance.status = InstanceStatus.TERMINATED\n instance.termination_reason = \"Idle timeout\"\n\n idle_time = current_time - last_time\n logger.info(\n \"Instance %s terminated by termination policy: idle time %ss\",\n instance.name,\n str(idle_time.seconds),\n extra={\n \"instance_name\": instance.name,\n \"instance_status\": InstanceStatus.TERMINATED.value,\n },\n )\n\n await session.commit()\n```\n###test function signature:\n```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_terminate_by_idle_timeout", + "reference": "\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n project = await create_project(session=session)\n pool = await create_pool(session, project)\n\n instance = await create_instance(session, project, pool, status=InstanceStatus.IDLE)\n instance.termination_idle_time = 300\n instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE\n instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)\n await session.commit()\n\n with patch(\n \"dstack._internal.server.background.tasks.process_instances.terminate_job_provisioning_data_instance\"\n ):\n await terminate_idle_instance()\n\n await session.refresh(instance)\n\n assert instance is not None\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"", + "completions": [ + "```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n # given\n project = ProjectModel(name=\"test_project\")\n await session.add(project)\n await session.commit()\n\n instance = InstanceModel(\n name=\"test_instance\",\n project=project,\n termination_policy=TerminationPolicy.DESTROY_AFTER_IDLE,\n termination_idle_time=10,\n status=InstanceStatus.IDLE,\n created_at=datetime.datetime.now(datetime.timezone.utc),\n )\n await session.add(instance)\n await session.commit()\n\n # when\n await terminate_idle_instance()\n\n # then\n instance = await session.get(InstanceModel, instance.id)\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"\n```\n" + ], + "line": 56, + "token": 289, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 85, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef nargsort(\n items,\n kind: str = \"quicksort\",\n ascending: bool = True,\n na_position: str = \"last\",\n key: Callable | None = None,\n mask: npt.NDArray[np.bool_] | None = None,\n) -> npt.NDArray[np.intp]:\n\n if key is not None:\n items = ensure_key_mapped(items, key)\n return nargsort(\n items,\n kind=kind,\n ascending=ascending,\n na_position=na_position,\n key=None,\n mask=mask,\n )\n\n if isinstance(items, ABCRangeIndex):\n return items.argsort(ascending=ascending) # TODO: test coverage with key?\n elif not isinstance(items, ABCMultiIndex):\n items = extract_array(items)\n if mask is None:\n mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?\n\n if is_extension_array_dtype(items):\n return items.argsort(ascending=ascending, kind=kind, na_position=na_position)\n else:\n items = np.asanyarray(items)\n\n idx = np.arange(len(items))\n non_nans = items[~mask]\n non_nan_idx = idx[~mask]\n\n nan_idx = np.nonzero(mask)[0]\n if not ascending:\n non_nans = non_nans[::-1]\n non_nan_idx = non_nan_idx[::-1]\n indexer = non_nan_idx[non_nans.argsort(kind=kind)]\n if not ascending:\n indexer = indexer[::-1]\n # Finally, place the NaNs at the end or the beginning according to\n # na_position\n if na_position == \"last\":\n indexer = np.concatenate([indexer, nan_idx])\n elif na_position == \"first\":\n indexer = np.concatenate([nan_idx, indexer])\n else:\n raise ValueError(f\"invalid na_position: {na_position}\")\n return ensure_platform_int(indexer)\n```\n###test function signature:\n```python\n\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nargsort", + "reference": "\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)\n\n # mergesort is the most difficult to get right because we want it to be\n # stable.\n\n # According to numpy/core/tests/test_multiarray, \"\"\"The number of\n # sorted items must be greater than ~50 to check the actual algorithm\n # because quick and merge sort fall over to insertion sort for small\n # arrays.\"\"\"\n\n result = nargsort(\n items, kind=\"mergesort\", ascending=ascending, na_position=na_position\n )\n tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)", + "completions": [ + "```python\n def test_nargsort(self, ascending, na_position, exp, box):\n items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216," + ], + "line": 58, + "token": 412, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 86, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_rotated_rb_circuits(\n n_qubits: int,\n num_cliffords: int,\n theta: Optional[float] = None,\n trials: int = 1,\n return_type: Optional[str] = None,\n seed: Optional[int] = None,\n) -> List[QPROGRAM]:\n\n circuits = cast(\n List[cirq.Circuit],\n generate_rb_circuits(n_qubits, num_cliffords, trials, seed=seed),\n )\n\n if theta is None:\n theta = 2 * np.pi * np.random.rand()\n\n for circ in circuits:\n qubits = list(circ.all_qubits())\n circ.insert(len(circ) // 2, cirq.Rz(rads=theta).on(qubits[0]))\n\n return_type = \"cirq\" if not return_type else return_type\n return [convert_from_mitiq(circuit, return_type) for circuit in circuits]\n```\n###test function signature:\n```python\n\ndef test_rotated_rb_circuits(n_qubits, theta):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rotated_rb_circuits", + "reference": "\ndef test_rotated_rb_circuits(n_qubits, theta):\n depth = 10\n for trials in [2, 3]:\n circuits = generate_rotated_rb_circuits(\n n_qubits=n_qubits, num_cliffords=depth, theta=theta, trials=trials\n )\n for circ in circuits:\n zero_prob = (\n cirq.DensityMatrixSimulator()\n .simulate(circ)\n .final_density_matrix[0, 0]\n .real\n )\n assert -1.0001 <= zero_prob <= 1.0001", + "completions": [ + "```python\ndef test_rotated_rb_circuits(n_qubits, theta):\n from cirq import Circuit, LineQubit\n from cirq.testing import assert_has_diagram\n from mitiq.generate import generate_rotated_rb_circuits\n\n circuits = generate_rotated_rb_circuits(n_qubits, 10, theta)\n for circuit in circuits:\n assert_has_diagram(\n circuit,\n f\"\"\"\n0: ───Rz({theta})───\n1: ───Rz({theta})───\n2: ───Rz({theta})───\n3: ───Rz({theta})───\n4: ───Rz({theta})───\n5: ───Rz({theta})───\n6: ───Rz({theta})───\n7: ───Rz({theta})───\n8: ───Rz({theta})───\n9: ───Rz({theta})───\n10: ───Rz({theta})───\n11: ───Rz({theta})───\n12: ───Rz({theta})───\n13: ───Rz({theta})───\n14: ───Rz({theta})───\n15: ───Rz({theta})───\n16: ───Rz({theta})───\n17: ───Rz({theta})───\n18: ───Rz({theta})───\n19: ───Rz({theta})───\n20: ───Rz({theta})───\n21: ───Rz({theta})───\n22: ───Rz({theta})───\n23: ───Rz({theta})───\n24: ───Rz({theta})───\n25: ───Rz({theta})───\n26: ───Rz({theta})───\n27: ───Rz({theta})───\n28: ───Rz({theta})───\n29: ───Rz({theta})───\n30: ───Rz({theta})───\n31: ───Rz({theta})───\n32: ───Rz({theta})───\n33: ───Rz({theta})───\n34: ───Rz({theta})───\n35: ───Rz({theta})───\n36: ───Rz({theta})───\n37: ───Rz({theta})───\n38: ───Rz({theta})───\n39: ───Rz({theta})───\n40: ───Rz({theta})───\n41: ───Rz({theta})───\n42: ───Rz({theta})───\n43: ───Rz({theta})───\n44: ───Rz({theta})───\n45: ───Rz({theta})───\n46: ───Rz({theta})───\n47: ───Rz({theta})───\n48: ───Rz({theta})───\n49: ───Rz({theta})───\n50: ───Rz({theta})───\n51: ───Rz({theta})───\n52: ───Rz({theta})───\n53: ───Rz({theta})───\n54: ───Rz({theta})───\n55: ───Rz({theta})───\n56: ───Rz({theta})───\n57: ───Rz({theta})───\n58: ───Rz({theta})───\n59: ───Rz({theta})───\n60: ───Rz({theta})───\n61: ───Rz({theta})───\n62: ───Rz({theta})───\n63: ───" + ], + "line": 29, + "token": 226, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 87, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef eye(N, M=None, k=None, dtype=None):\n dtype = to_torch_dtype(dtype or config.floatx())\n M = N if M is None else M\n k = 0 if k is None else k\n if k == 0:\n # TODO: torch.eye doesn't support bfloat16 with cpu\n if get_device() == \"cpu\" and dtype == torch.bfloat16:\n return cast(\n torch.eye(\n N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n ),\n dtype,\n )\n return torch.eye(N, M, dtype=dtype, device=get_device())\n diag_length = builtins.max(N, M)\n diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n return torch.diag(diag, diagonal=k)[:N, :M]\n```\n###test function signature:\n```python\n\n def test_eye(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_eye", + "reference": "\n def test_eye(self):\n self.assertAllClose(knp.eye(3), np.eye(3))\n self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n self.assertAllClose(knp.Eye()(3), np.eye(3))\n self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1))\n\n # Test k >= N\n self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3))\n\n # Test k > 0 and N >= M\n self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1))\n\n # Test k > 0 and N < M and N + k > M\n self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2))\n\n # Test k < 0 and M >= N\n self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1))\n\n # Test k < 0 and M < N and M - k > N\n self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))", + "completions": [ + "```python\n\n def test_eye(self):\n self.assertEqual(eye(3), torch.eye(3))\n self.assertEqual(eye(3, 4), torch.eye(3, 4))\n self.assertEqual(eye(3, 4, 1), torch.eye(3, 4, 1))\n self.assertEqual(eye(3, 4, -1), torch.eye(3, 4, -1))\n self.assertEqual(eye(3, 4, 2), torch.eye(3, 4, 2))\n self.assertEqual(eye(3, 4, -2), torch.eye(3, 4, -2))\n self.assertEqual(eye(3, 4, 3), torch.eye(3, 4, 3))\n self.assertEqual(eye(3, 4, -3), torch.eye(3, 4, -3))\n self.assertEqual(eye(3, 4, 4), torch.eye(3, 4, 4))\n self.assertEqual(eye(3, 4, -4), torch.eye(3, 4, -4))\n self.assertEqual(eye(3, 4, 5), torch.eye(3, 4, 5))\n self.assertEqual(eye(3, 4, -5), torch.eye(3, 4, -5))\n self.assertEqual(eye(3, 4, 6), torch.eye(3, 4, 6))\n self.assertEqual(eye(3, 4, -6), torch.eye(3, 4, -6))\n self.assertEqual(eye(3, 4, 7), torch.eye(3, 4, 7))\n self.assertEqual(eye(3, 4, -7), torch.eye(3, 4, -7))\n self.assertEqual(eye(3, 4, 8), torch.eye(3, 4, 8))\n self.assertEqual(eye(3, 4, -8), torch.eye(3, 4, -8))\n self.assertEqual(eye(3, 4, 9), torch.eye(3, 4, 9))\n self.assertEqual(eye(3, 4, -9), torch.eye(3, 4, -9))\n self.assertEqual(eye(3, 4, 10), torch.eye(3, 4, 10))\n self.assertEqual(eye(3, 4, -10), torch.eye(3, 4, -10))\n self.assertEqual(eye(3, 4, 11), torch.eye(3, 4, 11))\n self.assertEqual(eye(3, 4, -11), torch.eye(3, 4, -11))\n self.assertEqual(eye(3, 4, 12), torch.eye(3, 4, 12))\n self.assertEqual(eye(3, 4, -12), torch.eye(3, 4, -12))\n self.assertEqual(eye(3, 4, 13), torch.eye(3, 4, 13))\n self.assertEqual(eye(3, 4, -13), torch.eye(3, 4, -13))\n self.assertEqual(eye(3, 4, 14), torch.eye(3, 4, 14))\n self.assertEqual(eye(3, 4, -14), torch.eye(3, 4, -14))\n self.assertEqual(eye(3, 4, 15), torch.eye(" + ], + "line": 27, + "token": 219, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef ragged_qk_dotprod(\n query: RaggedActivations, key: RaggedActivations, lut: RaggedQkPidLookupTable\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n # check constraints\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n # allocates output\n # max_n_ctx_q_across_seqs = query.max_n_ctx_per_seq\n\n assert query.n_seqs == key.n_seqs\n # TODO: flag use zeros for garbage\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n # Integers\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n```\n###test function signature:\n```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ragged_qk_dotprod_single_seq", + "reference": "\ndef test_ragged_qk_dotprod_single_seq(dtype):\n d_head = 2\n\n key = RaggedActivations.from_list(\n [\n make_seq(n_ctx=3, value=42, d_model=d_head, dtype=dtype),\n ]\n )\n query = RaggedActivations.from_list(\n [\n make_seq(n_ctx=4, value=55, d_model=d_head, dtype=dtype),\n ]\n )\n torch_scores = scores_via_qk_dotprod(query, key)\n print(f\"{torch_scores=}\")\n\n lut = RaggedQkPidLookupTable.from_query_and_key_tokens_per_seq(\n n_ctx_q_per_seq=query.n_ctx_per_seq, n_ctx_k_per_seq=key.n_ctx_per_seq\n )\n\n scores = ragged_qk_dotprod(query, key, lut)\n assert_eq(torch_scores, scores)", + "completions": [ + "```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n # query\n query_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n query = RaggedActivations(\n raw_tensor=query_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # key\n key_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n key = RaggedActivations(\n raw_tensor=key_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # lut\n pid_to_in_q_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_in_k_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_q_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_k_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_seq_idx = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n lut = RaggedQkPidLookupTable(\n pid_to_in_q_token_offset=pid_to_in_q_token_offset,\n pid_to_in_k_token_offset=pid_to_in_k_token_offset,\n pid_to_out_q_block=pid_to_out_q_block,\n pid_to_out_k_block=pid_to_out_k_block,\n pid_to_out_seq_idx=pid_to_out_seq_idx,\n n_pids_total=1,\n )\n\n # scores_out\n scores_out = torch.ones(\n (1, 1, 1), dtype=dtype, device=torch.device(\"cuda\")\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_" + ], + "line": 57, + "token": 310, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 89, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):\n # Ravel both arrays, behavior for the first array could be different\n ar1 = np.asarray(ar1).ravel()\n ar2 = np.asarray(ar2).ravel()\n\n # Ensure that iteration through object arrays yields size-1 arrays\n if ar2.dtype == object:\n ar2 = ar2.reshape(-1, 1)\n\n if kind not in {None, 'sort', 'table'}:\n raise ValueError(\n f\"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.\")\n\n # Can use the table method if all arrays are integers or boolean:\n is_int_arrays = all(ar.dtype.kind in (\"u\", \"i\", \"b\") for ar in (ar1, ar2))\n use_table_method = is_int_arrays and kind in {None, 'table'}\n\n if use_table_method:\n if ar2.size == 0:\n if invert:\n return np.ones_like(ar1, dtype=bool)\n else:\n return np.zeros_like(ar1, dtype=bool)\n\n # Convert booleans to uint8 so we can use the fast integer algorithm\n if ar1.dtype == bool:\n ar1 = ar1.astype(np.uint8)\n if ar2.dtype == bool:\n ar2 = ar2.astype(np.uint8)\n\n ar2_min = np.min(ar2)\n ar2_max = np.max(ar2)\n\n ar2_range = int(ar2_max) - int(ar2_min)\n\n # Constraints on whether we can actually use the table method:\n # 1. Assert memory usage is not too large\n below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)\n # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype\n range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max\n # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype\n if ar1.size > 0:\n ar1_min = np.min(ar1)\n ar1_max = np.max(ar1)\n\n # After masking, the range of ar1 is guaranteed to be\n # within the range of ar2:\n ar1_upper = min(int(ar1_max), int(ar2_max))\n ar1_lower = max(int(ar1_min), int(ar2_min))\n\n range_safe_from_overflow &= all((\n ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,\n ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min\n ))\n\n # Optimal performance is for approximately\n # log10(size) > (log10(range) - 2.27) / 0.927.\n # However, here we set the requirement that by default\n # the intermediate array can only be 6x\n # the combined memory allocation of the original\n # arrays. See discussion on \n # https://github.com/numpy/numpy/pull/12065.\n\n if (\n range_safe_from_overflow and \n (below_memory_constraint or kind == 'table')\n ):\n\n if invert:\n outgoing_array = np.ones_like(ar1, dtype=bool)\n else:\n outgoing_array = np.zeros_like(ar1, dtype=bool)\n\n # Make elements 1 where the integer exists in ar2\n if invert:\n isin_helper_ar = np.ones(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 0\n else:\n isin_helper_ar = np.zeros(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 1\n\n # Mask out elements we know won't work\n basic_mask = (ar1 <= ar2_max) & (ar1 >= ar2_min)\n outgoing_array[basic_mask] = isin_helper_ar[ar1[basic_mask] -\n ar2_min]\n\n return outgoing_array\n elif kind == 'table': # not range_safe_from_overflow\n raise RuntimeError(\n \"You have specified kind='table', \"\n \"but the range of values in `ar2` or `ar1` exceed the \"\n \"maximum integer of the datatype. \"\n \"Please set `kind` to None or 'sort'.\"\n )\n elif kind == 'table':\n raise ValueError(\n \"The 'table' method is only \"\n \"supported for boolean or integer arrays. \"\n \"Please select 'sort' or None for kind.\"\n )\n\n\n # Check if one of the arrays may contain arbitrary objects\n contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject\n\n # This code is run when\n # a) the first condition is true, making the code significantly faster\n # b) the second condition is true (i.e. `ar1` or `ar2` may contain\n # arbitrary objects), since then sorting is not guaranteed to work\n if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:\n if invert:\n mask = np.ones(len(ar1), dtype=bool)\n for a in ar2:\n mask &= (ar1 != a)\n else:\n mask = np.zeros(len(ar1), dtype=bool)\n for a in ar2:\n mask |= (ar1 == a)\n return mask\n\n # Otherwise use sorting\n if not assume_unique:\n ar1, rev_idx = np.unique(ar1, return_inverse=True)\n ar2 = np.unique(ar2)\n\n ar = np.concatenate((ar1, ar2))\n # We need this to be a stable sort, so always use 'mergesort'\n # here. The values from the first array should always come before\n # the values from the second array.\n order = ar.argsort(kind='mergesort')\n sar = ar[order]\n if invert:\n bool_ar = (sar[1:] != sar[:-1])\n else:\n bool_ar = (sar[1:] == sar[:-1])\n flag = np.concatenate((bool_ar, [invert]))\n ret = np.empty(ar.shape, dtype=bool)\n ret[order] = flag\n\n if assume_unique:\n return ret[:len(ar1)]\n else:\n return ret[rev_idx]\n```\n###test function signature:\n```python\n\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_in1d_ravel", + "reference": "\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n a = np.arange(6).reshape(2, 3)\n b = np.arange(3, 9).reshape(3, 2)\n long_b = np.arange(3, 63).reshape(30, 2)\n ec = np.array([False, False, False, True, True, True])\n\n assert_array_equal(in1d(a, b, assume_unique=True, kind=kind),\n ec)\n assert_array_equal(in1d(a, b, assume_unique=False,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=True,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=False,\n kind=kind),\n ec)", + "completions": [ + "```python\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n ar1 = np.array([1, 2, 3, 4, 5])\n ar2 = np.array([1, 2, 3, 4, 5])\n ar1_ravel = ar1.ravel()\n ar2_ravel = ar2.ravel()\n self.assertTrue(np.array_equal(ar1_ravel, ar1))\n self.assertTrue(np.array_equal(ar2_ravel, ar2))\n self.assertTrue(np.array_equal(in1d(ar1, ar2, kind=kind),\n in1d(ar1_ravel, ar2_ravel, kind=kind)))\n```\n" + ], + "line": 133, + "token": 1112, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 90, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_metadata", + "reference": "\ndef test_header_to_metadata_and_cell_metadata():\n text = \"\"\"---\ntitle: Sample header\njupyter:\n mainlanguage: python\n---\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {\"mainlanguage\": \"python\"}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {\"lines_to_next_cell\": 0}\n assert pos == len(lines)", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_metadata():\n lines = [\n \"#' ---\",\n \"#' jupyter:\",\n \"#' jupytext:\",\n \"#' text_representation:\",\n \"#' extension: .py\",\n \"#' format_name: light\",\n \"#' format_version: '1.5'\",\n \"#' jupytext_version: 1.11.4\",\n \"#' kernelspec:\",\n \"#' display_name: Python 3\",\n \"#' language: python\",\n \"#' name: python3\",\n \"#' ---\",\n \"#'\",\n \"#' This is a comment\",\n \"#'\",\n \"#' And another one\",\n \"#'\",\n \"#' ```python\",\n \"#' import numpy as np\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' import matplotlib.pyplot as plt\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %matplotlib inline\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2" + ], + "line": 104, + "token": 642, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 91, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef import_optional_dependency(\n name: str,\n extra: str = \"\",\n errors: str = \"raise\",\n min_version: str | None = None,\n):\n\n assert errors in {\"warn\", \"raise\", \"ignore\"}\n\n package_name = INSTALL_MAPPING.get(name)\n install_name = package_name if package_name is not None else name\n\n msg = (\n f\"Missing optional dependency '{install_name}'. {extra} \"\n f\"Use pip or conda to install {install_name}.\"\n )\n try:\n module = importlib.import_module(name)\n except ImportError:\n if errors == \"raise\":\n raise ImportError(msg)\n return None\n\n # Handle submodules: if we have submodule, grab parent module from sys.modules\n parent = name.split(\".\")[0]\n if parent != name:\n install_name = parent\n module_to_get = sys.modules[install_name]\n else:\n module_to_get = module\n minimum_version = min_version if min_version is not None else VERSIONS.get(parent)\n if minimum_version:\n version = get_version(module_to_get)\n if version and Version(version) < Version(minimum_version):\n msg = (\n f\"Pandas requires version '{minimum_version}' or newer of '{parent}' \"\n f\"(version '{version}' currently installed).\"\n )\n if errors == \"warn\":\n warnings.warn(\n msg,\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return None\n elif errors == \"raise\":\n raise ImportError(msg)\n\n return module\n```\n###test function signature:\n```python\n\ndef test_bad_version(monkeypatch):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_bad_version", + "reference": "\ndef test_bad_version(monkeypatch):\n name = \"fakemodule\"\n module = types.ModuleType(name)\n module.__version__ = \"0.9.0\"\n sys.modules[name] = module\n monkeypatch.setitem(VERSIONS, name, \"1.0.0\")\n\n match = \"Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'\"\n with pytest.raises(ImportError, match=match):\n import_optional_dependency(\"fakemodule\")\n\n # Test min_version parameter\n result = import_optional_dependency(\"fakemodule\", min_version=\"0.8\")\n assert result is module\n\n with tm.assert_produces_warning(UserWarning):\n result = import_optional_dependency(\"fakemodule\", errors=\"warn\")\n assert result is None\n\n module.__version__ = \"1.0.0\" # exact match is OK\n result = import_optional_dependency(\"fakemodule\")\n assert result is module", + "completions": [ + "```python\ndef test_bad_version(monkeypatch):\n monkeypatch.setattr(import_optional_dependency, \"get_version\", lambda x: \"0.1.0\")\n monkeypatch.setattr(import_optional_dependency, \"VERSIONS\", {\"pandas\": \"0.2.0\"})\n with pytest.raises(ImportError):\n import_optional_dependency(\"pandas\")\n```\n" + ], + "line": 54, + "token": 352, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 92, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_mel_spectrogram(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mel_spectrogram", + "reference": "\n def test_mel_spectrogram(self):\n waveform = self._load_datasamples(1)[0]\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=513,\n num_mel_filters=13,\n min_frequency=100,\n max_frequency=4000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"htk\",\n )\n self.assertEqual(mel_filters.shape, (513, 13))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n )\n self.assertEqual(spec.shape, (513, 732))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n mel_filters=mel_filters,\n )\n self.assertEqual(spec.shape, (13, 732))\n\n # fmt: off\n expected = np.array([\n 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,\n 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,\n 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,\n 9.44153646e-04\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:, 300], expected))", + "completions": [ + "```python\n def test_mel_spectrogram(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n mel_floor = 1e-10\n log_mel = \"log\"\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (1, 10))\n self.assertEqual(spectrogram[0][0], 1)\n self.assertEqual(spectrogram[0][1], 2)\n self.assertEqual(spectrogram[0][2], 3)\n self.assertEqual(spectrogram[0][3], 4)\n self.assertEqual(spectrogram[0][4], 5)\n self.assertEqual(spectrogram[0][5], 6)\n self.assertEqual(spectrogram[0][6], 7)\n self.assertEqual(spectrogram[0][7], 8)\n self.assertEqual(spectrogram[0][8], 9)\n self.assertEqual(spectrogram[0][9], 10)\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 93, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef create(\n meta_schema,\n validators=(),\n version=None,\n type_checker=_types.draft202012_type_checker,\n format_checker=_format.draft202012_format_checker,\n id_of=_id_of,\n applicable_validators=methodcaller(\"items\"),\n):\n # preemptively don't shadow the `Validator.format_checker` local\n format_checker_arg = format_checker\n\n @attr.s\n class Validator:\n\n VALIDATORS = dict(validators)\n META_SCHEMA = dict(meta_schema)\n TYPE_CHECKER = type_checker\n FORMAT_CHECKER = format_checker_arg\n ID_OF = staticmethod(id_of)\n\n schema = attr.ib(repr=reprlib.repr)\n resolver = attr.ib(default=None, repr=False)\n format_checker = attr.ib(default=None)\n\n def __init_subclass__(cls):\n warnings.warn(\n (\n \"Subclassing validator classes is not intended to \"\n \"be part of their public API. A future version \"\n \"will make doing so an error, as the behavior of \"\n \"subclasses isn't guaranteed to stay the same \"\n \"between releases of jsonschema. Instead, prefer \"\n \"composition of validators, wrapping them in an object \"\n \"owned entirely by the downstream library.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n\n def __attrs_post_init__(self):\n if self.resolver is None:\n self.resolver = RefResolver.from_schema(\n self.schema,\n id_of=id_of,\n )\n\n @classmethod\n def check_schema(cls, schema, format_checker=_UNSET):\n Validator = validator_for(cls.META_SCHEMA, default=cls)\n if format_checker is _UNSET:\n format_checker = Validator.FORMAT_CHECKER\n validator = Validator(\n schema=cls.META_SCHEMA,\n format_checker=format_checker,\n )\n for error in validator.iter_errors(schema):\n raise exceptions.SchemaError.create_from(error)\n\n def evolve(self, **changes):\n # Essentially reproduces attr.evolve, but may involve instantiating\n # a different class than this one.\n cls = self.__class__\n\n schema = changes.setdefault(\"schema\", self.schema)\n NewValidator = validator_for(schema, default=cls)\n\n for field in attr.fields(cls):\n if not field.init:\n continue\n attr_name = field.name # To deal with private attributes.\n init_name = attr_name if attr_name[0] != \"_\" else attr_name[1:]\n if init_name not in changes:\n changes[init_name] = getattr(self, attr_name)\n\n return NewValidator(**changes)\n\n def iter_errors(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.iter_errors \"\n \"is deprecated and will be removed in a future \"\n \"release. Call validator.evolve(schema=new_schema).\"\n \"iter_errors(...) instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n else:\n _schema = self.schema\n\n if _schema is True:\n return\n elif _schema is False:\n yield exceptions.ValidationError(\n f\"False schema does not allow {instance!r}\",\n validator=None,\n validator_value=None,\n instance=instance,\n schema=_schema,\n )\n return\n\n scope = id_of(_schema)\n if scope:\n self.resolver.push_scope(scope)\n try:\n for k, v in applicable_validators(_schema):\n validator = self.VALIDATORS.get(k)\n if validator is None:\n continue\n\n errors = validator(self, v, instance, _schema) or ()\n for error in errors:\n # set details if not already set by the called fn\n error._set(\n validator=k,\n validator_value=v,\n instance=instance,\n schema=_schema,\n type_checker=self.TYPE_CHECKER,\n )\n if k not in {\"if\", \"$ref\"}:\n error.schema_path.appendleft(k)\n yield error\n finally:\n if scope:\n self.resolver.pop_scope()\n\n def descend(self, instance, schema, path=None, schema_path=None):\n for error in self.evolve(schema=schema).iter_errors(instance):\n if path is not None:\n error.path.appendleft(path)\n if schema_path is not None:\n error.schema_path.appendleft(schema_path)\n yield error\n\n def validate(self, *args, **kwargs):\n for error in self.iter_errors(*args, **kwargs):\n raise error\n\n def is_type(self, instance, type):\n try:\n return self.TYPE_CHECKER.is_type(instance, type)\n except exceptions.UndefinedTypeCheck:\n raise exceptions.UnknownType(type, instance, self.schema)\n\n def is_valid(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.is_valid is deprecated \"\n \"and will be removed in a future release. Call \"\n \"validator.evolve(schema=new_schema).is_valid(...) \"\n \"instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n self = self.evolve(schema=_schema)\n\n error = next(self.iter_errors(instance), None)\n return error is None\n\n if version is not None:\n safe = version.title().replace(\" \", \"\").replace(\"-\", \"\")\n Validator.__name__ = Validator.__qualname__ = f\"{safe}Validator\"\n Validator = validates(version)(Validator)\n\n return Validator\n```\n###test function signature:\n```python\n\n def test_validates_registers_meta_schema_id(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validates_registers_meta_schema_id", + "reference": "\n def test_validates_registers_meta_schema_id(self):\n meta_schema_key = \"meta schema id\"\n my_meta_schema = {\"id\": meta_schema_key}\n\n validators.create(\n meta_schema=my_meta_schema,\n version=\"my version\",\n id_of=lambda s: s.get(\"id\", \"\"),\n )\n self.addCleanup(validators._META_SCHEMAS.pop, meta_schema_key)\n\n self.assertIn(meta_schema_key, validators._META_SCHEMAS)", + "completions": [ + "```python\n def test_validates_registers_meta_schema_id(self):\n validator = create(\n meta_schema={\n \"$id\": \"https://example.com/schema.json\",\n \"type\": \"object\",\n \"properties\": {\"foo\": {\"type\": \"string\"}},\n }\n )\n self.assertEqual(\n validator.resolver.store[\"https://example.com/schema.json\"],\n validator.schema,\n )\n```\n" + ], + "line": 160, + "token": 896, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 94, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_metrics(directory: str) -> pd.Series:\n\n dfs = []\n # Loop through all csv files in the directory\n csv_files = [\n filename for filename in os.listdir(directory) if filename.endswith(\".csv\")\n ]\n\n if len(csv_files) == 0:\n return pd.Series()\n\n for filename in csv_files:\n filepath = os.path.join(directory, filename)\n # Read the csv file into a pandas DataFrame\n df = pd.read_csv(\n filepath,\n header=None,\n names=[\n \"initial_time_iso\",\n \"lead_time_hours\",\n \"channel\",\n \"metric\",\n \"value\",\n ],\n )\n out = pd.DataFrame()\n out[\"initial_time\"] = pd.to_datetime(df[\"initial_time_iso\"])\n out[\"lead_time\"] = pd.to_timedelta(df[\"lead_time_hours\"], unit=\"h\")\n out[\"channel\"] = df[\"channel\"]\n out[\"metric\"] = df[\"metric\"]\n out[\"value\"] = df[\"value\"]\n dfs.append(out)\n\n df = pd.concat(dfs, axis=0)\n df.set_index([\"initial_time\", \"lead_time\", \"channel\", \"metric\"], inplace=True)\n return df[\"value\"]\n```\n###test function signature:\n```python\n\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_and_read_metric", + "reference": "\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n initial_time = datetime.datetime(2022, 1, 1, 0, 0, 0)\n lead_time = datetime.timedelta(hours=24)\n channel = \"t2m\"\n metric = \"rmse\"\n value = 25.6\n\n # Read the metric from the directory\n directory = tmpdir.mkdir(\"metrics\")\n filename = \"test_metric.csv\"\n filepath = os.path.join(directory, filename)\n with open(filepath, \"w\") as f:\n forecast_metrics_io.write_metric(\n f, initial_time, lead_time, channel, metric, value\n )\n metrics = forecast_metrics_io.read_metrics(directory)\n\n # Check that the metric was written correctly\n expected_index = pd.MultiIndex.from_tuples(\n [(initial_time, lead_time, channel, metric)],\n names=[\"initial_time\", \"lead_time\", \"channel\", \"metric\"],\n )\n expected_value = pd.Series([value], index=expected_index, name=\"value\")\n pd.testing.assert_series_equal(metrics, expected_value)", + "completions": [ + "```python\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n metric = pd.Series(\n data=[1, 2, 3],\n index=[\n pd.Timestamp(\"2020-01-01 00:00:00\"),\n pd.Timestamp(\"2020-01-01 01:00:00\"),\n pd.Timestamp(\"2020-01-01 02:00:00\"),\n ],\n name=\"metric\",\n )\n metric.to_csv(tmpdir.join(\"metric.csv\"), header=False)\n\n # Read the metric from the file object\n read_metric = read_metrics(tmpdir.strpath)\n\n # Check that the metric is the same\n assert_series_equal(metric, read_metric)\n```\n" + ], + "line": 43, + "token": 320, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 95, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n\n def test_kde_bandwidth_method(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_kde_bandwidth_method", + "reference": "\n def test_kde_bandwidth_method(self):\n\n np.random.seed(8765678)\n n_basesample = 50\n xn = np.random.randn(n_basesample)\n\n # Default\n gkde = mlab.GaussianKDE(xn)\n # Supply a callable\n gkde2 = mlab.GaussianKDE(xn, 'scott')\n # Supply a scalar\n gkde3 = mlab.GaussianKDE(xn, bw_method=gkde.factor)\n\n xs = np.linspace(-7, 7, 51)\n kdepdf = gkde.evaluate(xs)\n kdepdf2 = gkde2.evaluate(xs)\n assert kdepdf.all() == kdepdf2.all()\n kdepdf3 = gkde3.evaluate(xs)\n assert kdepdf.all() == kdepdf3.all()", + "completions": [ + "```python\n def test_kde_bandwidth_method(self):\n from sklearn.neighbors import KernelDensity\n from sklearn.metrics import euclidean_distances\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n from sklearn.model_selection import train_test_split\n from sklearn.metrics import accuracy_score\n from sklearn.metrics import f1_score\n from sklearn.metrics import precision_score\n from sklearn.metrics import recall_score\n from sklearn.metrics import roc_auc_score\n from sklearn.metrics import confusion_matrix\n from sklearn.metrics import classification_report\n from sklearn.metrics import cohen_kappa_score\n from sklearn.metrics import mean_squared_error\n from sklearn.metrics import mean_absolute_error\n from sklearn.metrics import mean_squared_log_error\n from sklearn.metrics import median_absolute_error\n from sklearn.metrics import r2_score\n from sklearn.metrics import explained_variance_score\n from sklearn.metrics import silhouette_score\n from sklearn.metrics import calinski_harabasz_score\n from sklearn.metrics import davies_bouldin_score\n from sklearn.metrics import mutual_info_score\n from sklearn.metrics import normalized_mutual_info_score\n from sklearn.metrics import adjusted_mutual_info_score\n from sklearn.metrics import adjusted_rand_score\n from sklearn.metrics import completeness_score\n from sklearn.metrics import contingency_matrix\n from sklearn.metrics import homogeneity_completeness_v_measure\n from sklearn.metrics import homogeneity_score\n from sklearn.metrics import v_measure_score\n from sklearn.metrics import normalized_v_measure_score\n from sklearn.metrics import poisson\n from sklearn.metrics import kld\n from sklearn.metrics import kld_normalized\n from sklearn.metrics import kld_unnormalized\n from sklearn.metrics import kld_unnormalized_normalized\n from sklearn.metrics import kld_unnormalized_normalized_sym\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 96, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_top_k_embeddings(\n query_embedding: List[float],\n embeddings: List[List[float]],\n similarity_fn: Optional[Callable[..., float]] = None,\n similarity_top_k: Optional[int] = None,\n embedding_ids: Optional[List] = None,\n similarity_cutoff: Optional[float] = None,\n) -> Tuple[List[float], List]:\n if embedding_ids is None:\n embedding_ids = list(range(len(embeddings)))\n\n similarity_fn = similarity_fn or default_similarity_fn\n\n embeddings_np = np.array(embeddings)\n query_embedding_np = np.array(query_embedding)\n\n similarity_heap: List[Tuple[float, Any]] = []\n for i, emb in enumerate(embeddings_np):\n similarity = similarity_fn(query_embedding_np, emb)\n if similarity_cutoff is None or similarity > similarity_cutoff:\n heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))\n if similarity_top_k and len(similarity_heap) > similarity_top_k:\n heapq.heappop(similarity_heap)\n result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)\n\n result_similarities = [s for s, _ in result_tups]\n result_ids = [n for _, n in result_tups]\n\n return result_similarities, result_ids\n```\n###test function signature:\n```python\ndef test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_top_k_mmr_embeddings", + "reference": "def test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n query_embedding = [5.0, 0.0, 0.0]\n embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]]\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.8\n )\n\n assert np.isclose(0.8 * 4 / 5, result_similarities[0], atol=0.00001)\n assert np.isclose(\n 0.8 * 3 / 5 - (1 - 0.8) * (3 * 4 / 25 + 3 * 4 / 25),\n result_similarities[1],\n atol=0.00001,\n )\n assert np.isclose(\n 0.8 * -4 / 5 - (1 - 0.8) * (3 * -4 / 25 + 4 * 3 / 25),\n result_similarities[2],\n atol=0.00001,\n )\n assert result_ids == [0, 1, 2]\n\n # Tests that if the first embedding vector is close to the second,\n # it will return the third\n query_embedding = [1.0, 0.0, 1.0]\n embeddings = [[1.0, 0.0, 0.9], [1.0, 0.0, 0.8], [0.7, 0.0, 1.0]]\n\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.5\n )\n assert result_ids == [0, 2, 1]\n\n # Tests that embedding ids map properly to results\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, embedding_ids=[\"A\", \"B\", \"C\"], mmr_threshold=0.5\n )\n assert result_ids == [\"A\", \"C\", \"B\"]\n # Test that it will go back to the original order under a high threshold\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n assert result_ids == [0, 1, 2]\n\n # Test similarity_top_k works\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1, similarity_top_k=2\n )\n assert result_ids == [0, 1]\n\n # Test the results for get_top_k_embeddings and get_top_k_mmr_embeddings are the\n # same for threshold = 1\n query_embedding = [10, 23, 90, 78]\n embeddings = [[1, 23, 89, 68], [1, 74, 144, 23], [0.23, 0.0, 1.0, 9]]\n result_similarities_no_mmr, result_ids_no_mmr = get_top_k_embeddings(\n query_embedding, embeddings\n )\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n\n for result_no_mmr, result_with_mmr in zip(\n result_similarities_no_mmr, result_similarities\n ):\n assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)", + "completions": [ + "```python\ndef test_get_top_k_mmr_embeddings() -> None:\n query_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]\n embeddings = [\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n ]\n result_similarities, result_ids = get_top_k_embeddings(\n query_embedding, embeddings, similarity_top_k=3\n )\n assert result_similarities == [0.5, 0.4, 0.3]\n assert result_ids == [0, 1, 2]\n```\n" + ], + "line": 35, + "token": 309, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 97, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef frame_accuracy(\n logits: np.ndarray,\n label: np.ndarray,\n background_logit_threshold: float = 0.0,\n) -> float:\n top1_idx = np.argmax(logits, axis=-1)\n background_label = np.sum(label, axis=-1) == 0\n pred_background = (\n np.max(logits, axis=-1) <= background_logit_threshold\n ).astype(np.int32)\n\n # Extracts the label at the highest logit index for each input.\n top1_correct = np.take_along_axis(label, top1_idx[..., None], axis=-1)\n top1_correct = np.squeeze(top1_correct)\n foreground_correct = ~pred_background.astype(bool) * top1_correct\n\n # Count correctly classified background samples.\n background_correct = pred_background * background_label\n correct = foreground_correct + background_correct\n return np.sum(correct) / len(logits)\n```\n###test function signature:\n```python\n\n def test_frame_accuracy_all_background(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_frame_accuracy_all_background", + "reference": "\n def test_frame_accuracy_all_background(self):\n logits = np.array(\n [\n [1.5, 0.0, -1.0],\n [0.1, 0.7, 0.2],\n [0.1, 0.7, 0.2],\n [0.1, -0.7, 0.2],\n [-0.1, -0.7, -0.2],\n ],\n np.float32,\n )\n label = np.array(\n [\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n ],\n np.int32,\n )\n actual = metrics.frame_accuracy(logits, label)\n self.assertAlmostEqual(actual, 0.2)", + "completions": [ + "```python\ndef test_frame_accuracy_all_background(self):\n logits = np.array([[0.1, 0.2, 0.3, 0.4]])\n label = np.array([[0, 0, 0, 0]])\n self.assertAlmostEqual(frame_accuracy(logits, label), 1.0)\n```\n" + ], + "line": 28, + "token": 206, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_local_image_to_be_loaded", + "reference": "\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n file_path, _ = example_local_image\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=file_path,\n max_width=64,\n max_height=128,\n )\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert abs(scaling_factor - 0.5) < 1e-5\n assert decoding_result.shape == (64, 64, 3)\n assert (decoding_result == 0).all()", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n reference, image = example_local_image\n # when\n result = load_image_from_string(\n reference=reference,\n max_height=image.shape[0],\n max_width=image.shape[1],\n )\n # then\n assert result[0] == reference\n assert result[1] is None\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 99, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = \"htk\") -> Union[float, np.ndarray]:\n\n if mel_scale not in [\"slaney\", \"htk\", \"kaldi\"]:\n raise ValueError('mel_scale should be one of \"htk\", \"slaney\" or \"kaldi\".')\n\n if mel_scale == \"htk\":\n return 2595.0 * np.log10(1.0 + (freq / 700.0))\n elif mel_scale == \"kaldi\":\n return 1127.0 * np.log(1.0 + (freq / 700.0))\n\n min_log_hertz = 1000.0\n min_log_mel = 15.0\n logstep = 27.0 / np.log(6.4)\n mels = 3.0 * freq / 200.0\n\n if isinstance(freq, np.ndarray):\n log_region = freq >= min_log_hertz\n mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep\n elif freq >= min_log_hertz:\n mels = min_log_mel + np.log(freq / min_log_hertz) * logstep\n\n return mels\n```\n###test function signature:\n```python\n\n def test_hertz_to_mel(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_hertz_to_mel", + "reference": "\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(0.0), 0.0)\n self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)\n\n inputs = np.array([100, 200])\n expected = np.array([150.48910241, 283.22989816])\n self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))\n\n self.assertEqual(hertz_to_mel(0.0, \"slaney\"), 0.0)\n self.assertEqual(hertz_to_mel(100, \"slaney\"), 1.5)\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"slaney\"), expected))\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"kaldi\"), expected))\n\n with pytest.raises(ValueError):\n hertz_to_mel(100, mel_scale=None)", + "completions": [ + "```python\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(" + ], + "line": 27, + "token": 262, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/3/QS/line_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Python/3/QS/line_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..5dc0c40d8980f15c184c06e278bef6cc19a7858e --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/3/QS/line_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(15~33),subset_1(35~61),subset_2(62~749) +StarCoder2-15b,26.65,23.86,22.64 +CodeLlama-7b,23.41,30.31,25.81 +CodeLlama-13b,28.85,26.72,24.79 +CodeLlama-34b,29.54,24.36,27.28 +DeepSeek-Coder-1.3b,23.88,27.34,25.70 +DeepSeek-Coder-6.7b,24.09,26.56,23.05 +DeepSeek-Coder-33b,29.68,28.99,28.43 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/3/QS/token_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Python/3/QS/token_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..056139aad1888316243924ab5461e55b255d6232 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/3/QS/token_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(122~273),subset_1(273~437),subset_2(453~7038) +StarCoder2-15b,26.65,23.75,22.66 +CodeLlama-7b,24.21,29.18,26.12 +CodeLlama-13b,29.35,25.72,25.11 +CodeLlama-34b,29.58,24.89,26.60 +DeepSeek-Coder-1.3b,24.39,26.55,25.97 +DeepSeek-Coder-6.7b,25.26,25.08,23.32 +DeepSeek-Coder-33b,29.73,28.05,29.21 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/3/QS/tongji.py b/dataset/Test Generation/ComplexCodeEval-Python/3/QS/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/3/QS/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/4/EI/EI.json b/dataset/Test Generation/ComplexCodeEval-Python/4/EI/EI.json new file mode 100644 index 0000000000000000000000000000000000000000..2168ebdda26e1312eecec960b6711f70475570ae --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/4/EI/EI.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_pair_vector_and_distance(g: dgl.DGLGraph):\n dst_pos = g.ndata[\"pos\"][g.edges()[1]] + g.edata[\"pbc_offshift\"]\n src_pos = g.ndata[\"pos\"][g.edges()[0]]\n bond_vec = dst_pos - src_pos\n bond_dist = torch.norm(bond_vec, dim=1)\n\n return bond_vec, bond_dist\n```\n###test function signature:\n```python\n\n def test_compute_angle(self, graph_Mo, graph_CH4):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_angle", + "reference": "\n def test_compute_angle(self, graph_Mo, graph_CH4):\n s1, g1, state1 = graph_Mo\n lattice = torch.tensor(s1.lattice.matrix, dtype=matgl.float_th).unsqueeze(dim=0)\n g1.edata[\"pbc_offshift\"] = torch.matmul(g1.edata[\"pbc_offset\"], lattice[0])\n g1.ndata[\"pos\"] = g1.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g1)\n g1.edata[\"bond_vec\"] = bv\n g1.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g1, 4.0)\n\n line_graph = create_line_graph(g1, 4.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))\n np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata[\"theta\"])), decimal=4)\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n s2, g2, state2 = graph_CH4\n lattice = torch.tensor(np.identity(3), dtype=matgl.float_th).unsqueeze(dim=0)\n g2.edata[\"pbc_offshift\"] = torch.matmul(g2.edata[\"pbc_offset\"], lattice[0])\n g2.ndata[\"pos\"] = g2.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g2)\n g2.edata[\"bond_vec\"] = bv\n g2.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g2, 2.0)\n\n line_graph = create_line_graph(g2, 2.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata[\"theta\"]))\n )\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )", + "completions": [ + "```python\n def test_compute_angle(self, graph_Mo, graph_CH4):\n \"\"\"\n Test compute_angle function\n \"\"\"\n # compute angle between Mo-C and Mo-H\n angle_Mo_C = compute_angle(graph_Mo, graph_CH4, 0, 1)\n angle_Mo_H = compute_angle(graph_Mo, graph_CH4, 0, 2)\n\n # check angle between Mo-C and Mo-H\n self.assertAlmostEqual(angle_Mo_C, 104.5, places=1)\n self.assertAlmostEqual(angle_Mo_H, 109.5, places=1)\n```\n" + ], + "line": 16, + "token": 148, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 1, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def _make_schedule(self, targets: List[Task]) -> List[Task]:\n self.schedule = []\n self.dependencies = self._build_dependencies(targets)\n\n edge_tups = []\n for node in self.dependencies:\n for dependency in self.dependencies[node]:\n edge_tups.append((dependency, node))\n\n for task in targets:\n # add edges from a dummy node to each target to guarantee\n # they will be included in the final schedule\n edge_tups.append((Executor.DUMMY_TASK_VALUE, task))\n\n def _compare_key(task: Union[Task, str]):\n if task == Executor.DUMMY_TASK_VALUE:\n return (\"\", 0)\n return (\n task.group_label() or \"\",\n -task.priority(),\n )\n\n graph = networkx.DiGraph(edge_tups)\n res = [\n t\n for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)\n if t != Executor.DUMMY_TASK_VALUE\n ]\n return res\n```\n###test function signature:\n```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_group_label_with_dependencies", + "reference": "\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = create_mock_task(\"task1\", result=1, group_label=\"group1\")\n task2 = create_mock_task(\n \"task2\", result=2, dependencies={\"task1\": task1}, group_label=\"group2\"\n )\n task3 = create_mock_task(\n \"task3\", result=3, dependencies={\"task2\": task2}, group_label=\"group1\"\n )\n\n executor = Executor([task3])\n schedule = executor._make_schedule([task3])\n scheduled_labels = [\n task.group_label() for task in schedule if task.group_label()\n ]\n\n # Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1\n group1_indices = [\n i for i, label in enumerate(scheduled_labels) if label == \"group1\"\n ]\n group2_index = scheduled_labels.index(\"group2\")\n\n assert (\n group1_indices[-1] > group2_index\n ), \"Task with the same group label but later dependency was not scheduled after different group label\"", + "completions": [ + "```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = Task(name=\"task1\", group_label=\"group1\")\n task2 = Task(name=\"task2\", group_label=\"group1\")\n task3 = Task(name=\"task3\", group_label=\"group2\")\n task4 = Task(name=\"task4\", group_label=\"group2\")\n task5 = Task(name=\"task5\", group_label=\"group3\")\n task6 = Task(name=\"task6\", group_label=\"group3\")\n task7 = Task(name=\"task7\", group_label=\"group4\")\n task8 = Task(name=\"task8\", group_label=\"group4\")\n task9 = Task(name=\"task9\", group_label=\"group5\")\n task10 = Task(name=\"task10\", group_label=\"group5\")\n\n # Creating dependencies between tasks\n task1.depends_on(task2)\n task2.depends_on(task3)\n task3.depends_on(task4)\n task4.depends_on(task5)\n task5.depends_on(task6)\n task6.depends_on(task7)\n task7.depends_on(task8)\n task8.depends_on(task9)\n task9.depends_on(task10)\n\n # Creating a list of tasks\n tasks = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies\n tasks_with_dependencies = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with group labels\n tasks_with_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_" + ], + "line": 36, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 2, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n):\n if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:\n return\n\n head_dim = dim // heads\n max_seq_len = find_multiple(max_seq_len, 8)\n\n for b in layers:\n b.attention.kv_cache = KVCache(\n max_batch_size, max_seq_len, heads, head_dim\n )\n\n freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)\n causal_mask = torch.tril(\n torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)\n )\n\n return causal_mask, freq_cis\n```\n###test function signature:\n```python\n\ndef test_setup_cache_max_batch_size_greater_than_max():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_setup_cache_max_batch_size_greater_than_max", + "reference": "\ndef test_setup_cache_max_batch_size_greater_than_max():\n layers = [\n Mock(attention=Mock(kw_cache=None)),\n Mock(attention=Mock(kw_cache=None)),\n ]\n max_batch_size = 64\n max_seq_len = 32\n heads = 4\n head_dim = 32\n block_size = 32\n rope_base = 1000\n setup_cache(\n max_batch_size + 10,\n max_seq_len,\n head_dim * heads,\n heads,\n layers,\n block_size,\n rope_base,\n )\n for layer in layers:\n assert isinstance(layer.attention.kw_cache, KVCache)\n assert layer.attention.kw_cache.k_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )\n assert layer.attention.kw_cache.v_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )", + "completions": [ + "```python\ndef test_setup_cache_max_batch_size_greater_than_max():\n max_batch_size = 10\n max_seq_len = 10\n dim = 10\n heads = 10\n layers = [1, 2, 3]\n block_size = 10\n rope_base = 10\n\n causal_mask, freq_cis = setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n )\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size" + ], + "line": 26, + "token": 168, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_file_to_be_downloaded", + "reference": "\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n response = Response()\n response.status_code = 200\n image = np.zeros((128, 128, 3), dtype=np.uint8)\n _, encoded_image = cv2.imencode(\".jpg\", image)\n response._content = encoded_image\n requests_get_mock.return_value = response\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=\"https://some/file.jpg\"\n )\n\n # then\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert scaling_factor is None\n assert decoding_result.shape == image.shape\n assert np.allclose(decoding_result, image)\n requests_get_mock.assert_called_once_with(\"https://some/file.jpg\")", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n reference = \"https://www.example.com/image.jpg\"\n max_height = 100\n max_width = 100\n requests_get_mock.return_value.content = b\"image_bytes\"\n # when\n result = load_image_from_string(\n reference=reference, max_height=max_height, max_width=max_width\n )\n # then\n assert result == (\n \"data:image/jpeg;base64,image_bytes\",\n None,\n )\n```\n" + ], + "line": 47, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 4, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_random_circuits", + "reference": "\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n placer = placer(connectivity=connectivity)\n layout_circ = Circuit(5)\n initial_layout = placer(layout_circ)\n router = Sabre(connectivity=connectivity, lookahead=look, decay_lookahead=decay)\n circuit = generate_random_circuit(nqubits=5, ngates=n_gates)\n measurement = gates.M(*range(5))\n circuit.add(measurement)\n transpiled_circuit, final_qubit_map = router(circuit, initial_layout)\n assert router.added_swaps >= 0\n assert_connectivity(connectivity, transpiled_circuit)\n assert_placement(transpiled_circuit, final_qubit_map)\n assert n_gates + router.added_swaps + 1 == transpiled_circuit.ngates\n assert_circuit_equivalence(\n original_circuit=circuit,\n transpiled_circuit=transpiled_circuit,\n final_map=final_qubit_map,\n initial_map=initial_layout,\n )\n assert transpiled_circuit.queue[-1].register_name == measurement.register_name", + "completions": [ + "```python\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n import random\n import networkx as nx\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import" + ], + "line": 27, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 5, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef xx_circuit_step(source, strength, target, embodiment):\n\n permute_source_for_overlap, permute_target_for_overlap = None, None\n\n # apply all possible reflections, shifts to the source\n for source_reflection_name in reflection_options:\n reflected_source_coord, source_reflection, reflection_phase_shift = apply_reflection(\n source_reflection_name, source\n )\n for source_shift_name in shift_options:\n shifted_source_coord, source_shift, shift_phase_shift = apply_shift(\n source_shift_name, reflected_source_coord\n )\n\n # check for overlap, back out permutation\n source_shared, target_shared = None, None\n for i, j in [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]:\n\n if (\n abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi)) < EPSILON\n or abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi) - np.pi)\n < EPSILON\n ):\n source_shared, target_shared = i, j\n break\n if source_shared is None:\n continue\n\n # pick out the other coordinates\n source_first, source_second = (x for x in [0, 1, 2] if x != source_shared)\n target_first, target_second = (x for x in [0, 1, 2] if x != target_shared)\n\n # check for arccos validity\n r, s, u, v, x, y = decompose_xxyy_into_xxyy_xx(\n float(target[target_first]),\n float(target[target_second]),\n float(shifted_source_coord[source_first]),\n float(shifted_source_coord[source_second]),\n float(strength),\n )\n if any(math.isnan(val) for val in (r, s, u, v, x, y)):\n continue\n\n # OK: this combination of things works.\n # save the permutation which rotates the shared coordinate into ZZ.\n permute_source_for_overlap = canonical_rotation_circuit(source_first, source_second)\n permute_target_for_overlap = canonical_rotation_circuit(target_first, target_second)\n break\n\n if permute_source_for_overlap is not None:\n break\n\n if permute_source_for_overlap is None:\n raise QiskitError(\n \"Error during RZX decomposition: Could not find a suitable Weyl \"\n f\"reflection to match {source} to {target} along {strength}.\"\n )\n\n prefix_circuit, affix_circuit = QuantumCircuit(2), QuantumCircuit(2)\n\n # the basic formula we're trying to work with is:\n # target^p_t_f_o =\n # rs * (source^s_reflection * s_shift)^p_s_f_o * uv * operation * xy\n # but we're rearranging it into the form\n # target = affix source prefix\n # and computing just the prefix / affix circuits.\n\n # the outermost prefix layer comes from the (inverse) target permutation.\n prefix_circuit.compose(permute_target_for_overlap.inverse(), inplace=True)\n # the middle prefix layer comes from the local Z rolls.\n prefix_circuit.rz(2 * x, [0])\n prefix_circuit.rz(2 * y, [1])\n prefix_circuit.compose(embodiment, inplace=True)\n prefix_circuit.rz(2 * u, [0])\n prefix_circuit.rz(2 * v, [1])\n # the innermost prefix layer is source_reflection, shifted by source_shift,\n # finally conjugated by p_s_f_o.\n prefix_circuit.compose(permute_source_for_overlap, inplace=True)\n prefix_circuit.compose(source_reflection, inplace=True)\n prefix_circuit.global_phase += -np.log(reflection_phase_shift).imag\n prefix_circuit.global_phase += -np.log(shift_phase_shift).imag\n\n # the affix circuit is constructed in reverse.\n # first (i.e., innermost), we install the other half of the source transformations and p_s_f_o.\n affix_circuit.compose(source_reflection.inverse(), inplace=True)\n affix_circuit.compose(source_shift, inplace=True)\n affix_circuit.compose(permute_source_for_overlap.inverse(), inplace=True)\n # then, the other local rolls in the middle.\n affix_circuit.rz(2 * r, [0])\n affix_circuit.rz(2 * s, [1])\n # finally, the other half of the p_t_f_o conjugation.\n affix_circuit.compose(permute_target_for_overlap, inplace=True)\n\n return {\"prefix_circuit\": prefix_circuit, \"affix_circuit\": affix_circuit}\n```\n###test function signature:\n```python\n def test_xx_circuit_step(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_xx_circuit_step", + "reference": " def test_xx_circuit_step(self):\n\n for _ in range(100):\n source_coordinate, interaction, target_coordinate = self._generate_xxyy_test_case()\n\n source_embodiment = qiskit.QuantumCircuit(2)\n source_embodiment.append(UnitaryGate(canonical_matrix(*source_coordinate)), [0, 1])\n interaction_embodiment = qiskit.QuantumCircuit(2)\n interaction_embodiment.append(UnitaryGate(canonical_matrix(*interaction)), [0, 1])\n\n prefix_circuit, affix_circuit = itemgetter(\"prefix_circuit\", \"affix_circuit\")(\n xx_circuit_step(\n source_coordinate, interaction[0], target_coordinate, interaction_embodiment\n )\n )\n\n target_embodiment = QuantumCircuit(2)\n target_embodiment.compose(prefix_circuit, inplace=True)\n target_embodiment.compose(source_embodiment, inplace=True)\n target_embodiment.compose(affix_circuit, inplace=True)\n self.assertTrue(\n np.all(\n np.abs(\n qiskit.quantum_info.operators.Operator(target_embodiment).data\n - canonical_matrix(*target_coordinate)\n )\n < EPSILON\n )\n )", + "completions": [ + "```python\n def test_xx_circuit_step(self):\n source = np.array([0.1, 0.2, 0.3])\n strength = 0.4\n target = np.array([0.5, 0.6, 0.7])\n embodiment = QuantumCircuit(2)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.cx(0, 1)\n embodiment.barrier()\n embodiment.rz(0.1, 0)\n embodiment.rz(0.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.3, 0)\n embodiment.rz(0.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.5, 0)\n embodiment.rz(0.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.7, 0)\n embodiment.rz(0.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.9, 0)\n embodiment.rz(1.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.1, 0)\n embodiment.rz(1.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.3, 0)\n embodiment.rz(1.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.5, 0)\n embodiment.rz(1.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.7, 0)\n embodiment.rz(1.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.9, 0)\n embodiment.rz(2.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(2.1, 0)\n embodiment.rz(2.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment" + ], + "line": 90, + "token": 807, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_einsum(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_einsum", + "reference": "\n def test_einsum(self):\n x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n self.assertAllClose(\n knp.einsum(\"ijk,lkj->il\", x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"ijk,ikj->i\", x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"i...,j...k->...ijk\", x, y),\n np.einsum(\"i..., j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n self.assertAllClose(\n knp.Einsum(\"ijk,lkj->il\")(x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"ijk,ikj->i\")(x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n np.einsum(\"i...,j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))", + "completions": [ + "```python\n\n def test_arange(self):\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32),\n torch.arange(0, 10, 1, dtype=torch.float32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64),\n torch.arange(0, 10, 1, dtype=torch.float64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32),\n torch.arange(0, 10, 1, dtype=torch.int32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64),\n torch.arange(0, 10, 1, dtype=torch.int64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool),\n torch.arange(0, 10, 1, dtype=torch.bool),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64),\n torch.arange(0, 10, 1, dtype=torch.complex64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128),\n torch.arange(0, 10, 1, dtype=torch.complex128),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n )\n```\n" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 7, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_center_padding(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_center_padding", + "reference": "\n def test_spectrogram_center_padding(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"reflect\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,\n 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,\n 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,\n 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,\n 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,\n 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,\n 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,\n 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,\n 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,\n 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,\n 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,\n 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,\n 0.00217659, 0.00276204, 0.00260835, 0.00299299,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"constant\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,\n 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,\n 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,\n 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,\n 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,\n 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,\n 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,\n 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,\n 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,\n 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,\n 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,\n 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,\n 0.00788239, 0.00664407, 0.00824227, 0.00628301,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=False,\n )\n self.assertEqual(spec.shape, (257, 728))\n\n # fmt: off\n expected = np.array([\n 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,\n 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,\n 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,\n 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,\n 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,\n 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,\n 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,\n 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,\n 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,\n 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,\n 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,\n 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,\n 0.00811857, 0.00538216, 0.00685749, 0.00535275,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))", + "completions": [ + "```python\n def test_spectrogram_center_padding(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 4\n hop_length = 2\n fft_length = 16\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n np.testing.assert_array_equal(\n spectrogram,\n np.array(\n [\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0," + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 8, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_kword_ref_line", + "reference": "\ndef test_nanoplot_out_bars_with_kword_ref_line():\n\n out_bars_with_kword_ref_line = _generate_nanoplot(**CASES[12])\n\n assert _is_nanoplot_output(out_bars_with_kword_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = \"max\"\n y_ref_area = [1, 2]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = \"min\"\n expand_y = \"min\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_stroke_color=data_bar_negative_stroke_" + ], + "line": 749, + "token": 7038, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 9, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_take_along_axis(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_take_along_axis", + "reference": "\n def test_take_along_axis(self):\n x = np.arange(24).reshape([1, 2, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=1),\n np.take_along_axis(x, indices, axis=1),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=1)(x, indices),\n np.take_along_axis(x, indices, axis=1),\n )\n\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )\n\n # Test with axis=None\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.array([1, 2, 3], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=None),\n np.take_along_axis(x, indices, axis=None),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=None)(x, indices),\n np.take_along_axis(x, indices, axis=None),\n )\n\n # Test with negative indices\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )", + "completions": [ + "```python\n\n def test_take_along_axis(self):\n # test arange\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=1, dtype=torch.float32),\n torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=2, dtype=torch.float32),\n torch.tensor([0, 2, 4, 6, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=3, dtype=torch.float32),\n torch.tensor([0, 3, 6, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=4, dtype=torch.float32),\n torch.tensor([0, 4, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=5, dtype=torch.float32),\n torch.tensor([0, 5], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=6, dtype=torch.float32),\n torch.tensor([0, 6], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=7, dtype=torch.float32),\n torch.tensor([0, 7], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=8, dtype=torch.float32),\n torch.tensor([0, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=9, dtype=torch.float32),\n torch.tensor([0, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=10, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=11, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=12, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=13, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=14, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 10, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_correlation(\n response_level_scores: dict[str, list[dict[str, dict[str, Any]]]],\n name: str = '',\n) -> dict[str, dict[str, dict[str, float]]]:\n factscore_metrics = set(list_metrics(response_level_scores[FACTSCORE]))\n pred_metrics = set(list_metrics(response_level_scores[EVAL_METHOD]))\n result = {}\n\n for metric in factscore_metrics.intersection(pred_metrics):\n factscore_scores = [\n find_metric(response_level_scores[FACTSCORE][i], metric)\n for i in range(len(response_level_scores[FACTSCORE]))\n ]\n pred_scores = [\n find_metric(response_level_scores[EVAL_METHOD][i], metric)\n for i in range(len(response_level_scores[EVAL_METHOD]))\n ]\n\n # Remove failed runs\n for i in range(len(factscore_scores) - 1, -1, -1):\n if factscore_scores[i] == -1 or pred_scores[i] == -1:\n factscore_scores.pop(i)\n pred_scores.pop(i)\n\n # Cannot calculate correlation with only one data point\n if len(factscore_scores) <= 1 or len(pred_scores) <= 1:\n pearson_result, spearman_result = None, None\n else:\n pearson_result = stats.pearsonr(factscore_scores, pred_scores)\n spearman_result = stats.spearmanr(factscore_scores, pred_scores)\n\n scatter_plot(\n factscore_scores,\n pred_scores,\n f'{name}-{metric}',\n FACTSCORE,\n EVAL_METHOD,\n )\n result[metric] = {\n _PEARSON: CorrelationResult(pearson_result).result,\n _SPEARMAN: CorrelationResult(spearman_result).result,\n }\n\n return result\n```\n###test function signature:\n```python\n\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_correlation_one_data_point", + "reference": "\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n response_level_scores = {\n correlation_vs_factscore.FACTSCORE: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'N': 1}},\n ],\n correlation_vs_factscore.EVAL_METHOD: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'M': 1}},\n ],\n }\n actual_output = correlation_vs_factscore.compute_correlation(\n response_level_scores=response_level_scores\n )\n mock_pearsonr.assert_not_called()\n mock_spearmanr.assert_not_called()\n mock_scatter_plot.assert_called()\n self.assertIsInstance(actual_output, dict)\n self.assertLen(actual_output, 1)\n self.assertIn('S', actual_output)\n self.assertIsInstance(actual_output['S'], dict)\n self.assertLen(actual_output['S'], 2)\n self.assertIn('Pearson', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Pearson'], dict)\n self.assertIn('correlation', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['correlation']))\n self.assertIn('p_value', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['p_value']))\n self.assertIn('Spearman', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Spearman'], dict)\n self.assertIn('correlation', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['correlation']))\n self.assertIn('p_value', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['p_value']))", + "completions": [ + "```python\ndef test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n) -> None:\n response_level_scores = {\n FACTSCORE: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n EVAL_METHOD: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n }\n result = compute_correlation(response_level_scores)\n self.assertEqual(result, {})\n mock_spearmanr.assert_not_called()\n mock_pearsonr.assert_not_called()\n mock_scatter_plot.assert_not_called()\n```\n" + ], + "line": 54, + "token": 384, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_data_lines():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_data_lines", + "reference": "\ndef test_nanoplot_out_data_lines():\n\n out_data_lines = _generate_nanoplot(**CASES[0])\n\n assert _is_nanoplot_output(out_data_lines)\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_data_lines():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n expected = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_" + ], + "line": 749, + "token": 7038, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 12, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_daily_temp_mean", + "reference": "\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data = il_electricity_cdd_hdd_daily[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_daily[\"temperature_data\"]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert df.shape == (810, 3)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_mean\",\n ]\n\n assert round(df.temperature_mean.mean()) == 55.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"D\", tz=\"UTC\"\n )\n temperature_data = pd.DataFrame(\n {\"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]},\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n df = compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=[],\n cooling_balance_points=[],\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n )\n expected = pd.DataFrame(\n {\n \"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],\n \"n_hours_kept\": [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],\n \"n_hours_dropped\": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n },\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n pd.testing.assert_frame_equal(df, expected)\n```\n" + ], + "line": 166, + "token": 983, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 13, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef power_color(\n frequency,\n power,\n power_err=None,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n):\n freq_edges = np.asarray(freq_edges)\n if len(freq_edges) != 5:\n raise ValueError(\"freq_edges must have 5 elements\")\n\n frequency = np.asarray(frequency)\n power = np.asarray(power)\n\n if df is None:\n df = np.median(np.diff(frequency))\n input_frequency_low_edges = frequency - df / 2\n input_frequency_high_edges = frequency + df / 2\n\n if freq_edges.min() < input_frequency_low_edges[0]:\n raise ValueError(\"The minimum frequency is larger than the first frequency edge\")\n if freq_edges.max() > input_frequency_high_edges[-1]:\n raise ValueError(\"The maximum frequency is lower than the last frequency edge\")\n\n if power_err is None:\n power_err = power / np.sqrt(m)\n else:\n power_err = np.asarray(power_err)\n\n if freqs_to_exclude is not None:\n if len(np.shape(freqs_to_exclude)) == 1:\n freqs_to_exclude = [freqs_to_exclude]\n\n if (\n not isinstance(freqs_to_exclude, Iterable)\n or len(np.shape(freqs_to_exclude)) != 2\n or np.shape(freqs_to_exclude)[1] != 2\n ):\n raise ValueError(\"freqs_to_exclude must be of format [[f0, f1], [f2, f3], ...]\")\n for f0, f1 in freqs_to_exclude:\n frequency_mask = (input_frequency_low_edges > f0) & (input_frequency_high_edges < f1)\n idx0, idx1 = np.searchsorted(frequency, [f0, f1])\n power[frequency_mask] = np.mean([power[idx0], power[idx1]])\n\n var00, var00_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[:2],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var01, var01_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[2:4],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var10, var10_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[1:3],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var11, var11_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[3:5],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n pc0 = var00 / var01\n pc1 = var10 / var11\n pc0_err = pc0 * (var00_err / var00 + var01_err / var01)\n pc1_err = pc1 * (var10_err / var10 + var11_err / var11)\n if return_log:\n pc0_err = 1 / pc0 * pc0_err\n pc1_err = 1 / pc1 * pc1_err\n pc0 = np.log10(pc0)\n pc1 = np.log10(pc1)\n return pc0, pc0_err, pc1, pc1_err\n```\n###test function signature:\n```python\n\n def test_with_power_err(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_power_err", + "reference": "\n def test_with_power_err(self):\n pc0, pc0_err, pc1, pc1_err = power_color(\n self.freq,\n self.power,\n power_err=self.power / 2,\n )\n pc0e, pc0e_err, pc1e, pc1e_err = power_color(\n self.freq,\n self.power,\n power_err=self.power,\n )\n assert np.isclose(pc0, 1, atol=0.001)\n assert np.isclose(pc1, 1, atol=0.001)\n assert np.isclose(pc0e, 1, atol=0.001)\n assert np.isclose(pc1e, 1, atol=0.001)\n assert np.isclose(pc0e_err / pc0_err, 2, atol=0.001)\n assert np.isclose(pc1e_err / pc1_err, 2, atol=0.001)", + "completions": [ + "```python\n def test_with_power_err(self):\n frequency = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power_err = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n pc0, pc0_err, pc1, pc1_err = power_color(\n frequency,\n power,\n power_err=power_err,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n )\n self.assertEqual(pc0, 0.0)\n self.assertEqual(pc0_err, 0.0)\n self.assertEqual(pc1, 0.0)\n self.assertEqual(pc1_err, 0.0)\n```\n" + ], + "line": 97, + "token": 560, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 14, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def apply_top_k(\n self, logits: numpy.ndarray, filter_value=_MIN_FLOAT\n ) -> numpy.ndarray:\n logits_shape = logits.shape\n logits = logits.reshape(logits.shape[-1])\n top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]\n logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value\n\n return logits.reshape(logits_shape)\n```\n###test function signature:\n```python\n\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_apply_topk", + "reference": "\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11))\n\n token_generator = TokenGenerator(\n logits_shape=logits[-1].shape[-1],\n top_k=3,\n )\n\n filter_value = -float(\"Inf\")\n new_logits = token_generator.apply_top_k(\n logits.copy(), filter_value=filter_value\n )\n\n for _ in range(token_generator.top_k):\n curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits)\n assert curr_max > filter_value\n new_logits = numpy.delete(new_logits, idx)\n\n assert numpy.all(new_logits == filter_value)", + "completions": [ + "```python\n def test_apply_topk(\n self,\n ):\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9," + ], + "line": 21, + "token": 152, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(\n self,\n sources: List[Union[str, Path, ByteStream]],\n meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,\n ):\n parser = MarkdownIt(renderer_cls=RendererPlain)\n if self.table_to_single_line:\n parser.enable(\"table\")\n\n documents = []\n meta_list = normalize_metadata(meta=meta, sources_count=len(sources))\n\n for source, metadata in tqdm(\n zip(sources, meta_list),\n total=len(sources),\n desc=\"Converting markdown files to Documents\",\n disable=not self.progress_bar,\n ):\n try:\n bytestream = get_bytestream_from_source(source)\n except Exception as e:\n logger.warning(\"Could not read {source}. Skipping it. Error: {error}\", source=source, error=e)\n continue\n try:\n file_content = bytestream.data.decode(\"utf-8\")\n text = parser.render(file_content)\n except Exception as conversion_e:\n logger.warning(\n \"Failed to extract text from {source}. Skipping it. Error: {error}\",\n source=source,\n error=conversion_e,\n )\n continue\n\n merged_metadata = {**bytestream.meta, **metadata}\n document = Document(content=text, meta=merged_metadata)\n documents.append(document)\n\n return {\"documents\": documents}\n```\n###test function signature:\n```python\n\n def test_run_calls_normalize_metadata(self, test_files_path):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_calls_normalize_metadata", + "reference": "\n def test_run_calls_normalize_metadata(self, test_files_path):\n bytestream = ByteStream(data=b\"test\", meta={\"author\": \"test_author\", \"language\": \"en\"})\n\n converter = MarkdownToDocument()\n\n with patch(\"haystack.components.converters.markdown.normalize_metadata\") as normalize_metadata, patch(\n \"haystack.components.converters.markdown.MarkdownIt\"\n ):\n converter.run(sources=[bytestream, test_files_path / \"markdown\" / \"sample.md\"], meta={\"language\": \"it\"})\n\n # check that the metadata normalizer is called properly\n normalize_metadata.assert_called_with(meta={\"language\": \"it\"}, sources_count=2)", + "completions": [ + "```python\n def test_run_calls_normalize_metadata(self, test_files_path):\n # Arrange\n sources = [test_files_path / \"test_file.md\"]\n meta = {\"key\": \"value\"}\n\n # Act\n run(sources, meta)\n\n # Assert\n normalize_metadata.assert_called_once_with(meta=meta, sources_count=1)\n```\n" + ], + "line": 45, + "token": 311, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 16, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_config(config: object) -> None:\n if hasattr(config, \"kind\"):\n Config.running_config.classes[config.kind] = config # type: ignore\n Config.running_config.types[config.kind] = {}\n for field in fields(config): # type: ignore\n if hasattr(field, \"type\"):\n Config.running_config.types[config.kind][field.name] = optional_origin(field.type)\n else:\n raise RuntimeError(\"Config must have a 'kind' attribute\")\n```\n###test function signature:\n```python\n\ndef test_config():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_config", + "reference": "\ndef test_config():\n arg_parser = get_arg_parser()\n core_add_args(arg_parser)\n arg_parser.parse_known_args()\n cfg = Config(\"test\")\n cfg.add_config(ConfigTest)\n cfg.init_default_config()\n assert Config.dict() == {\n \"configtest\": {\n \"testvar1\": \"testing123\",\n \"testvar2\": 12345,\n \"testvar3\": {\n \"mydict\": {\"foo\": \"bar\", \"abc\": {\"def\": \"ghi\"}},\n \"myint\": 0,\n \"mystr\": \"Hello\",\n },\n }\n }\n cfg2 = Config(\"test2\")\n assert cfg.configtest.testvar1 == cfg2.configtest.testvar1\n assert cfg.configtest.testvar1 == Config.configtest.testvar1\n assert cfg.configtest.testvar2 == cfg2.configtest.testvar2\n assert cfg.configtest.testvar3 == cfg2.configtest.testvar3\n Config.configtest.testvar2 += 1\n assert cfg.configtest.testvar2 == 12346\n with pytest.raises(ConfigNotFoundError):\n Config.does_not_exist.foo = \"bar\"\n with pytest.raises(ConfigNotFoundError):\n cfg.does_not_exist.foo = \"bar\"", + "completions": [ + "```python\ndef test_config():\n from dataclasses import dataclass, field\n from typing import Optional\n\n @dataclass\n class Config:\n kind: str\n a: Optional[int] = field(default=None)\n b: Optional[int] = field(default=None)\n\n add_config(Config)\n\n assert Config.running_config.classes[\"Config\"] == Config\n assert Config.running_config.types[\"Config\"] == {\"a\": int, \"b\": int}\n```\n" + ], + "line": 19, + "token": 166, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 17, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_with_num_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_with_num_ref_line", + "reference": "\ndef test_nanoplot_out_with_num_ref_line():\n\n out_with_num_ref_line = _generate_nanoplot(**CASES[1])\n\n assert _is_nanoplot_output(out_with_num_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_with_num_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color," + ], + "line": 749, + "token": 7038, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 18, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def begin_typing(self, x=None):\n self.capturekeystrokes = True\n # Disable keypress shortcuts, which may otherwise cause the figure to\n # be saved, closed, etc., until the user stops typing. The way to\n # achieve this depends on whether toolmanager is in use.\n stack = ExitStack() # Register cleanup actions when user stops typing.\n self._on_stop_typing = stack.close\n toolmanager = getattr(\n self.ax.figure.canvas.manager, \"toolmanager\", None)\n if toolmanager is not None:\n # If using toolmanager, lock keypresses, and plan to release the\n # lock when typing stops.\n toolmanager.keypresslock(self)\n stack.callback(toolmanager.keypresslock.release, self)\n else:\n # If not using toolmanager, disable all keypress-related rcParams.\n # Avoid spurious warnings if keymaps are getting deprecated.\n with _api.suppress_matplotlib_deprecation_warning():\n stack.enter_context(mpl.rc_context(\n {k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n```\n###test function signature:\n```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_TextBox", + "reference": "\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n plt.rcParams._set(\"toolbar\", toolbar)\n\n submit_event = mock.Mock(spec=noop, return_value=None)\n text_change_event = mock.Mock(spec=noop, return_value=None)\n tool = widgets.TextBox(ax, '')\n tool.on_submit(submit_event)\n tool.on_text_change(text_change_event)\n\n assert tool.text == ''\n\n do_event(tool, '_click')\n\n tool.set_val('x**2')\n\n assert tool.text == 'x**2'\n assert text_change_event.call_count == 1\n\n tool.begin_typing()\n tool.stop_typing()\n\n assert submit_event.call_count == 2\n\n do_event(tool, '_click')\n do_event(tool, '_keypress', key='+')\n do_event(tool, '_keypress', key='5')\n\n assert text_change_event.call_count == 3", + "completions": [ + "```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n with ExitStack() as stack:\n stack.enter_context(mpl.rc_context({k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n ax.begin_typing()\n assert ax.capturekeystrokes is True\n assert toolbar.keypresslock.is_locked(ax)\n```\n" + ], + "line": 31, + "token": 258, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def read_events(\n db: PrefectDBInterface,\n session: AsyncSession,\n events_filter: EventFilter,\n limit: \"int | None\" = None,\n offset: \"int | None\" = None,\n) -> Sequence[ORMEvent]:\n # Always order by occurred timestamp, with placeholder for order direction\n order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc\n\n # Check if distinct fields are provided\n if distinct_fields := build_distinct_queries(events_filter):\n # Define window function\n window_function = (\n sa.func.row_number()\n .over(partition_by=distinct_fields, order_by=order(db.Event.occurred))\n .label(\"row_number\")\n )\n # Create a subquery with the window function\n subquery = (\n sa.select(db.Event, window_function)\n .where(\n sa.and_(\n *events_filter.build_where_clauses(db)\n ) # Ensure the same filters are applied here\n )\n .subquery()\n )\n\n # Alias the subquery for easier column references\n aliased_table = aliased(db.Event, subquery)\n\n # Create the final query from the subquery, filtering to get only rows with row_number = 1\n select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1)\n\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(subquery.c.occurred))\n\n else:\n # If no distinct fields are provided, create a query for all events\n select_events_query = sa.select(db.Event).where(\n sa.and_(*events_filter.build_where_clauses(db))\n )\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(db.Event.occurred))\n\n if limit is not None:\n limit = max(0, min(limit, events_filter.logical_limit))\n select_events_query = select_events_query.limit(limit=limit)\n if offset is not None:\n select_events_query = select_events_query.offset(offset=offset)\n\n logger.debug(\"Running PostgreSQL query: %s\", select_events_query)\n\n select_events_query_result = await session.execute(select_events_query)\n return select_events_query_result.scalars().unique().all()\n```\n###test function signature:\n```python\n\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_read_events", + "reference": "\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_1.id\n\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=3),\n until=pendulum.now(\"UTC\").subtract(days=1),\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_2.id\n\n # Read criteria should apply AND logic\n # Nothing should be returned for this resource\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n resource=EventResourceFilter(id=[\"prefect.garbage.foo\"]),\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 0", + "completions": [ + "```python\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n # Arrange\n events_filter = EventFilter(\n occurred_after=datetime.datetime(2022, 1, 1),\n occurred_before=datetime.datetime(2022, 1, 2),\n order=EventOrder.DESC,\n )\n # Act\n events = await read_events(\n db=self.db,\n session=session,\n events_filter=events_filter,\n )\n # Assert\n self.assertEqual(len(events), 2)\n self.assertEqual(events[0], event_2)\n self.assertEqual(events[1], event_1)\n```\n" + ], + "line": 58, + "token": 412, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 20, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n def test_evaluate_diff_dim(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_evaluate_diff_dim", + "reference": " def test_evaluate_diff_dim(self):\n x1 = np.arange(3, 10, 2)\n kde = mlab.GaussianKDE(x1)\n x2 = np.arange(3, 12, 2)\n y_expected = [\n 0.08797252, 0.11774109, 0.11774109, 0.08797252, 0.0370153\n ]\n y = kde.evaluate(x2)\n np.testing.assert_array_almost_equal(y, y_expected, 7)", + "completions": [ + "```python\n def test_evaluate_diff_dim(self):\n \"\"\"\n Test evaluate with different dimensions\n \"\"\"\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n\n # Generate a random dataset\n X, _ = make_blobs(n_samples=100, n_features=10, centers=5,\n random_state=42)\n\n # Standardize the data\n scaler = StandardScaler()\n X = scaler.fit_transform(X)\n\n # Create a GaussianMixture object\n gmm = GaussianMixture(n_components=5, covariance_type='full')\n gmm.fit(X)\n\n # Generate a random set of points\n n_points = 100\n points = np.random.rand(n_points, 10)\n\n # Evaluate the points\n result = gmm.evaluate(points)\n\n # Check that the result is correct\n self.assertEqual(result.shape, (n_points,))\n self.assertTrue(np.all(result >= 0))\n self.assertTrue(np.all(result <= 1))\n```\n" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 21, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _differentiate(func, x, *, args=(), atol=None, rtol=None, maxiter=10,\n order=8, initial_step=0.5, step_factor=2.0,\n step_direction=0, preserve_shape=False, callback=None):\n # TODO (followup):\n # - investigate behavior at saddle points\n # - array initial_step / step_factor?\n # - multivariate functions?\n\n res = _differentiate_iv(func, x, args, atol, rtol, maxiter, order, initial_step,\n step_factor, step_direction, preserve_shape, callback)\n (func, x, args, atol, rtol, maxiter, order,\n h0, fac, hdir, preserve_shape, callback) = res\n\n # Initialization\n # Since f(x) (no step) is not needed for central differences, it may be\n # possible to eliminate this function evaluation. However, it's useful for\n # input validation and standardization, and everything else is designed to\n # reduce function calls, so let's keep it simple.\n temp = eim._initialize(func, (x,), args, preserve_shape=preserve_shape)\n func, xs, fs, args, shape, dtype = temp\n x, f = xs[0], fs[0]\n df = np.full_like(f, np.nan)\n # Ideally we'd broadcast the shape of `hdir` in `_elementwise_algo_init`, but\n # it's simpler to do it here than to generalize `_elementwise_algo_init` further.\n # `hdir` and `x` are already broadcasted in `_differentiate_iv`, so we know\n # that `hdir` can be broadcasted to the final shape.\n hdir = np.broadcast_to(hdir, shape).flatten()\n\n status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress\n nit, nfev = 0, 1 # one function evaluations performed above\n # Boolean indices of left, central, right, and (all) one-sided steps\n il = hdir < 0\n ic = hdir == 0\n ir = hdir > 0\n io = il | ir\n\n # Most of these attributes are reasonably obvious, but:\n # - `fs` holds all the function values of all active `x`. The zeroth\n # axis corresponds with active points `x`, the first axis corresponds\n # with the different steps (in the order described in\n # `_differentiate_weights`).\n # - `terms` (which could probably use a better name) is half the `order`,\n # which is always even.\n work = _RichResult(x=x, df=df, fs=f[:, np.newaxis], error=np.nan, h=h0,\n df_last=np.nan, error_last=np.nan, h0=h0, fac=fac,\n atol=atol, rtol=rtol, nit=nit, nfev=nfev,\n status=status, dtype=dtype, terms=(order+1)//2,\n hdir=hdir, il=il, ic=ic, ir=ir, io=io)\n # This is the correspondence between terms in the `work` object and the\n # final result. In this case, the mapping is trivial. Note that `success`\n # is prepended automatically.\n res_work_pairs = [('status', 'status'), ('df', 'df'), ('error', 'error'),\n ('nit', 'nit'), ('nfev', 'nfev'), ('x', 'x')]\n\n def pre_func_eval(work):\n \"\"\"Determine the abscissae at which the function needs to be evaluated.\n\n See `_differentiate_weights` for a description of the stencil (pattern\n of the abscissae).\n\n In the first iteration, there is only one stored function value in\n `work.fs`, `f(x)`, so we need to evaluate at `order` new points. In\n subsequent iterations, we evaluate at two new points. Note that\n `work.x` is always flattened into a 1D array after broadcasting with\n all `args`, so we add a new axis at the end and evaluate all point\n in one call to the function.\n\n For improvement:\n - Consider measuring the step size actually taken, since `(x + h) - x`\n is not identically equal to `h` with floating point arithmetic.\n - Adjust the step size automatically if `x` is too big to resolve the\n step.\n - We could probably save some work if there are no central difference\n steps or no one-sided steps.\n \"\"\"\n n = work.terms # half the order\n h = work.h # step size\n c = work.fac # step reduction factor\n d = c**0.5 # square root of step reduction factor (one-sided stencil)\n # Note - no need to be careful about dtypes until we allocate `x_eval`\n\n if work.nit == 0:\n hc = h / c**np.arange(n)\n hc = np.concatenate((-hc[::-1], hc))\n else:\n hc = np.asarray([-h, h]) / c**(n-1)\n\n if work.nit == 0:\n hr = h / d**np.arange(2*n)\n else:\n hr = np.asarray([h, h/d]) / c**(n-1)\n\n n_new = 2*n if work.nit == 0 else 2 # number of new abscissae\n x_eval = np.zeros((len(work.hdir), n_new), dtype=work.dtype)\n il, ic, ir = work.il, work.ic, work.ir\n x_eval[ir] = work.x[ir, np.newaxis] + hr\n x_eval[ic] = work.x[ic, np.newaxis] + hc\n x_eval[il] = work.x[il, np.newaxis] - hr\n return x_eval\n\n def post_func_eval(x, f, work):\n \"\"\" Estimate the derivative and error from the function evaluations\n\n As in `pre_func_eval`: in the first iteration, there is only one stored\n function value in `work.fs`, `f(x)`, so we need to add the `order` new\n points. In subsequent iterations, we add two new points. The tricky\n part is getting the order to match that of the weights, which is\n described in `_differentiate_weights`.\n\n For improvement:\n - Change the order of the weights (and steps in `pre_func_eval`) to\n simplify `work_fc` concatenation and eliminate `fc` concatenation.\n - It would be simple to do one-step Richardson extrapolation with `df`\n and `df_last` to increase the order of the estimate and/or improve\n the error estimate.\n - Process the function evaluations in a more numerically favorable\n way. For instance, combining the pairs of central difference evals\n into a second-order approximation and using Richardson extrapolation\n to produce a higher order approximation seemed to retain accuracy up\n to very high order.\n - Alternatively, we could use `polyfit` like Jacobi. An advantage of\n fitting polynomial to more points than necessary is improved noise\n tolerance.\n \"\"\"\n n = work.terms\n n_new = n if work.nit == 0 else 1\n il, ic, io = work.il, work.ic, work.io\n\n # Central difference\n # `work_fc` is *all* the points at which the function has been evaluated\n # `fc` is the points we're using *this iteration* to produce the estimate\n work_fc = (f[ic, :n_new], work.fs[ic, :], f[ic, -n_new:])\n work_fc = np.concatenate(work_fc, axis=-1)\n if work.nit == 0:\n fc = work_fc\n else:\n fc = (work_fc[:, :n], work_fc[:, n:n+1], work_fc[:, -n:])\n fc = np.concatenate(fc, axis=-1)\n\n # One-sided difference\n work_fo = np.concatenate((work.fs[io, :], f[io, :]), axis=-1)\n if work.nit == 0:\n fo = work_fo\n else:\n fo = np.concatenate((work_fo[:, 0:1], work_fo[:, -2*n:]), axis=-1)\n\n work.fs = np.zeros((len(ic), work.fs.shape[-1] + 2*n_new))\n work.fs[ic] = work_fc\n work.fs[io] = work_fo\n\n wc, wo = _differentiate_weights(work, n)\n work.df_last = work.df.copy()\n work.df[ic] = fc @ wc / work.h\n work.df[io] = fo @ wo / work.h\n work.df[il] *= -1\n\n work.h /= work.fac\n work.error_last = work.error\n # Simple error estimate - the difference in derivative estimates between\n # this iteration and the last. This is typically conservative because if\n # convergence has begin, the true error is much closer to the difference\n # between the current estimate and the *next* error estimate. However,\n # we could use Richarson extrapolation to produce an error estimate that\n # is one order higher, and take the difference between that and\n # `work.df` (which would just be constant factor that depends on `fac`.)\n work.error = abs(work.df - work.df_last)\n\n def check_termination(work):\n \"\"\"Terminate due to convergence, non-finite values, or error increase\"\"\"\n stop = np.zeros_like(work.df).astype(bool)\n\n i = work.error < work.atol + work.rtol*abs(work.df)\n work.status[i] = eim._ECONVERGED\n stop[i] = True\n\n if work.nit > 0:\n i = ~((np.isfinite(work.x) & np.isfinite(work.df)) | stop)\n work.df[i], work.status[i] = np.nan, eim._EVALUEERR\n stop[i] = True\n\n # With infinite precision, there is a step size below which\n # all smaller step sizes will reduce the error. But in floating point\n # arithmetic, catastrophic cancellation will begin to cause the error\n # to increase again. This heuristic tries to avoid step sizes that are\n # too small. There may be more theoretically sound approaches for\n # detecting a step size that minimizes the total error, but this\n # heuristic seems simple and effective.\n i = (work.error > work.error_last*10) & ~stop\n work.status[i] = _EERRORINCREASE\n stop[i] = True\n\n return stop\n\n def post_termination_check(work):\n return\n\n def customize_result(res, shape):\n return shape\n\n return eim._loop(work, callback, shape, maxiter, func, args, dtype,\n pre_func_eval, post_func_eval, check_termination,\n post_termination_check, customize_result, res_work_pairs,\n preserve_shape)\n```\n###test function signature:\n```python\n\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_maxiter_callback", + "reference": "\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n x = 0.612814\n dist = stats.norm()\n maxiter = 3\n\n def f(x):\n res = dist.cdf(x)\n return res\n\n default_order = 8\n res = differentiate(f, x, maxiter=maxiter, rtol=1e-15)\n assert not np.any(res.success)\n assert np.all(res.nfev == default_order + 1 + (maxiter - 1)*2)\n assert np.all(res.nit == maxiter)\n\n def callback(res):\n callback.iter += 1\n callback.res = res\n assert hasattr(res, 'x')\n assert res.df not in callback.dfs\n callback.dfs.add(res.df)\n assert res.status == eim._EINPROGRESS\n if callback.iter == maxiter:\n raise StopIteration\n callback.iter = -1 # callback called once before first iteration\n callback.res = None\n callback.dfs = set()\n\n res2 = differentiate(f, x, callback=callback, rtol=1e-15)\n # terminating with callback is identical to terminating due to maxiter\n # (except for `status`)\n for key in res.keys():\n if key == 'status':\n assert res[key] == eim._ECONVERR\n assert callback.res[key] == eim._EINPROGRESS\n assert res2[key] == eim._ECALLBACK\n else:\n assert res2[key] == callback.res[key] == res[key]", + "completions": [ + "```python\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n def callback(work):\n if work.nit == 0:\n work.maxiter = 1\n else:\n work.maxiter = 0\n res = _differentiate(lambda x: x**2, 0, maxiter=2, callback=callback)\n self.assertEqual(res.nit, 1)\n self.assertEqual(res.status, 0)\n```\n" + ], + "line": 187, + "token": 1981, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 22, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef get_standard_metadata(\n pyproject_dict: Mapping[str, Any],\n settings: ScikitBuildSettings,\n) -> StandardMetadata:\n new_pyproject_dict = dict(pyproject_dict)\n # Handle any dynamic metadata\n for field, provider, config in load_dynamic_metadata(settings.metadata):\n if provider is None:\n msg = f\"{field} is missing provider\"\n raise KeyError(msg)\n if field not in pyproject_dict.get(\"project\", {}).get(\"dynamic\", []):\n msg = f\"{field} is not in project.dynamic\"\n raise KeyError(msg)\n new_pyproject_dict[\"project\"][field] = provider.dynamic_metadata(field, config)\n new_pyproject_dict[\"project\"][\"dynamic\"].remove(field)\n\n metadata = StandardMetadata.from_pyproject(new_pyproject_dict)\n\n # For scikit-build-core < 0.5, we keep the normalized name for back-compat\n if settings.minimum_version is not None and settings.minimum_version < Version(\n \"0.5\"\n ):\n metadata.name = metadata.canonical_name\n\n # The description field is required to be one line. Instead of merging it\n # or cutting off subsequent lines (setuptools), we throw a nice error.\n # But we didn't validate before 0.9.\n if (\n settings.minimum_version is None or settings.minimum_version >= Version(\"0.9\")\n ) and \"\\n\" in (metadata.description or \"\"):\n msg = \"Multiple lines in project.description are not supported; this is supposed to be a one line summary\"\n raise ValueError(msg)\n\n return metadata\n```\n###test function signature:\n```python\n\ndef test_plugin_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_plugin_metadata", + "reference": "\ndef test_plugin_metadata():\n reason_msg = (\n \"install hatch-fancy-pypi-readme and setuptools-scm to test the \"\n \"dynamic metadata plugins\"\n )\n pytest.importorskip(\"hatch_fancy_pypi_readme\", reason=reason_msg)\n pytest.importorskip(\"setuptools_scm\", reason=reason_msg)\n if shutil.which(\"git\") is None:\n pytest.skip(\"git is not installed\")\n\n shutil.copy(\"plugin_project.toml\", \"pyproject.toml\")\n\n subprocess.run([\"git\", \"init\", \"--initial-branch=main\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.name\", \"bot\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.email\", \"bot@scikit-build.org\"], check=True)\n subprocess.run([\"git\", \"add\", \"pyproject.toml\"], check=True)\n subprocess.run([\"git\", \"commit\", \"-m\", \"initial commit\"], check=True)\n subprocess.run([\"git\", \"tag\", \"v0.1.0\", \"-m\", \"initial commint\"], check=True)\n\n with Path(\"pyproject.toml\").open(\"rb\") as ft:\n pyproject = tomllib.load(ft)\n settings_reader = SettingsReader(pyproject, {}, state=\"metadata_wheel\")\n settings = settings_reader.settings\n\n settings_reader.validate_may_exit()\n\n metadata = get_standard_metadata(pyproject, settings)\n\n assert str(metadata.version) == \"0.1.0\"\n assert metadata.readme == pyproject_metadata.Readme(\n \"Fragment #1Fragment #2\", None, \"text/x-rst\"\n )\n\n assert set(GetRequires().dynamic_metadata()) == {\n \"hatch-fancy-pypi-readme>=22.3\",\n \"setuptools-scm\",\n }", + "completions": [ + "```python\n\ndef test_plugin_metadata():\n from skbuild import __version__ as skbuild_version\n from skbuild.settings import ScikitBuildSettings\n from skbuild.metadata import StandardMetadata\n from skbuild.utils import load_dynamic_metadata\n from skbuild.version import Version\n\n pyproject_dict = {\n \"build-system\": {\n \"requires\": [\"setuptools>=40.8.0\", \"wheel\"],\n \"build-backend\": \"setuptools.build_meta\",\n },\n \"project\": {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n \"description\": \"A build system for Python packages\",\n \"dynamic\": [\"name\", \"version\"],\n },\n }\n\n settings = ScikitBuildSettings(\n minimum_version=Version(\"0.1.0\"),\n metadata=[\n (\"name\", None, None),\n (\"version\", None, None),\n ],\n )\n\n metadata = get_standard_metadata(pyproject_dict, settings)\n\n assert metadata.name == \"scikit-build\"\n assert metadata.version == \"0.1.0\"\n assert metadata.description == \"A build system for Python packages\"\n\n assert metadata.canonical_name == \"scikit-build\"\n assert metadata.canonical_version == \"0.1.0\"\n\n assert metadata.dynamic_metadata == {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n }\n\n assert metadata.dynamic_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_providers == [None, None]\n\n assert metadata.dynamic_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_metadata_providers == [None, None]\n\n assert metadata.dynamic_metadata_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields_by_provider == {None: [\"name\", \"version\"]}\n\n assert metadata.dynamic_metadata_providers_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_configs_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field == {\n (None, \"name\"): \"scikit-build\",\n (None, \"version\"): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider == {\n (\"name\", None): \"scikit-build\",\n (\"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider == {\n (\"name\", None): None,\n (\"version\", None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config == {\n (None, \"name\", None): \"scikit-build\",\n (None, \"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config == {\n (\"name\", None, None): \"scikit-build\",\n (\"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider_and_config == {\n (\"name\", None, None): None,\n (\"version\", None, None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config_and_provider == {\n (None, \"name\", None, None): \"scikit-build\",\n (None, \"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config_and_provider == {\n (\"name\", None, None" + ], + "line": 40, + "token": 347, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 23, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _get_datetime_format(value):\n if not isinstance(value, pd.Series):\n value = pd.Series(value)\n\n value = value[~value.isna()]\n value = value.astype(str).to_numpy()\n\n return _guess_datetime_format_for_array(value)\n```\n###test function signature:\n```python\ndef test__get_datetime_format():\n # Setup\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test__get_datetime_format", + "reference": "def test__get_datetime_format():\n # Setup\n string_value = '2021-02-02'\n list_value = [np.nan, '2021-02-02']\n series_value = pd.Series(['2021-02-02T12:10:59'])\n\n # Run\n string_out = _get_datetime_format(string_value)\n list_out = _get_datetime_format(list_value)\n series_out = _get_datetime_format(series_value)\n\n # Assert\n expected_output = '%Y-%m-%d'\n assert string_out == expected_output\n assert list_out == expected_output\n assert series_out == '%Y-%m-%dT%H:%M:%S'", + "completions": [ + "```python\ndef test__get_datetime_format():\n # Setup\n value = pd.Series([\"2020-01-01\", \"2020-01-02\", \"2020-01-03\"])\n expected = \"%Y-%m-%d\"\n\n # Exercise\n actual = _get_datetime_format(value)\n\n # Verify\n assert actual == expected\n```\n" + ], + "line": 17, + "token": 124, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 24, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def load(self) -> None:\n super().load()\n\n _path = str(\n importlib_resources.files(\"distilabel\")\n / \"steps\"\n / \"tasks\"\n / \"templates\"\n / \"ultrafeedback\"\n / f\"{self.aspect}.jinja2\"\n )\n\n self._template = Template(open(_path).read())\n```\n###test function signature:\n```python\n\n def test_process_with_complex_aspect(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_process_with_complex_aspect", + "reference": "\n def test_process_with_complex_aspect(self) -> None:\n pipeline = Pipeline(name=\"unit-test-pipeline\")\n llm = UltraFeedbackLLM()\n\n task = UltraFeedback(\n name=\"ultrafeedback\",\n aspect=\"truthfulness\",\n llm=llm,\n pipeline=pipeline,\n )\n task.load()\n\n assert next(\n task.process([{\"instruction\": \"test\", \"generations\": [\"A\", \"B\"]}])\n ) == [\n {\n \"instruction\": \"test\",\n \"generations\": [\"A\", \"B\"],\n \"types\": [1, 2],\n \"rationales\": [\"text\", \"text\"],\n \"ratings\": [1, 2],\n \"rationales-for-ratings\": [\"text\", \"text\"],\n \"model_name\": \"ultrafeedback-model\",\n }\n ]", + "completions": [ + "```python\n def test_process_with_complex_aspect(self) -> None:\n self.aspect = \"complex_aspect\"\n self.load()\n self.assertEqual(\n self._template.render(\n {\n \"aspect\": self.aspect,\n \"aspect_name\": \"Complex Aspect\",\n \"aspect_description\": \"This is a complex aspect\",\n \"aspect_type\": \"complex\",\n \"aspect_value\": \"1\",\n \"aspect_weight\": \"1\",\n \"aspect_color\": \"red\",\n \"aspect_icon\": \"icon-circle\",\n \"aspect_icon_color\": \"red\",\n \"aspect_icon_bg_color\": \"white\",\n \"aspect_icon_size\": \"20\",\n \"aspect_icon_border_color\": \"red\",\n \"aspect_icon_border_width\": \"2\",\n \"aspect_icon_border_radius\": \"50\",\n \"aspect_icon_padding\": \"5\",\n \"aspect_icon_margin\": \"5\",\n \"aspect_icon_bg_size\": \"20\",\n \"aspect_icon_bg_border_color\": \"red\",\n \"aspect_icon_bg_border_width\": \"2\",\n \"aspect_icon_bg_border_radius\": \"50\",\n \"aspect_icon_bg_padding\": \"5\",\n \"aspect_icon_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_" + ], + "line": 21, + "token": 145, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 25, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _setup_cuda_device(device: torch.device) -> torch.device:\n local_rank = _get_local_rank() or 0\n if device.index is None:\n device = torch.device(type=\"cuda\", index=local_rank)\n\n # Ensure index is available before setting device\n if device.index >= torch.cuda.device_count():\n raise RuntimeError(\n \"The local rank is larger than the number of available GPUs.\"\n )\n\n torch.cuda.set_device(device)\n return device\n```\n###test function signature:\n```python\n\n def test_get_gpu_device(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_gpu_device", + "reference": "\n def test_get_gpu_device(self) -> None:\n device_idx = torch.cuda.device_count() - 1\n assert device_idx >= 0\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(device_idx)}, clear=True):\n device = get_device()\n assert device.type == \"cuda\"\n assert device.index == device_idx\n assert device.index == torch.cuda.current_device()\n\n # Test that we raise an error if the device index is specified on distributed runs\n if device_idx > 0:\n with pytest.raises(\n RuntimeError,\n match=f\"Device specified is cuda:0 but was assigned cuda:{device_idx}\",\n ):\n device = get_device(\"cuda:0\")\n\n invalid_device_idx = device_idx + 10\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(invalid_device_idx)}):\n with pytest.raises(\n RuntimeError,\n match=\"The local rank is larger than the number of available GPUs\",\n ):\n device = get_device(\"cuda\")\n\n # Test that we fall back to 0 if LOCAL_RANK is not specified\n device = torch.device(_get_device_type_from_env())\n device = _setup_cuda_device(device)\n assert device.type == \"cuda\"\n assert device.index == 0\n assert device.index == torch.cuda.current_device()", + "completions": [ + "```python\n def test_get_gpu_device(self) -> None:\n device = torch.device(type=\"cuda\", index=0)\n self.assertEqual(device, _setup_cuda_device(device))\n```\n" + ], + "line": 21, + "token": 157, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):\n local_dir = settings.home_dir / settings.cache_dir / \"models\" / model_id\n client.load_artifact(local_dir, posixpath.join(\"models\", model_id))\n meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)\n # internal worker case, no token based http client\n if not local_dir.exists():\n raise RuntimeError(f\"Cannot find existing model {project_key}.{model_id} in {local_dir}\")\n with (Path(local_dir) / META_FILENAME).open(encoding=\"utf-8\") as f:\n file_meta = yaml.load(f, Loader=yaml.Loader)\n classification_labels = cls.cast_labels(meta_response)\n meta = ModelMeta(\n name=meta_response.name,\n description=meta_response.description,\n model_type=SupportedModelTypes[meta_response.modelType],\n feature_names=meta_response.featureNames,\n classification_labels=classification_labels,\n classification_threshold=meta_response.threshold,\n loader_module=file_meta[\"loader_module\"],\n loader_class=file_meta[\"loader_class\"],\n )\n\n model_py_ver = (\n tuple(meta_response.languageVersion.split(\".\")) if \"PYTHON\" == meta_response.language.upper() else None\n )\n\n clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver)\n\n constructor_params = meta.__dict__\n constructor_params[\"id\"] = str(model_id)\n\n del constructor_params[\"loader_module\"]\n del constructor_params[\"loader_class\"]\n\n model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params)\n return model\n```\n###test function signature:\n```python\n\ndef test_model_download(request):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_model_download", + "reference": "\ndef test_model_download(request):\n model: BaseModel = request.getfixturevalue(\"german_credit_model\")\n project_key = str(uuid.uuid4())\n\n with utils.MockedProjectCacheDir():\n with utils.MockedClient(mock_all=False) as (client, mr):\n # The model needs to request files\n requested_urls = []\n requested_urls.extend(utils.register_uri_for_model_meta_info(mr, model, project_key))\n requested_urls.extend(\n utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True)\n )\n\n downloaded_model = BaseModel.download(client=client, project_key=project_key, model_id=str(model.id))\n\n for requested_url in requested_urls:\n assert utils.is_url_requested(mr.request_history, requested_url)\n\n assert downloaded_model.id == model.id\n assert downloaded_model.meta == model.meta", + "completions": [ + "```python\ndef test_model_download(request):\n from pathlib import Path\n from unittest.mock import MagicMock\n from giskard.client import GiskardClient\n from giskard.models import ModelMetaInfo, ModelMeta, SupportedModelTypes\n from giskard.utils import cast_labels\n from giskard.models.base import Model\n from giskard.models.classification import ClassificationModel\n from giskard.models.regression import RegressionModel\n from giskard.models.clustering import ClusteringModel\n from giskard.models.dim_reduction import DimReductionModel\n from giskard.models.anomaly_detection import AnomalyDetectionModel\n from giskard.models.time_series import TimeSeriesModel\n from giskard.models.text_classification import TextClassificationModel\n from giskard.models.text_recommendation import TextRecommendationModel\n from giskard.models.image_classification import ImageClassificationModel\n from giskard.models.image_recommendation import ImageRecommendationModel\n from giskard.models.image_segmentation import ImageSegmentationModel\n from giskard.models.image_generation import ImageGenerationModel\n from giskard.models.image_style_transfer import ImageStyleTransferModel\n from giskard.models.image_super_resolution import ImageSuperResolutionModel\n from giskard.models.image_denoising import ImageDenoisingModel\n from giskard.models.image_colorization import ImageColorizationModel\n from giskard.models.image_deraining import ImageDerainingModel\n from giskard.models.image_deblurring import ImageDeblurringModel\n from giskard.models.image_edge_detection import ImageEdgeDetectionModel\n from giskard.models.image_face_detection import ImageFaceDetectionModel\n from giskard.models.image_face_recognition import ImageFaceRecognitionModel\n from giskard.models.image_object_detection import ImageObjectDetectionModel\n from giskard.models.image_object_recognition import ImageObjectRecognitionModel\n from giskard.models.image_segmentation_mask_rcnn import ImageSegmentationMaskRCNNModel\n from giskard.models.image_segmentation_yolo import ImageSegmentationYOLOModel\n from giskard.models.image_segmentation_faster_rcnn import ImageSegmentationFasterRCNNModel\n from giskard.models.image_segmentation_unet import ImageSegmentationUNetModel\n from giskard.models.image_segmentation_deeplab import ImageSegmentationDeepLabModel\n from giskard.models.image_segmentation_fcn import ImageSegmentationFCNModel\n from giskard.models.image_segmentation_pspnet import ImageSegmentationPSPNetModel\n from giskard.models.image_segmentation_linknet import ImageSegmentationLinkNetModel\n from giskard.models.image_segmentation_unet_plus import ImageSegmentationUNetPlusModel\n from giskard.models.image_segmentation_deeplab_v3 import ImageSegmentationDeepLabV3Model\n from giskard.models.image_segmentation_deeplab_v3_plus import ImageSegmentationDeepLabV3PlusModel\n from giskard.models.image_segmentation_fcn_resnet101 import ImageSegmentationFCNResNet101Model\n from giskard.models.image_segmentation_fcn_resnet50 import ImageSegmentationFCNResNet50Model\n from giskard.models.image_segmentation_fcn_mobilenet_v2 import ImageSegmentationFCNMobileNetV2Model\n from giskard.models.image_segmentation_fcn_mobilenet_v3 import ImageSegmentationFCNMobileNetV3Model\n from giskard.models.image_segmentation_fcn_efficientnet_b" + ], + "line": 40, + "token": 316, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\ndef test_rounded_date_columns():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rounded_date_columns", + "reference": "def test_rounded_date_columns():\n df = pl.DataFrame(\n [\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.000000\", \"exp\": True},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.123456\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:01.000000\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00\", \"exp\": True},\n ]\n )\n\n df = df.with_columns(df[\"a\"].str.to_date().alias(\"a_dt\"))\n actual = columns_equal(df[\"a_dt\"], df[\"b\"])\n expected = df[\"exp\"]\n assert_series_equal(actual, expected, check_names=False)", + "completions": [ + "```python\ndef test_rounded_date_columns():\n # given\n col_1 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n col_2 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n\n # when\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n\n # then\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef annotations_to_instances(annos, image_size, mask_format=\"polygon\"):\n boxes = [BoxMode.convert(obj[\"bbox\"], obj[\"bbox_mode\"], BoxMode.XYXY_ABS) for obj in annos]\n target = Instances(image_size)\n boxes = target.gt_boxes = Boxes(boxes)\n boxes.clip(image_size)\n\n classes = [obj[\"category_id\"] for obj in annos]\n classes = torch.tensor(classes, dtype=torch.int64)\n target.gt_classes = classes\n\n if len(annos) and \"segmentation\" in annos[0]:\n segms = [obj[\"segmentation\"] for obj in annos]\n if mask_format == \"polygon\":\n masks = PolygonMasks(segms)\n else:\n assert mask_format == \"bitmask\", mask_format\n masks = []\n for segm in segms:\n if isinstance(segm, list):\n # polygon\n masks.append(polygons_to_bitmask(segm, *image_size))\n elif isinstance(segm, dict):\n # COCO RLE\n masks.append(mask_util.decode(segm))\n elif isinstance(segm, np.ndarray):\n assert segm.ndim == 2, \"Expect segmentation of 2 dimensions, got {}.\".format(\n segm.ndim\n )\n # mask array\n masks.append(segm)\n else:\n raise ValueError(\n \"Cannot convert segmentation of type '{}' to BitMasks!\"\n \"Supported types are: polygons as list[list[float] or ndarray],\"\n \" COCO-style RLE as a dict, or a full-image segmentation mask \"\n \"as a 2D ndarray.\".format(type(segm))\n )\n # torch.from_numpy does not support array with negative stride.\n masks = BitMasks(\n torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])\n )\n target.gt_masks = masks\n\n if len(annos) and \"keypoints\" in annos[0]:\n kpts = [obj.get(\"keypoints\", []) for obj in annos]\n target.gt_keypoints = Keypoints(kpts)\n\n return target\n```\n###test function signature:\n```python\n\n def test_transform_simple_annotation(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_simple_annotation", + "reference": "\n def test_transform_simple_annotation(self):\n transforms = T.TransformList([T.HFlipTransform(400)])\n anno = {\n \"bbox\": np.asarray([10, 10, 200, 300]),\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 3,\n \"segmentation\": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],\n }\n\n output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))\n self.assertTrue(np.allclose(output[\"bbox\"], [200, 10, 390, 300]))\n self.assertEqual(len(output[\"segmentation\"]), len(anno[\"segmentation\"]))\n self.assertTrue(np.allclose(output[\"segmentation\"][0], [390, 10, 300, 100, 300, 10]))\n\n detection_utils.annotations_to_instances([output, output], (400, 400))", + "completions": [ + "```python\n def test_transform_simple_annotation(self):\n annos = [\n {\n \"bbox\": [10, 10, 100, 100],\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 0,\n \"segmentation\": [\n [10, 10, 50, 50, 90, 90, 100, 100],\n [20, 20, 50, 50, 80, 80, 100, 100],\n ],\n \"keypoints\": [10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n }\n ]\n image_size = (100, 100)\n target = annotations_to_instances(annos, image_size)\n assert target.gt_boxes.tensor.tolist() == [[10, 10, 100, 100]]\n assert target.gt_classes.tolist() == [0]\n assert target.gt_masks.polygons().tolist() == [\n [[10, 10], [50, 50], [90, 90], [100, 100]],\n [[20, 20], [50, 50], [80, 80], [100, 100]],\n ]\n assert target.gt_keypoints.tensor.tolist() == [[10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n```\n" + ], + "line": 54, + "token": 437, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 29, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_simple(seed):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_simple", + "reference": "\ndef test_sabre_simple(seed):\n placer = Trivial()\n circ = Circuit(5)\n circ.add(gates.CZ(0, 1))\n initial_layout = placer(circ)\n router = Sabre(connectivity=star_connectivity(), seed=seed)\n routed_circuit, final_map = router(circuit=circ, initial_layout=initial_layout)\n assert router.added_swaps == 1\n assert final_map == {\"q0\": 2, \"q1\": 1, \"q2\": 0, \"q3\": 3, \"q4\": 4}\n assert routed_circuit.queue[0].qubits == (0, 2)\n assert isinstance(routed_circuit.queue[0], gates.SWAP)\n assert isinstance(routed_circuit.queue[1], gates.CZ)\n assert_connectivity(circuit=routed_circuit, connectivity=star_connectivity())\n assert_circuit_equivalence(\n original_circuit=circ,\n transpiled_circuit=routed_circuit,\n final_map=final_map,\n initial_map=initial_layout,\n )", + "completions": [ + "```python\ndef test_sabre_simple(seed):\n import random\n random.seed(seed)\n import networkx as nx\n from pyquil.quil import Program\n from pyquil.api import QVMConnection\n from pyquil.gates import CNOT, H, MEASURE\n from pyquil.quilbase import Declare\n from pyquil.quilbase import DefGate\n from pyquil.quilbase import DefCircuit\n from pyquil.quilbase import Gate\n from pyquil.quilbase import Circuit\n from pyquil.quilbase import QubitPlaceholder\n from pyquil.quilbase import MemoryReference\n from pyquil.quilbase import Parameter\n from pyquil.quilbase import Pragma\n from pyquil.quilbase import PragmaDeclare\n from pyquil.quilbase import PragmaDefGate\n from pyquil.quilbase import PragmaDefCircuit\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefQubit\n from pyquil.quilbase import PragmaDefMemory\n from pyquil.quilbase import PragmaDefRegister\n from pyquil.quilbase import PragmaDefLabel\n from pyquil.quilbase import PragmaDefFunction\n from pyquil.quilbase import PragmaDefProcedure\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyqu" + ], + "line": 27, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 30, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef unique_patterns_2d(\n agrid: NDArray[np.int64], ksize: int, periodic_input: bool\n) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:\n assert ksize >= 1\n if periodic_input:\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n else:\n # TODO: implement non-wrapped image handling\n # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n\n patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(\n agrid,\n (\n agrid.shape[0] - ksize + 1,\n agrid.shape[1] - ksize + 1,\n ksize,\n ksize,\n *agrid.shape[2:],\n ),\n agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],\n writeable=False,\n )\n patch_codes = hash_downto(patches, 2)\n uc, ui = np.unique(patch_codes, return_index=True)\n locs = np.unravel_index(ui, patch_codes.shape)\n up: NDArray[np.int64] = patches[locs[0], locs[1]]\n ids: NDArray[np.int64] = np.vectorize(\n {code: ind for ind, code in enumerate(uc)}.get\n )(patch_codes)\n return ids, up, patch_codes\n```\n###test function signature:\n```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_patterns_2d", + "reference": "\ndef test_unique_patterns_2d(img_redmaze) -> None:\n img = img_redmaze\n tile_size = 1\n pattern_width = 2\n _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(\n img, tile_size\n )\n\n (\n _patterns_in_grid,\n pattern_contents_list,\n patch_codes,\n ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)\n assert patch_codes[1][2] == 4867810695119132864\n assert pattern_contents_list[7][1][1] == 8253868773529191888", + "completions": [ + "```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n ids, up, patch_codes = unique_patterns_2d(img_redmaze, 3, True)\n assert ids.shape == (10,)\n assert up.shape == (10, 3, 3)\n assert patch_codes.shape == (10,)\n```\n" + ], + "line": 48, + "token": 425, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 31, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef serialize(data: Any, encoding: str | None = None) -> str | bytes:\n serialized: bytes | str\n\n if encoding is not None:\n serialized = yaml.dump(data, encoding=encoding, Dumper=AiiDADumper)\n else:\n serialized = yaml.dump(data, Dumper=AiiDADumper)\n\n return serialized\n```\n###test function signature:\n```python\ndef test_mixed_attribute_normal_dict():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mixed_attribute_normal_dict", + "reference": "def test_mixed_attribute_normal_dict():\n from aiida.common.extendeddicts import AttributeDict\n\n # Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively\n dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}\n attribute_dict = AttributeDict(dictionary)\n\n # Now add a normal dictionary in the attribute dictionary\n attribute_dict['nested']['normal'] = {'a': 2}\n\n serialized = serialize.serialize(attribute_dict)\n deserialized = serialize.deserialize_unsafe(serialized)\n\n assert attribute_dict, deserialized", + "completions": [ + "```python\ndef test_mixed_attribute_normal_dict():\n data = {\n \"a\": 1,\n \"b\": 2,\n \"c\": 3,\n \"d\": 4,\n \"e\": 5,\n \"f\": 6,\n \"g\": 7,\n \"h\": 8,\n \"i\": 9,\n \"j\": 10,\n \"k\": 11,\n \"l\": 12,\n \"m\": 13,\n \"n\": 14,\n \"o\": 15,\n \"p\": 16,\n \"q\": 17,\n \"r\": 18,\n \"s\": 19,\n \"t\": 20,\n \"u\": 21,\n \"v\": 22,\n \"w\": 23,\n \"x\": 24,\n \"y\": 25,\n \"z\": 26,\n }\n assert serialize(data) == \"\"\"a: 1\nb: 2\nc: 3\nd: 4\ne: 5\nf: 6\ng: 7\nh: 8\ni: 9\nj: 10\nk: 11\nl: 12\nm: 13\nn: 14\no: 15\np: 16\nq: 17\nr: 18\ns: 19\nt: 20\nu: 21\nv: 22\nw: 23\nx: 24\ny: 25\nz: 26\n\"\"\"\n```\n" + ], + "line": 17, + "token": 133, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected_against_alias", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n http_client.select_model(model_id=\"yolov8n-640\")\n\n # when\n result = http_client.unload_model(model_id=\"yolov8n-640\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"coco/3\",\n }\n assert (\n http_client.selected_model is None\n ), \"Even when alias is in use - selected model should be emptied\"", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n requests_mock.post(\n f\"{API_URL}/model/remove\",\n json={\n \"model_id\": \"test_model_id\",\n },\n headers=DEFAULT_HEADERS,\n )\n client = RoboflowClient(api_key=\"test_api_key\", api_secret=\"test_api_secret\")\n # when\n response = client.unload_model(model_id=\"test_model_alias\")\n # then\n assert response.model_id == \"test_model_id\"\n assert response.model_name == \"test_model_name\"\n assert response.model_version == \"test_model_version\"\n assert response.model_type == \"test_model_type\"\n assert response.model_status == \"test_model_status\"\n assert response.model_created_at == \"test_model_created_at\"\n assert response.model_updated_at == \"test_model_updated_at\"\n assert response.model_tags == \"test_model_tags\"\n assert response.model_description == \"test_model_description\"\n assert response.model_url == \"test_model_url\"\n assert response.model_thumbnail_url == \"test_model_thumbnail_url\"\n assert response.model_training_data_url == \"test_model_training_data_url\"\n assert response.model_training_data_type == \"test_model_training_data_type\"\n assert response.model_training_data_size == \"test_model_training_data_size\"\n assert response.model_training_data_description == \"test_model_training_data_description\"\n assert response.model_training_data_tags == \"test_model_training_data_tags\"\n assert response.model_training_data_created_at == \"test_model_training_data_created_at\"\n assert response.model_training_data_updated_at == \"test_model_training_data_updated_at\"\n assert response.model_training_data_status == \"test_model_training_data_status\"\n assert response.model_training_data_status_message == \"test_model_training_data_status_message\"\n assert response.model_training_data_status_details == \"test_model_training_data_status_details\"\n assert response.model_training_data_status_created_at == \"test_model_training_data_status_created_at\"\n assert response.model_training_data_status_updated_at == \"test_model_training_data_status_updated_at\"\n assert response.model_training_data_status_completed_at == \"test_model_training_data_status_completed_at\"\n assert response.model_training_data_status_duration == \"test_model_training_data_status_duration\"\n assert response.model_training_data_status_budget_type == \"test_model_training_data_status_budget_type\"\n assert response.model_training_data_status_budget_amount == \"test_model_training_data_status_budget_amount\"\n assert response.model_training_data_status_budget_unit == \"test_model_training_data_status_budget_unit\"\n assert response.model_training_data_status_budget_remaining == \"test_model_training_data_status_budget_remaining\"\n assert response.model_training_data_status_budget_message == \"test_model_training_data_status_budget_message\"\n assert response.model_training_data_status_budget_details == \"test_model_training_data_status_budget_details\"\n assert response.model_training_data_status_budget_created_at == \"test_model_training_data_status_budget_created_at\"\n assert response.model_training_data_status_budget_updated_at == \"test_" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 33, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef edge_index_to_vector(\n edge_index: Tensor,\n size: Tuple[int, int],\n bipartite: bool,\n force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n row, col = edge_index\n\n if bipartite: # No need to account for self-loops.\n idx = (row * size[1]).add_(col)\n population = size[0] * size[1]\n return idx, population\n\n elif force_undirected:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We only operate on the upper triangular matrix:\n mask = row < col\n row, col = row[mask], col[mask]\n offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n idx = row.mul_(num_nodes).add_(col).sub_(offset)\n population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n return idx, population\n\n else:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We remove self-loops as we do not want to take them into account\n # when sampling negative values.\n mask = row != col\n row, col = row[mask], col[mask]\n col[row < col] -= 1\n idx = row.mul_(num_nodes - 1).add_(col)\n population = num_nodes * num_nodes - num_nodes\n return idx, population\n```\n###test function signature:\n```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_edge_index_to_vector_and_vice_versa", + "reference": "\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n N = 10\n row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n edge_index = torch.stack([row, col], dim=0)\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n assert population == N * N\n assert idx.tolist() == list(range(population))\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n assert is_undirected(edge_index2)\n assert edge_index.tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n assert population == N * N - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n force_undirected=True)\n assert population == (N * (N + 1)) / 2 - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n force_undirected=True)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()", + "completions": [ + "```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n\n # Test force_undirected:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test force_undirected with bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n```\n" + ], + "line": 43, + "token": 351, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 34, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_integration_test(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_integration_test", + "reference": "\n def test_spectrogram_integration_test(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,\n 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,\n 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,\n 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,\n 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,\n 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,\n 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,\n 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,\n 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,\n 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,\n 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,\n 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,\n 0.0293578 , 0.03452379, 0.02194803, 0.01676056,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=256,\n num_mel_filters=400,\n min_frequency=20,\n max_frequency=8000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"kaldi\",\n triangularize_in_mel_space=True,\n )\n\n mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"povey\", periodic=False),\n frame_length=400,\n hop_length=160,\n fft_length=512,\n power=2.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n preemphasis=0.97,\n mel_filters=mel_filters,\n log_mel=\"log\",\n mel_floor=1.1920928955078125e-07,\n remove_dc_offset=True,\n )\n self.assertEqual(spec.shape, (400, 584))\n\n # fmt: off\n expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -6.52463769, -7.73677889, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.18650018, -3.37195286,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -4.70190154, -2.4217066 , -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -5.62755239, -3.53385194,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -9.43303023, -8.77480925, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.2951092 , -5.51585994,\n -15.94238515, -15.94238515, -15.94238515, -4.40151721,\n -3.95228878, -15.94238515, -15.94238515, -15.94238515,\n -6.10365415, -4.59494697, -15.94238515, -15.94238515,\n -15.94238515, -8.10727767, -6.2585298 , -15.94238515,\n -15.94238515, -15.94238515, -5.60161702, -4.47217004,\n -15.94238515, -15.94238515, -15.94238515, -5.91641988]\n )\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))", + "completions": [ + "```python\n def test_spectrogram_integration_test(self):\n waveform = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])\n window = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n expected_spectrogram = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]])\n\n self.assertTrue(np.allclose(spectrogram, expected_spectrogram))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 35, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef run_query_expansion_node(modules: List[Callable],\n module_params: List[Dict],\n previous_result: pd.DataFrame,\n node_line_dir: str,\n strategies: Dict,\n ) -> pd.DataFrame:\n if not os.path.exists(node_line_dir):\n os.makedirs(node_line_dir)\n node_dir = os.path.join(node_line_dir, \"query_expansion\")\n if not os.path.exists(node_dir):\n os.makedirs(node_dir)\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n\n # run query expansion\n results, execution_times = zip(*map(lambda task: measure_speed(\n task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))\n average_times = list(map(lambda x: x / len(results[0]), execution_times))\n\n # save results to folder\n pseudo_module_params = deepcopy(module_params)\n for i, module_param in enumerate(pseudo_module_params):\n if 'prompt' in module_params:\n module_param['prompt'] = str(i)\n filepaths = list(map(lambda x: os.path.join(node_dir, f'{x}.parquet'), range(len(modules))))\n list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet\n filenames = list(map(lambda x: os.path.basename(x), filepaths))\n\n # make summary file\n summary_df = pd.DataFrame({\n 'filename': filenames,\n 'module_name': list(map(lambda module: module.__name__, modules)),\n 'module_params': module_params,\n 'execution_time': average_times,\n })\n\n # Run evaluation when there are more than one module.\n if len(modules) > 1:\n # pop general keys from strategies (e.g. metrics, speed_threshold)\n general_key = ['metrics', 'speed_threshold']\n general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))\n extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))\n\n # first, filter by threshold if it is enabled.\n if general_strategy.get('speed_threshold') is not None:\n results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],\n filenames)\n\n # check metrics in strategy\n if general_strategy.get('metrics') is None:\n raise ValueError(\"You must at least one metrics for query expansion evaluation.\")\n\n if extra_strategy.get('top_k') is None:\n extra_strategy['top_k'] = 10 # default value\n\n # get retrieval modules from strategy\n retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)\n\n # get retrieval_gt\n retrieval_gt = pd.read_parquet(os.path.join(project_dir, \"data\", \"qa.parquet\"))['retrieval_gt'].tolist()\n\n # run evaluation\n evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(\n retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,\n general_strategy['metrics'], project_dir, previous_result), results))\n\n evaluation_df = pd.DataFrame({\n 'filename': filenames,\n **{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))\n for metric_name in general_strategy['metrics']}\n })\n summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')\n\n best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)\n # change metric name columns to query_expansion_metric_name\n best_result = best_result.rename(columns={\n metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})\n best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])\n else:\n best_result, best_filename = results[0], filenames[0]\n best_result = pd.concat([previous_result, best_result], axis=1)\n\n # add 'is_best' column at summary file\n summary_df['is_best'] = summary_df['filename'] == best_filename\n\n # save files\n summary_df.to_csv(os.path.join(node_dir, \"summary.csv\"), index=False)\n best_result.to_parquet(os.path.join(node_dir, f\"best_{os.path.splitext(best_filename)[0]}.parquet\"), index=False)\n\n return best_result\n```\n###test function signature:\n```python\n\ndef test_run_query_expansion_node_default(node_line_dir):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_query_expansion_node_default", + "reference": "\ndef test_run_query_expansion_node_default(node_line_dir):\n generator_models['mock'] = MockLLM\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n qa_path = os.path.join(project_dir, \"data\", \"qa.parquet\")\n previous_result = pd.read_parquet(qa_path)\n\n modules = [query_decompose, hyde]\n module_params = [{'llm': \"mock\", 'temperature': 0.2, 'batch': 7}, {'llm': \"mock\"}]\n strategies = {\n 'metrics': metrics\n }\n best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n base_query_expansion_test(best_result, node_line_dir)", + "completions": [ + "```python\ndef test_run_query_expansion_node_default(node_line_dir):\n modules = [query_expansion_module]\n module_params = [{'prompt': 'default'}]\n previous_result = pd.DataFrame({'query': ['a', 'b', 'c']})\n strategies = {'metrics': ['recall', 'precision', 'f1'], 'speed_threshold': 1000}\n run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n```\n" + ], + "line": 84, + "token": 850, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 36, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(self) -> Tuple[pd.DataFrame, FreqDetectionResult]:\n try:\n if self._prophet_detector_params[\"suppress_stan\"]:\n pd.set_option(\"mode.chained_assignment\", None)\n\n # Skip measurements based on feedbacks given from SODA Cloud\n preprocessed_df = self.preprocess(time_series_df=self.raw_time_series_df)\n\n # Automatically detect frequency of the time series\n freq_detector = FrequencyDetector(\n logs=self.logs,\n params=self.params,\n time_series_df=preprocessed_df,\n manual_freq=self.training_dataset_params.frequency,\n )\n freq_detection_result = freq_detector.detect_frequency()\n\n # Return if frequency detection failed\n if freq_detection_result.error_code_int >= ERROR_CODE_LEVEL_CUTTOFF:\n return self.exit_with_warning(freq_detection_result)\n\n # Apply training dataset configurations\n training_df = self.apply_training_dataset_configs(\n time_series_df=preprocessed_df, freq_detection_result=freq_detection_result\n )\n\n # Remove big gaps from the time series to not confuse Prophet\n training_df = self.remove_big_gaps_from_time_series(\n time_series_df=training_df, freq_detection_result=freq_detection_result\n )\n\n # Only use the last n points for training based on the window length\n window_length = self.get_window_length(training_df=training_df)\n training_df = training_df.iloc[-window_length:]\n\n training_df_shape = training_df[\"y\"].dropna().shape[0]\n if training_df_shape <= self._min_n_points:\n freq_detection_result = get_not_enough_measurements_freq_result(n_data_points=training_df_shape)\n return self.exit_with_warning(freq_detection_result)\n\n model_hyperparameters = self.get_prophet_hyperparameters(time_series_df=training_df)\n\n predictions_df = self.setup_fit_predict(\n time_series_df=training_df, model_hyperparameters=model_hyperparameters\n )\n anomalies_df = self.detect_anomalies(time_series_df=training_df, predictions_df=predictions_df)\n anomalies_df = self.generate_severity_zones(anomalies_df=anomalies_df)\n anomalies_df = self.compute_alert_level(anomalies_df=anomalies_df)\n return anomalies_df, freq_detection_result\n except Exception as e:\n raise e\n```\n###test function signature:\n```python\n\ndef test_with_exit() -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_exit", + "reference": "\ndef test_with_exit() -> None:\n time_series_df = generate_random_dataframe(size=2, n_rows_to_convert_none=0, frequency=\"D\")\n detector = ProphetDetector(\n logs=LOGS,\n params=PARAMS,\n time_series_df=time_series_df,\n model_cfg=ModelConfigs(),\n training_dataset_params=TrainingDatasetParameters(),\n severity_level_params=SeverityLevelParameters(),\n )\n df_anomalies, frequency_result = detector.run()\n assert df_anomalies.empty\n assert frequency_result.error_code_int == 100\n assert frequency_result.freq_detection_strategy == \"not_enough_measurements_custom\"", + "completions": [ + "```python\ndef test_with_exit() -> None:\n from soda.sodacl.checks.check_frequency import FrequencyDetector\n from soda.sodacl.checks.check_frequency import get_not_enough_measurements_freq_result\n from soda.sodacl.checks.check_frequency import ERROR_CODE_LEVEL_CUTTOFF\n from soda.sodacl.checks.check_frequency import FreqDetectionResult\n from soda.sodacl.checks.check_frequency import CheckFrequency\n from soda.sodacl.checks.check_frequency import CheckFrequencyParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyLogs\n from soda.sodacl.checks.check_frequency import CheckFrequencyTrainingDatasetParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyPro" + ], + "line": 52, + "token": 310, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 37, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef average_transforms(\n R: torch.Tensor,\n t: torch.Tensor,\n w: torch.Tensor,\n mask: torch.Tensor,\n dim: int,\n t_edge: Optional[torch.Tensor] = None,\n dither: Optional[bool] = True,\n dither_eps: float = 1e-4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert dim >= 0, \"dimension must index from the left\"\n w = torch.where(\n mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)\n )\n\n # We use different averaging models based on the number of weights\n num_transform_weights = w.size(-1)\n if num_transform_weights == 1:\n # Share a single scalar weight between t and R.\n probs = w.softmax(dim)\n t_probs = probs\n R_probs = probs[..., None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 2:\n # Use separate scalar weights for each of t and R.\n probs = w.softmax(dim)\n t_probs, R_probs = probs.unbind(-1)\n t_probs = t_probs[..., None]\n R_probs = R_probs[..., None, None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 3:\n # For R use a signed scalar weight.\n R_probs = w[..., 2].softmax(dim)[..., None, None]\n\n # For t use a two-parameter precision matrix P = P_isometric + P_radial.\n # We need to hand compute softmax over the shared dim x 2 elements.\n w_t = w[..., :2]\n w_t_total = w_t.logsumexp([dim, -1], True)\n p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)\n\n # Use Gaussian fusion for translation.\n t_edge = t_edge * mask.to(t_edge.dtype)[..., None]\n t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)\n else:\n raise NotImplementedError\n\n # Average rotation via SVD\n R_avg_unc = (R * R_probs).sum(dim)\n R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)\n U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)\n R_avg = U @ Vh\n\n # Enforce that matrix is rotation matrix\n d = torch.linalg.det(R_avg)\n d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)\n Vh = Vh * d_expand\n R_avg = U @ Vh\n return R_avg, t_avg\n```\n###test function signature:\n```python\n\ndef test_average_transforms(rotations, translations):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_average_transforms", + "reference": "\ndef test_average_transforms(rotations, translations):\n R = torch.stack([rotations[0], torch.eye(3)])\n t = torch.stack([translations[0], torch.zeros(3)])\n w = torch.ones(2, 2)\n mask = torch.ones(2)\n # average of a transform with the identity is \"half\" the transform\n R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)\n R_total_fromavg, _ = compose_transforms(\n R_avg, torch.zeros(3), R_avg, torch.zeros(3)\n )\n _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)\n assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)\n assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)", + "completions": [ + "```python\ndef test_average_transforms(rotations, translations):\n # test case 1\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 2\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 3\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 4\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 5\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 6\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 7\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 8\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 9\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w" + ], + "line": 65, + "token": 537, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n return GPT2Config(\n vocab_size=btlm_config.vocab_size,\n n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n n_embd=btlm_config.hidden_size,\n n_layer=btlm_config.num_hidden_layers,\n n_head=btlm_config.num_attention_heads,\n n_inner=btlm_config.n_inner,\n activation_function=btlm_config.activation_function,\n resid_pdrop=btlm_config.resid_pdrop,\n embd_pdrop=btlm_config.embd_pdrop,\n attn_pdrop=btlm_config.attn_pdrop,\n layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n initializer_range=btlm_config.initializer_range,\n bos_token_id=btlm_config.bos_token_id,\n eos_token_id=btlm_config.eos_token_id,\n # These are new arguments not in the original GPT2Config\n use_alibi=btlm_config.position_embedding_type == \"alibi\",\n use_flash_attn=btlm_config.position_embedding_type == \"alibi\", # Alibi code path requires flash_attn\n mup_width_scale=btlm_config.mup_width_scale,\n mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n mup_output_multiplier=btlm_config.mup_output_alpha,\n mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n mlp_multiple_of=1,\n )\n```\n###test function signature:\n```python\ndef test_btlm_optimized(model_name):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_btlm_optimized", + "reference": "def test_btlm_optimized(model_name):\n dtype = torch.float16\n device = \"cuda\"\n config = btlm_config_to_gpt2_config(\n AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n )\n config.fused_bias_fc = True\n config.fused_dropout_add_ln = True\n config.residual_in_fp32 = True\n\n pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n model = GPTLMHeadModel(config, device=device, dtype=dtype)\n model.load_state_dict(pretrained_state_dict)\n model.eval()\n\n torch.manual_seed(0)\n batch_size = 2\n max_seqlen = 256\n seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n input_ids = torch.randint(\n 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n )\n with torch.no_grad():\n out = model.transformer(input_ids)\n logits = model(input_ids).logits\n del model\n\n # Without device_map, the model is loaded on the CPU, which is very slow\n # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n model_ref = AutoModelForCausalLM.from_pretrained(\n model_name, device_map=\"auto\", trust_remote_code=True\n )\n model_ref.eval()\n with torch.no_grad():\n out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n logits_ref = model_ref(input_ids).logits.to(device=device)\n del model_ref\n\n model_hf = AutoModelForCausalLM.from_pretrained(\n model_name,\n torch_dtype=dtype,\n device_map={\"\": device},\n trust_remote_code=True,\n )\n model_hf.eval()\n with torch.no_grad():\n out_hf = model_hf.transformer(input_ids).last_hidden_state\n logits_hf = model_hf(input_ids).logits\n del model_hf\n\n print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n assert (logits - logits_ref).abs().max().item() < 3 * (\n logits_hf - logits_ref\n ).abs().max().item()", + "completions": [ + "```python\ndef test_btlm_optimized(model_name):\n from transformers import GPT2Config, PretrainedConfig\n from bert_extended.optimized import btlm_config_to_gpt2_config\n\n config = PretrainedConfig.from_pretrained(model_name)\n gpt2_config = btlm_config_to_gpt2_config(config)\n gpt2_config.to_json_file(f\"{model_name}_optimized.json\")\n```\n" + ], + "line": 35, + "token": 169, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 39, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef grad_solid_harmonic(r: jnp.ndarray, l_max: int) -> jnp.ndarray:\n def cartesian_sum(r, ell):\n # computes all cartesian sums of order ell\n # i.e. all terms (a*x + b*y + c*z) for a+b+c = ell\n if ell == 0:\n return np.array([0.])\n return np.sum(np.stack(\n np.meshgrid(*[r]*ell, indexing='ij'), axis=0).reshape(ell, -1), axis=0)\n\n dharmonics = []\n for ell in range(l_max + 1):\n dx = cartesian_sum(np.array([1., 0., 0.]), ell)\n dy = cartesian_sum(np.array([0., 1., 0.]), ell)\n dz = cartesian_sum(np.array([0., 0., 1.]), ell)\n d = np.stack([dx, dy, dz], axis=1)\n cp = jax.vmap(functools.partial(cartesian_product, ell=ell), out_axes=-1)\n dcp = jnp.where(r.T[None, :, :] == 0, 0,\n d[:, :, None]*cp(r)[:, None, :] / r.T[None, :, :])\n sh = jnp.einsum('cs,cxn->snx', full_cart2sph(ell, True), dcp)\n pad = l_max - ell\n sh = jnp.pad(sh, [[pad, pad], [0, 0], [0, 0]])\n dharmonics.append(sh)\n\n return jnp.stack(dharmonics, axis=1) # [L, M, grid, xyz]\n```\n###test function signature:\n```python\n\n def test_grad_solid_harmonic(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_grad_solid_harmonic", + "reference": "\n def test_grad_solid_harmonic(self):\n np.random.seed(0)\n r = np.random.randn(100, 3)\n l_max = 5\n\n jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max))\n expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])),\n [1, 2, 0, 3])\n\n with self.subTest('by hand'):\n observed = gto.grad_solid_harmonic(r, l_max)\n np.testing.assert_allclose(observed, expected, atol=1.e-4)\n with self.subTest('by jax'):\n observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max)\n np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4)", + "completions": [ + "```python\n def test_grad_solid_harmonic(self):\n r = jnp.array([[0.5, 0.5, 0.5]])\n l_max = 2\n dharmonics = grad_solid_harmonic(r, l_max)\n self.assertEqual(dharmonics.shape, (l_max + 1, l_max + 1, 1, 3))\n self.assertTrue(jnp.allclose(dharmonics[0, 0, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 0, 0, :], jnp.array([1., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 1, 0, :], jnp.array([0., 1., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 0, 0, :], jnp.array([0., 0., 1.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 1, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 2, 0, :], jnp.array([0., 0., 0.])))\n```\n" + ], + "line": 32, + "token": 389, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 40, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_basis(\n base_shape, angular_tesselation, remove_symmetries=True, eps=1e-4\n):\n\n if base_shape == 'tetrahedron':\n verts = np.array([\n (np.sqrt(8 / 9), 0, -1 / 3),\n (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),\n (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3),\n (0, 0, 1),\n ])\n faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)])\n elif base_shape == 'icosahedron':\n a = (np.sqrt(5) + 1) / 2\n verts = np.array([\n (-1, 0, a),\n (1, 0, a),\n (-1, 0, -a),\n (1, 0, -a),\n (0, a, 1),\n (0, a, -1),\n (0, -a, 1),\n (0, -a, -1),\n (a, 1, 0),\n (-a, 1, 0),\n (a, -1, 0),\n (-a, -1, 0),\n ]) / np.sqrt(a + 2)\n faces = np.array([\n (0, 4, 1),\n (0, 9, 4),\n (9, 5, 4),\n (4, 5, 8),\n (4, 8, 1),\n (8, 10, 1),\n (8, 3, 10),\n (5, 3, 8),\n (5, 2, 3),\n (2, 7, 3),\n (7, 10, 3),\n (7, 6, 10),\n (7, 11, 6),\n (11, 0, 6),\n (0, 1, 6),\n (6, 1, 10),\n (9, 0, 11),\n (9, 11, 2),\n (9, 2, 5),\n (7, 2, 11),\n ])\n elif base_shape == 'octahedron':\n verts = np.array(\n [(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]\n )\n corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n else:\n raise ValueError(f'base_shape {base_shape} not supported')\n verts = tesselate_geodesic(verts, faces, angular_tesselation)\n\n if remove_symmetries:\n # Remove elements of `verts` that are reflections of each other.\n match = compute_sq_dist(verts.T, -verts.T) < eps\n verts = verts[~np.any(np.triu(match), axis=0), :]\n\n basis = verts[:, ::-1]\n return basis\n```\n###test function signature:\n```python\n def test_generate_basis_golden(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_basis_golden", + "reference": " def test_generate_basis_golden(self):\n\n basis = geopoly.generate_basis('tetrahedron', 1)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('tetrahedron', 2)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.57735027, 0.00000000, -0.81649658],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.57735027, -0.70710678, 0.40824829],\n [-0.57735027, 0.70710678, 0.40824829],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('icosahedron', 2)\n basis_golden = np.array([\n [0.85065081, 0.00000000, 0.52573111],\n [0.80901699, 0.50000000, 0.30901699],\n [0.52573111, 0.85065081, 0.00000000],\n [1.00000000, 0.00000000, 0.00000000],\n [0.80901699, 0.50000000, -0.30901699],\n [0.85065081, 0.00000000, -0.52573111],\n [0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, -0.85065081],\n [0.50000000, 0.30901699, -0.80901699],\n [0.00000000, 1.00000000, 0.00000000],\n [-0.52573111, 0.85065081, 0.00000000],\n [-0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, 0.85065081],\n [-0.30901699, 0.80901699, 0.50000000],\n [0.30901699, 0.80901699, 0.50000000],\n [0.50000000, 0.30901699, 0.80901699],\n [0.50000000, -0.30901699, 0.80901699],\n [0.00000000, 0.00000000, 1.00000000],\n [-0.50000000, 0.30901699, 0.80901699],\n [-0.80901699, 0.50000000, 0.30901699],\n [-0.80901699, 0.50000000, -0.30901699],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('octahedron', 4)\n basis_golden = np.array([\n [0.00000000, 0.00000000, -1.00000000],\n [0.00000000, -0.31622777, -0.94868330],\n [0.00000000, -0.70710678, -0.70710678],\n [0.00000000, -0.94868330, -0.31622777],\n [0.00000000, -1.00000000, 0.00000000],\n [-0.31622777, 0.00000000, -0.94868330],\n [-0.40824829, -0.40824829, -0.81649658],\n [-0.40824829, -0.81649658, -0.40824829],\n [-0.31622777, -0.94868330, 0.00000000],\n [-0.70710678, 0.00000000, -0.70710678],\n [-0.81649658, -0.40824829, -0.40824829],\n [-0.70710678, -0.70710678, 0.00000000],\n [-0.94868330, 0.00000000, -0.31622777],\n [-0.94868330, -0.31622777, 0.00000000],\n [-1.00000000, 0.00000000, 0.00000000],\n [0.00000000, -0.31622777, 0.94868330],\n [0.00000000, -0.70710678, 0.70710678],\n [0.00000000, -0.94868330, 0.31622777],\n [0.40824829, -0.40824829, 0.81649658],\n [0.40824829, -0.81649658, 0.40824829],\n [0.31622777, -0.94868330, 0.00000000],\n [0.81649658, -0.40824829, 0.40824829],\n [0.70710678, -0.70710678, 0.00000000],\n [0.94868330, -0.31622777, 0.00000000],\n [0.31622777, 0.00000000, -0.94868330],\n [0.40824829, 0.40824829, -0.81649658],\n [0.40824829, 0.81649658, -0.40824829],\n [0.70710678, 0.00000000, -0.70710678],\n [0.81649658, 0.40824829, -0.40824829],\n [0.94868330, 0.00000000, -0.31622777],\n [0.40824829, -0.40824829, -0.81649658],\n [0.40824829, -0.81649658, -0.40824829],\n [0.81649658, -0.40824829, -0.40824829],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))", + "completions": [ + "```python\n def test_generate_basis_golden(self):\n basis = generate_basis(\n base_shape='tetrahedron', angular_tesselation=10, remove_symmetries=True, eps=1e-4\n )\n self.assertTrue(np.allclose(basis, np.array([[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0]])))\n```\n" + ], + "line": 75, + "token": 692, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 41, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:\n if not isinstance(t, torch.Tensor):\n t = torch.Tensor([t]).float()\n\n t_max = self.t_max.to(t.device)\n t_min = self.t_min.to(t.device)\n t_tilde = t_min + (t_max - t_min) * t\n\n return t_tilde\n```\n###test function signature:\n```python\n\ndef test_noise_schedule_diffusion_coeff(kind):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_noise_schedule_diffusion_coeff", + "reference": "\ndef test_noise_schedule_diffusion_coeff(kind):\n noise_schedule = GaussianNoiseSchedule(kind=kind, log_snr_range=(-6, 6))\n ts = torch.linspace(1e-2, 1 - 1e-2, 10)\n t_map = noise_schedule.t_map(ts) # map time to the prescribed log_SNR range\n\n if kind == \"log_snr\":\n g = noise_schedule.g(ts)\n\n # compute true beta_t\n l_range = noise_schedule.log_snr_range\n snr = torch.exp((1 - t_map) * l_range[1] + t_map * l_range[0])\n g_true = (-(l_range[0] - l_range[1]) / (snr + 1)).sqrt()\n\n assert torch.allclose(g, g_true, atol=1e-4)\n\n if kind == \"ot_linear\":\n g = noise_schedule.g(ts)\n g_true = (2.0 * t_map / (1.0 - t_map)).sqrt()\n assert torch.allclose(g, g_true, atol=1e-4)", + "completions": [ + "```python\ndef test_noise_schedule_diffusion_coeff(kind):\n from torch.testing import assert_allclose\n from torch.nn import functional as F\n from torch import tensor\n from torch.distributions import Uniform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import Transform\n from torch.distributions.transforms import SigmoidTransform\n from torch.distributions.transforms import TanhTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import ComposeTransform\n from torch.distributions.transforms import PowerTransform\n from torch.distributions.transforms import ExpTransform\n from torch.distributions.transforms import LogTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import" + ], + "line": 17, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 42, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_data_quality", + "reference": "\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index, temperature_data, temperature_mean=False, data_quality=True\n )\n assert df.shape == (27, 4)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_not_null\",\n \"temperature_null\",\n ]\n assert round(df.temperature_not_null.mean(), 2) == 729.23\n assert round(df.temperature_null.mean(), 2) == 0.0", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"1H\", tz=\"America/Chicago\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0" + ], + "line": 168, + "token": 984, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 43, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n\n # when\n result = http_client.unload_model(model_id=\"some/1\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"some/1\",\n }", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n model_id = \"my_model_id\"\n de_aliased_model_id = \"my_de_aliased_model_id\"\n response_payload = {\n \"model_id\": de_aliased_model_id,\n \"model_name\": \"my_model_name\",\n \"model_version\": \"my_model_version\",\n \"model_type\": \"my_model_type\",\n \"model_status\": \"my_model_status\",\n \"model_description\": \"my_model_description\",\n \"model_tags\": [\"my_model_tag1\", \"my_model_tag2\"],\n \"model_created_at\": \"my_model_created_at\",\n \"model_updated_at\": \"my_model_updated_at\",\n }\n requests_mock.post(\n f\"{self.__api_url}/model/remove\",\n json=response_payload,\n headers=DEFAULT_HEADERS,\n )\n # when\n response = self.client.unload_model(model_id=model_id)\n # then\n assert response.model_id == de_aliased_model_id\n assert response.model_name == \"my_model_name\"\n assert response.model_version == \"my_model_version\"\n assert response.model_type == \"my_model_type\"\n assert response.model_status == \"my_model_status\"\n assert response.model_description == \"my_model_description\"\n assert response.model_tags == [\"my_model_tag1\", \"my_model_tag2\"]\n assert response.model_created_at == \"my_model_created_at\"\n assert response.model_updated_at == \"my_model_updated_at\"\n assert self.client.selected_model is None\n```\n" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=20.0,\n max_iteration=100,\n distance_strict=False,\n seed=None,\n):\n rng = np.random.default_rng(seed=seed)\n units_locations = np.zeros((num_units, 3), dtype=\"float32\")\n\n minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um\n minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um\n\n units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)\n units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)\n units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)\n\n if minimum_distance is not None:\n solution_found = False\n renew_inds = None\n for i in range(max_iteration):\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n inds0, inds1 = np.nonzero(distances < minimum_distance)\n mask = inds0 != inds1\n inds0 = inds0[mask]\n inds1 = inds1[mask]\n\n if inds0.size > 0:\n if renew_inds is None:\n renew_inds = np.unique(inds0)\n else:\n # random only bad ones in the previous set\n renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]\n\n units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)\n units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)\n units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)\n else:\n solution_found = True\n break\n\n if not solution_found:\n if distance_strict:\n raise ValueError(\n f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} \"\n \"You can use distance_strict=False or reduce minimum distance\"\n )\n else:\n warnings.warn(f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}\")\n\n return units_locations\n```\n###test function signature:\n```python\n\ndef test_generate_unit_locations():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_unit_locations", + "reference": "\ndef test_generate_unit_locations():\n seed = 0\n\n probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)\n channel_locations = probe.contact_positions\n\n num_units = 100\n minimum_distance = 20.0\n unit_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=minimum_distance,\n max_iteration=500,\n distance_strict=False,\n seed=seed,\n )\n distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)\n dist_flat = np.triu(distances, k=1).flatten()\n dist_flat = dist_flat[dist_flat > 0]\n assert np.all(dist_flat > minimum_distance)\n\n # import matplotlib.pyplot as plt\n # fig, ax = plt.subplots()\n # ax.hist(dist_flat, bins = np.arange(0, 400, 10))\n # fig, ax = plt.subplots()\n # from probeinterface.plotting import plot_probe\n # plot_probe(probe, ax=ax)\n # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)\n # plt.show()", + "completions": [ + "```python\n\ndef test_generate_unit_locations():\n num_units = 10\n channel_locations = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])\n margin_um = 20.0\n minimum_z = 5.0\n maximum_z = 40.0\n minimum_distance = 20.0\n max_iteration = 100\n distance_strict = False\n seed = None\n\n units_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=margin_um,\n minimum_z=minimum_z,\n maximum_z=maximum_z,\n minimum_distance=minimum_distance,\n max_iteration=max_iteration,\n distance_strict=distance_strict,\n seed=seed,\n )\n\n assert units_locations.shape == (num_units, 3)\n assert np.all(units_locations[:, 0] >= margin_um)\n assert np.all(units_locations[:, 0] <= 100 - margin_um)\n assert np.all(units_locations[:, 1] >= margin_um)\n assert np.all(units_locations[:, 1] <= 100 - margin_um)\n assert np.all(units_locations[:, 2] >= minimum_z)\n assert np.all(units_locations[:, 2] <= maximum_z)\n\n # check distance\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n assert np.all(distances >= minimum_distance)\n\n # check no duplicate\n assert np.all(np.unique(units_locations, axis=0) == units_locations)\n```\n" + ], + "line": 58, + "token": 453, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 45, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def to_dataframe(\n self, index: Optional[Union[str, None]] = \"date\", sort_by: Optional[str] = None\n ) -> pd.DataFrame:\n\n def is_list_of_basemodel(items: Union[List[T], T]) -> bool:\n return isinstance(items, list) and all(\n isinstance(item, BaseModel) for item in items\n )\n\n if self.results is None or not self.results:\n raise OpenBBError(\"Results not found.\")\n\n if isinstance(self.results, pd.DataFrame):\n return self.results\n\n try:\n res = self.results\n df = None\n sort_columns = True\n\n # List[Dict]\n if isinstance(res, list) and len(res) == 1 and isinstance(res[0], dict):\n r = res[0]\n dict_of_df = {}\n\n for k, v in r.items():\n # Dict[str, List[BaseModel]]\n if is_list_of_basemodel(v):\n dict_of_df[k] = basemodel_to_df(v, index)\n sort_columns = False\n # Dict[str, Any]\n else:\n dict_of_df[k] = pd.DataFrame(v)\n\n df = pd.concat(dict_of_df, axis=1)\n\n # List[BaseModel]\n elif is_list_of_basemodel(res):\n dt: Union[List[Data], Data] = res # type: ignore\n df = basemodel_to_df(dt, index)\n sort_columns = False\n # List[List | str | int | float] | Dict[str, Dict | List | BaseModel]\n else:\n try:\n df = pd.DataFrame(res)\n # Set index, if any\n if index is not None and index in df.columns:\n df.set_index(index, inplace=True)\n\n except ValueError:\n if isinstance(res, dict):\n df = pd.DataFrame([res])\n\n if df is None:\n raise OpenBBError(\"Unsupported data format.\")\n\n # Drop columns that are all NaN, but don't rearrange columns\n if sort_columns:\n df.sort_index(axis=1, inplace=True)\n df = df.dropna(axis=1, how=\"all\")\n\n # Sort by specified column\n if sort_by:\n df.sort_values(by=sort_by, inplace=True)\n\n except OpenBBError as e:\n raise e\n except ValueError as ve:\n raise OpenBBError(\n f\"ValueError: {ve}. Ensure the data format matches the expected format.\"\n ) from ve\n except TypeError as te:\n raise OpenBBError(\n f\"TypeError: {te}. Check the data types in your results.\"\n ) from te\n except Exception as ex:\n raise OpenBBError(f\"An unexpected error occurred: {ex}\") from ex\n\n return df\n```\n###test function signature:\n```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_to_dataframe_w_args", + "reference": "def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n co = OBBject(results=results)\n\n # Act and Assert\n result = co.to_dataframe(index=index, sort_by=sort_by)\n assert isinstance(result, pd.DataFrame)\n assert result.index.name == index\n\n # check if dataframe is properly sorted\n assert result[sort_by].is_monotonic_increasing", + "completions": [ + "```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n from openbb_terminal.decorators import log_start_end\n from openbb_terminal.stocks.fundamental_analysis import financedatabase_model\n\n @log_start_end(log=logger)\n def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n df = financedatabase_model.to_dataframe(results, index, sort_by)\n # Act\n # Assert\n assert df.shape[0] > 0\n assert df.shape[1] > 0\n assert df.index.name == index\n assert df.columns.tolist() == sort_by\n\n # Act\n test_to_dataframe_w_args(results, index, sort_by)\n # Assert\n```\n" + ], + "line": 76, + "token": 566, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 46, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef unique(ar, return_index=False, return_inverse=False,\n return_counts=False, axis=None, *, equal_nan=True):\n ar = np.asanyarray(ar)\n if axis is None:\n ret = _unique1d(ar, return_index, return_inverse, return_counts, \n equal_nan=equal_nan)\n return _unpack_tuple(ret)\n\n # axis was specified and not None\n try:\n ar = np.moveaxis(ar, axis, 0)\n except np.AxisError:\n # this removes the \"axis1\" or \"axis2\" prefix from the error message\n raise np.AxisError(axis, ar.ndim) from None\n\n # Must reshape to a contiguous 2D array for this to work...\n orig_shape, orig_dtype = ar.shape, ar.dtype\n ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))\n ar = np.ascontiguousarray(ar)\n dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]\n\n # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured\n # data type with `m` fields where each field has the data type of `ar`.\n # In the following, we create the array `consolidated`, which has\n # shape `(n,)` with data type `dtype`.\n try:\n if ar.shape[1] > 0:\n consolidated = ar.view(dtype)\n else:\n # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is\n # a data type with itemsize 0, and the call `ar.view(dtype)` will\n # fail. Instead, we'll use `np.empty` to explicitly create the\n # array with shape `(len(ar),)`. Because `dtype` in this case has\n # itemsize 0, the total size of the result is still 0 bytes.\n consolidated = np.empty(len(ar), dtype=dtype)\n except TypeError as e:\n # There's no good way to do this for object arrays, etc...\n msg = 'The axis argument to unique is not supported for dtype {dt}'\n raise TypeError(msg.format(dt=ar.dtype)) from e\n\n def reshape_uniq(uniq):\n n = len(uniq)\n uniq = uniq.view(orig_dtype)\n uniq = uniq.reshape(n, *orig_shape[1:])\n uniq = np.moveaxis(uniq, 0, axis)\n return uniq\n\n output = _unique1d(consolidated, return_index,\n return_inverse, return_counts, equal_nan=equal_nan)\n output = (reshape_uniq(output[0]),) + output[1:]\n return _unpack_tuple(output)\n```\n###test function signature:\n```python\n\n def test_unique_1d(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_1d", + "reference": "\n def test_unique_1d(self):\n\n def check_all(a, b, i1, i2, c, dt):\n base_msg = 'check {0} failed for type {1}'\n\n msg = base_msg.format('values', dt)\n v = unique(a)\n assert_array_equal(v, b, msg)\n\n msg = base_msg.format('return_index', dt)\n v, j = unique(a, True, False, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i1, msg)\n\n msg = base_msg.format('return_inverse', dt)\n v, j = unique(a, False, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i2, msg)\n\n msg = base_msg.format('return_counts', dt)\n v, j = unique(a, False, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, c, msg)\n\n msg = base_msg.format('return_index and return_inverse', dt)\n v, j1, j2 = unique(a, True, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n\n msg = base_msg.format('return_index and return_counts', dt)\n v, j1, j2 = unique(a, True, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format('return_inverse and return_counts', dt)\n v, j1, j2 = unique(a, False, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i2, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format(('return_index, return_inverse '\n 'and return_counts'), dt)\n v, j1, j2, j3 = unique(a, True, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n assert_array_equal(j3, c, msg)\n\n a = [5, 7, 1, 2, 1, 5, 7]*10\n b = [1, 2, 5, 7]\n i1 = [2, 3, 0, 1]\n i2 = [2, 3, 0, 1, 0, 2, 3]*10\n c = np.multiply([2, 1, 2, 2], 10)\n\n # test for numeric arrays\n types = []\n types.extend(np.typecodes['AllInteger'])\n types.extend(np.typecodes['AllFloat'])\n types.append('datetime64[D]')\n types.append('timedelta64[D]')\n for dt in types:\n aa = np.array(a, dt)\n bb = np.array(b, dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for object arrays\n dt = 'O'\n aa = np.empty(len(a), dt)\n aa[:] = a\n bb = np.empty(len(b), dt)\n bb[:] = b\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for structured arrays\n dt = [('', 'i'), ('', 'i')]\n aa = np.array(list(zip(a, a)), dt)\n bb = np.array(list(zip(b, b)), dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for ticket #2799\n aa = [1. + 0.j, 1 - 1.j, 1]\n assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])\n\n # test for ticket #4785\n a = [(1, 2), (1, 2), (2, 3)]\n unq = [1, 2, 3]\n inv = [0, 1, 0, 1, 1, 2]\n a1 = unique(a)\n assert_array_equal(a1, unq)\n a2, a2_inv = unique(a, return_inverse=True)\n assert_array_equal(a2, unq)\n assert_array_equal(a2_inv, inv)\n\n # test for chararrays with return_inverse (gh-5099)\n a = np.chararray(5)\n a[...] = ''\n a2, a2_inv = np.unique(a, return_inverse=True)\n assert_array_equal(a2_inv, np.zeros(5))\n\n # test for ticket #9137\n a = []\n a1_idx = np.unique(a, return_index=True)[1]\n a2_inv = np.unique(a, return_inverse=True)[1]\n a3_idx, a3_inv = np.unique(a, return_index=True,\n return_inverse=True)[1:]\n assert_equal(a1_idx.dtype, np.intp)\n assert_equal(a2_inv.dtype, np.intp)\n assert_equal(a3_idx.dtype, np.intp)\n assert_equal(a3_inv.dtype, np.intp)\n\n # test for ticket 2111 - float\n a = [2.0, np.nan, 1.0, np.nan]\n ua = [1.0, 2.0, np.nan]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - complex\n a = [2.0-1j, np.nan, 1.0+1j, complex(0.0, np.nan), complex(1.0, np.nan)]\n ua = [1.0+1j, 2.0-1j, complex(0.0, np.nan)]\n ua_idx = [2, 0, 3]\n ua_inv = [1, 2, 0, 2, 2]\n ua_cnt = [1, 1, 3]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - datetime64\n nat = np.datetime64('nat')\n a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]\n ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - timedelta\n nat = np.timedelta64('nat')\n a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]\n ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for gh-19300\n all_nans = [np.nan] * 4\n ua = [np.nan]\n ua_idx = [0]\n ua_inv = [0, 0, 0, 0]\n ua_cnt = [4]\n assert_equal(np.unique(all_nans), ua)\n assert_equal(np.unique(all_nans, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(all_nans, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(all_nans, return_counts=True), (ua, ua_cnt))", + "completions": [ + "```python\n def test_unique_1d(self):\n ar = np.array([1, 2, 3, 4, 5, 5, 5, 5, 5, 5])\n self.assertEqual(unique(ar), np.array([1, 2, 3, 4, 5]))\n self.assertEqual(unique(ar, return_index=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n```\n" + ], + "line": 56, + "token": 571, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef scaled_dot_product_attention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n att_mask: Optional[Union[AttentionMask, \"SparseCS\", torch.Tensor]],\n dropout: Optional[torch.nn.Module] = None,\n block_size: int = 128,\n) -> torch.Tensor:\n autocast_disabled = (\n _has_cpp_library\n and isinstance(att_mask, SparseCS)\n or (att_mask is not None and att_mask.is_sparse)\n )\n seq_len = q.shape[-2]\n\n # switch if:\n # causal is required but mask is not sparse\n # fp16 or under amp context\n # sequence length is divisible by block size\n # same seq len for K and Q\n switch_to_blocksparse = (\n _is_blocksparse_available\n and (att_mask is not None and not att_mask.is_sparse)\n and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)\n and (q.dtype == torch.float16 or torch.is_autocast_enabled())\n and not seq_len % block_size\n and q.shape[-2] == k.shape[-2]\n )\n\n if switch_to_blocksparse:\n logger.info(\"Switching causal attention to Triton blocksparse...\")\n return blocksparse_attention(q, k, v, dropout, block_size)\n\n with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore\n if autocast_disabled:\n q, k, v = q.float(), k.float(), v.float()\n\n att = scaled_query_key_softmax(q, k, att_mask=att_mask)\n\n # Optional dropout, could be part of the masking in the future\n att = _apply_dropout(att, dropout)\n\n # Get to the predicted values, for all heads\n # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)\n y = bmm(att, v)\n return y\n```\n###test function signature:\n```python\n\ndef test_amp_attention_dense(device):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_amp_attention_dense", + "reference": "\ndef test_amp_attention_dense(device):\n b, s, d = 8, 64, 32\n prob = 0.9\n\n a = torch.rand(b, s, d, device=device)\n m = torch.rand(s, s, device=device) > prob\n\n with torch.cuda.amp.autocast():\n r = scaled_dot_product_attention(a, a, a, m)\n\n expected_device = torch.float16 if device == \"cuda\" else torch.float32\n assert r.dtype == expected_device", + "completions": [ + "```python\n\ndef test_amp_attention_dense(device):\n q = torch.randn(1, 128, 128, device=device)\n k = torch.randn(1, 128, 128, device=device)\n v = torch.randn(1, 128, 128, device=device)\n att_mask = torch.randn(1, 128, 128, device=device)\n dropout = torch.nn.Dropout(0.1)\n block_size = 128\n y = scaled_dot_product_attention(q, k, v, att_mask, dropout, block_size)\n assert y.shape == (1, 128, 128)\n```\n" + ], + "line": 50, + "token": 390, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def quantize(\n cls,\n tensor: Tensor,\n nbits: int = 4,\n channel_wise: bool = True,\n group_size: int = 64,\n optimize: bool = False,\n round_zero: bool = False,\n axis: int = 0,\n bitpack: bool = True,\n compute_dtype: torch.dtype | None = None,\n view_as_float: bool = False,\n device: str = \"cuda\",\n ) -> tuple:\n assert nbits in Quantizer.SUPPORTED_BITS, (\n \"nbits=\" + str(nbits) + \" not supported.\"\n )\n assert axis in [0, 1], \"axis should be either 0 or 1\"\n if group_size is not None:\n assert is_divisible(tensor.numel(), group_size), (\n \"group_size should be divisble by the total tensor dimensions. shape: \"\n + str(tensor.shape)\n + \", group_size: \"\n + str(group_size)\n )\n\n W = tensor.float()\n shape = W.shape\n\n # Reshape for grouping\n if (group_size is not None) and channel_wise:\n W = (\n W.reshape([-1, group_size])\n if (axis == 1)\n else W.reshape([group_size, -1])\n )\n\n # Get min/max values\n if not channel_wise:\n _min, _max = W.min(), W.max()\n optimize = False\n else:\n _min = W.min(axis=axis, keepdim=True)[0]\n _max = W.max(axis=axis, keepdim=True)[0]\n\n max_v = 2**nbits - 1\n min_v = 0\n min_max = [min_v, max_v]\n\n # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.\n scale = (max_v / (_max - _min)).clamp(\n max=2e4\n ) # clamp to avoid half-precision problems\n zero = -_min * scale\n\n # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14\n if round_zero:\n zero = torch.round(zero)\n\n # Fine-tune weights\n if optimize:\n W_q, scale, zero = Quantizer.optimize_weights(\n tensor=W,\n scale=scale,\n zero=zero,\n min_max=min_max,\n axis=axis,\n device=device,\n )\n else:\n W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])\n\n # Store meta-data (we invert the scale for dequantization)\n meta = {\n \"nbits\": nbits,\n \"group_size\": group_size,\n \"shape\": shape,\n \"scale\": 1.0 / scale,\n \"zero\": zero,\n \"axis\": axis,\n \"packing\": Quantizer.bit_to_packing[nbits],\n }\n meta[\"unpack_view_dtype\"] = Quantizer.unpack_view_dtype[meta[\"packing\"]]\n\n # Pack bits\n meta[\"view_as_float\"] = view_as_float\n if bitpack:\n W_q = Quantizer.pack[meta[\"packing\"]](W_q)\n if view_as_float:\n W_q = W_q.view(\n torch.float32 if compute_dtype is None else compute_dtype\n ) # store quantized weights as compute_dtype\n else:\n W_q = W_q.to(tensor.dtype)\n meta[\"packing\"] = None\n\n # cleanup\n del W, _min, _max\n torch.cuda.empty_cache()\n\n return W_q, meta\n```\n###test function signature:\n```python\n\n def test_quantizer(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_quantizer", + "reference": "\n def test_quantizer(self): \n for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:\n for nbits in [8,4,3,2,1]:\n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, view_as_float=False)\n if nbits == 3:\n assert W_q.dtype == torch.int32\n else:\n assert W_q.dtype == torch.uint8\n w_dq = Quantizer.dequantize(W_q, meta)\n norm1 = torch.norm(self.w - w_dq, p=0.7)\n \n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, compute_dtype=compute_dtype, view_as_float=True)\n assert W_q.dtype == compute_dtype\n w_dq = Quantizer.dequantize(W_q, meta)\n norm2 = torch.norm(self.w - w_dq, p=0.7)\n \n self.assertTrue(torch.equal(norm1, norm2))", + "completions": [ + "```python\n\n def test_quantizer(self):\n # test case 1\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 2\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=True, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], 0.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 3\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=True, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 4\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor," + ], + "line": 100, + "token": 645, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 49, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_num_ref_area", + "reference": "\ndef test_nanoplot_out_bars_with_num_ref_area():\n\n out_bars_with_num_ref_area = _generate_nanoplot(**CASES[13])\n\n assert _is_nanoplot_output(out_bars_with_num_ref_area)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_area = [1, 3]\n y_ref_line = 2\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_" + ], + "line": 749, + "token": 7038, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 50, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef mean(x, axis=None, keepdims=False):\n if isinstance(x, (list, tuple)):\n x = stack(x)\n x = convert_to_tensor(x)\n if axis == () or axis == []:\n # Torch handles the empty axis case differently from numpy.\n return x\n axis = to_tuple_or_list(axis) # see [NB] below\n\n ori_dtype = standardize_dtype(x.dtype)\n # torch.mean only supports floating point inputs\n compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n if \"int\" in ori_dtype or ori_dtype == \"bool\":\n result_dtype = compute_dtype\n else:\n result_dtype = ori_dtype\n\n # [NB] the python torch op torch.mean() is generated into\n # `torch._C._VariableFunctions.pyi`, and the method\n # signature is overloaded.\n # Dynamo won't actually find the correct signature of\n # `torch.mean()` if arguments are passed via kwargs\n # So we have to pass the arguments via positional args\n # EXCEPT for those that are forced as kwargs via the `*`\n # delimiter in the overloaded method signatures.\n # Additionally, we have to create a singleton-tuple\n # when `axis` is an int to match the existing fn signature\n result = torch.mean(\n x,\n axis,\n keepdims,\n dtype=to_torch_dtype(compute_dtype),\n )\n return cast(result, result_dtype)\n```\n###test function signature:\n```python\n\n def test_mean(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mean", + "reference": "\n def test_mean(self):\n x = np.array([[1, 2, 3], [3, 2, 1]])\n self.assertAllClose(knp.mean(x), np.mean(x))\n self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n self.assertAllClose(\n knp.mean(x, axis=1, keepdims=True),\n np.mean(x, axis=1, keepdims=True),\n )\n\n self.assertAllClose(knp.Mean()(x), np.mean(x))\n self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n self.assertAllClose(\n knp.Mean(axis=1, keepdims=True)(x),\n np.mean(x, axis=1, keepdims=True),\n )\n\n # test overflow\n x = np.array([65504, 65504, 65504], dtype=\"float16\")\n self.assertAllClose(knp.mean(x), np.mean(x))", + "completions": [ + "```python\n\n def test_mean(self):\n x = torch.randn(2, 3, 4)\n self.assertEqual(mean(x), torch.mean(x))\n self.assertEqual(mean(x, axis=0), torch.mean(x, dim=0))\n self.assertEqual(mean(x, axis=1), torch.mean(x, dim=1))\n self.assertEqual(mean(x, axis=2), torch.mean(x, dim=2))\n self.assertEqual(mean(x, axis=(0, 1)), torch.mean(x, dim=(0, 1)))\n self.assertEqual(mean(x, axis=(0, 2)), torch.mean(x, dim=(0, 2)))\n self.assertEqual(mean(x, axis=(1, 2)), torch.mean(x, dim=(1, 2)))\n self.assertEqual(mean(x, axis=(0, 1, 2)), torch.mean(x, dim=(0, 1, 2)))\n self.assertEqual(mean(x, axis=0, keepdims=True), torch.mean(x, dim=0, keepdim=True))\n self.assertEqual(mean(x, axis=1, keepdims=True), torch.mean(x, dim=1, keepdim=True))\n self.assertEqual(mean(x, axis=2, keepdims=True), torch.mean(x, dim=2, keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1), keepdims=True), torch.mean(x, dim=(0, 1), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 2), keepdims=True), torch.mean(x, dim=(0, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(1, 2), keepdims=True), torch.mean(x, dim=(1, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1, 2), keepdims=True), torch.mean(x, dim=(0, 1, 2), keepdim=True))\n```\n" + ], + "line": 42, + "token": 322, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 51, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_intrinsic_value(\n cash_flow: float,\n growth_rate: float,\n perpetual_growth_rate: float,\n weighted_average_cost_of_capital: float,\n cash_and_cash_equivalents: float,\n total_debt: float,\n shares_outstanding: float,\n periods: int = 5,\n) -> pd.DataFrame:\n components = {}\n\n cash_flow_projection = [cash_flow]\n\n # Cash Flow to Use\n for period in range(1, periods + 1):\n if period == 1:\n cash_flow_projection.append(cash_flow_projection[0] * (1 + growth_rate))\n else:\n cash_flow_projection.append(\n cash_flow_projection[period - 1] * (1 + growth_rate)\n )\n\n # Calculate the Terminal Value\n terminal_value = (\n cash_flow_projection[-1]\n * (1 + perpetual_growth_rate)\n / (weighted_average_cost_of_capital - perpetual_growth_rate)\n )\n\n # Add Terminal Value to the end of the cash flow projection\n cash_flow_projection[-1] = cash_flow_projection[-1] + terminal_value\n\n # Calculate the Present Value based on the Discounted Cash Flow Formula\n cash_flow_present_value = []\n for index, cash_flow_value in enumerate(cash_flow_projection):\n cash_flow_present_value.append(\n cash_flow_value / (1 + weighted_average_cost_of_capital) ** (index + 1)\n )\n\n # Calculate the Enterprise Value\n enterprise_value = sum(cash_flow_present_value)\n\n # Calculate the Equity Value\n equity_value = enterprise_value + cash_and_cash_equivalents - total_debt\n\n # Calculate the Intrinsic Value\n intrinsic_value = equity_value / shares_outstanding\n\n # Combine the components into a dictionary\n components = {\n \"Terminal Value\": terminal_value,\n \"Cash Flow Projection\": cash_flow_projection[-1],\n \"Enterprise Value\": enterprise_value,\n \"Equity Value\": equity_value,\n \"Intrinsic Value\": intrinsic_value,\n }\n\n return pd.DataFrame.from_dict(\n components, orient=\"index\", columns=[f\"Periods = {periods}\"]\n )\n```\n###test function signature:\n```python\n\ndef test_get_intrinsic_value(recorder):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_intrinsic_value", + "reference": "\ndef test_get_intrinsic_value(recorder):\n recorder.capture(\n intrinsic_model.get_intrinsic_value(\n cash_flow=500,\n growth_rate=0.05,\n perpetual_growth_rate=0.02,\n weighted_average_cost_of_capital=0.1,\n cash_and_cash_equivalents=100,\n total_debt=100,\n shares_outstanding=100,\n periods=5,\n )\n )", + "completions": [ + "```python\n\ndef test_get_intrinsic_value(recorder):\n cash_flow = 1000000\n growth_rate = 0.05\n perpetual_growth_rate = 0.05\n weighted_average_cost_of_capital = 0.1\n cash_and_cash_equivalents = 1000000\n total_debt = 1000000\n shares_outstanding = 1000000\n periods = 5\n\n result = get_intrinsic_value(\n cash_flow,\n growth_rate,\n perpetual_growth_rate,\n weighted_average_cost_of_capital,\n cash_and_cash_equivalents,\n total_debt,\n shares_outstanding,\n periods,\n )\n\n recorder.capture(result)\n```\n" + ], + "line": 61, + "token": 379, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 52, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef to_numpy(\n evset: EventSet,\n timestamp_to_datetime: bool = True,\n timestamps: bool = True,\n) -> Dict[str, ndarray]:\n timestamp_key = \"timestamp\"\n index_names = evset.schema.index_names()\n feature_names = evset.schema.feature_names()\n\n column_names = index_names + feature_names\n if timestamps:\n column_names += [timestamp_key]\n\n dst = {column_name: [] for column_name in column_names}\n for index, data in evset.data.items():\n assert isinstance(index, tuple)\n\n if timestamps:\n # Timestamps\n if evset.schema.is_unix_timestamp and timestamp_to_datetime:\n dst[timestamp_key].append(\n data.timestamps.astype(\"datetime64[s]\")\n )\n else:\n dst[timestamp_key].append(data.timestamps)\n\n # Features\n for feature_name, feature in zip(feature_names, data.features):\n dst[feature_name].append(feature)\n\n # Indexes\n num_timestamps = len(data.timestamps)\n for index_name, index_item in zip(index_names, index):\n dst[index_name].append(np.repeat(index_item, num_timestamps))\n\n dst = {k: np.concatenate(v) for k, v in dst.items()}\n return dst\n```\n###test function signature:\n```python\n\n def test_correct(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct", + "reference": "\n def test_correct(self):\n evset = event_set(\n timestamps=[\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n features={\n \"feature_1\": [0.5, 0.6],\n \"my_index\": [\"red\", \"blue\"],\n },\n indexes=[\"my_index\"],\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"timestamp\": np.array(\n [\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n dtype=\"datetime64[s]\",\n ),\n \"feature_1\": np.array([0.5, 0.6]),\n \"my_index\": np.array([b\"red\", b\"blue\"]),\n }\n\n for k in expected:\n np.testing.assert_array_equal(\n np.sort(result[k]), np.sort(expected[k])\n )", + "completions": [ + "```python\n def test_correct(self):\n evset = EventSet(\n data={\n \"timestamp\": [1610000000, 1610000000, 1610000000, 1610000000],\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n },\n schema=EventSetSchema(\n index_names=[\"index\"],\n feature_names=[\"feature1\", \"feature2\"],\n is_unix_timestamp=True,\n ),\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n \"timestamp\": [\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n ],\n }\n\n self.assertDictEqual(result, expected)\n```\n" + ], + "line": 41, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def from_file(cls, path: str) -> \"LayerConfig\":\n logger.info(f\"Loading route config from {path}\")\n _, ext = os.path.splitext(path)\n with open(path, \"r\") as f:\n if ext == \".json\":\n layer = json.load(f)\n elif ext in [\".yaml\", \".yml\"]:\n layer = yaml.safe_load(f)\n else:\n raise ValueError(\n \"Unsupported file type. Only .json and .yaml are supported\"\n )\n\n if not is_valid(json.dumps(layer)):\n raise Exception(\"Invalid config JSON or YAML\")\n\n encoder_type = layer[\"encoder_type\"]\n encoder_name = layer[\"encoder_name\"]\n routes = []\n for route_data in layer[\"routes\"]:\n # Handle the 'llm' field specially if it exists\n if \"llm\" in route_data and route_data[\"llm\"] is not None:\n llm_data = route_data.pop(\n \"llm\"\n ) # Remove 'llm' from route_data and handle it separately\n # Use the module path directly from llm_data without modification\n llm_module_path = llm_data[\"module\"]\n # Dynamically import the module and then the class from that module\n llm_module = importlib.import_module(llm_module_path)\n llm_class = getattr(llm_module, llm_data[\"class\"])\n # Instantiate the LLM class with the provided model name\n llm = llm_class(name=llm_data[\"model\"])\n # Reassign the instantiated llm object back to route_data\n route_data[\"llm\"] = llm\n\n # Dynamically create the Route object using the remaining route_data\n route = Route(**route_data)\n routes.append(route)\n\n return cls(\n encoder_type=encoder_type, encoder_name=encoder_name, routes=routes\n )\n```\n###test function signature:\n```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_from_file_yaml", + "reference": "\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n config_path = tmp_path / \"config.yaml\"\n config_path.write_text(\n layer_yaml()\n ) # Assuming layer_yaml() returns a valid YAML string\n\n # Load the LayerConfig from the temporary file\n layer_config = LayerConfig.from_file(str(config_path))\n\n # Assertions to verify the loaded configuration\n assert layer_config.encoder_type == \"cohere\"\n assert layer_config.encoder_name == \"embed-english-v3.0\"\n assert len(layer_config.routes) == 2\n assert layer_config.routes[0].name == \"politics\"", + "completions": [ + "```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n yaml_file = tmp_path / \"test_layer.yaml\"\n yaml_file.write_text(\n \"\"\"\n encoder_type: openai\n encoder_name: gpt2\n routes:\n - name: test_route\n model_name: test_model\n llm:\n module: tests.test_layer\n class: TestLLM\n model: test_model\n encoder_kwargs:\n max_position_embeddings: 512\n num_attention_heads: 8\n intermediate_size: 2048\n hidden_act: gelu\n hidden_dropout_prob: 0.1\n attention_probs_dropout_prob: 0.1\n type_vocab_size: 16\n initializer_range: 0.02\n layer_norm_eps: 1e-12\n use_cache: True\n is_decoder: False\n classifier_dropout: 0.1\n use_multiplicative_input_mask: False\n num_labels: 2\n output_attentions: False\n output_hidden_states: False\n return_dict: True\n task_specific_params:\n summarization:\n max_length: 20\n min_length: 5\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n" + ], + "line": 49, + "token": 405, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 54, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def bind_all(self, circuit: QuantumCircuit) -> np.ndarray:\n arr = np.empty(self.shape, dtype=object)\n for idx in np.ndindex(self.shape):\n arr[idx] = self.bind(circuit, idx)\n return arr\n```\n###test function signature:\n```python\n def test_data(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_data", + "reference": " def test_data(self):\n with self.subTest(\"binding a single value\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 0)\n self.assertEqual(ba.shape, ())\n self.assertEqual(ba.size, 1)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuit = ba.bind(self.circuit, ())\n self.assertEqual(bound_circuit, self.circuit.assign_parameters(vals))\n\n with self.subTest(\"binding an array\"):\n vals = np.linspace(0, 1, 300).reshape((2, 3, 50))\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 2)\n self.assertEqual(ba.shape, (2, 3))\n self.assertEqual(ba.size, 6)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuits = ba.bind_all(self.circuit)\n self.assertEqual(bound_circuits.shape, (2, 3))\n self.assertEqual(bound_circuits[0, 0], self.circuit.assign_parameters(vals[0, 0]))\n self.assertEqual(bound_circuits[0, 1], self.circuit.assign_parameters(vals[0, 1]))\n self.assertEqual(bound_circuits[0, 2], self.circuit.assign_parameters(vals[0, 2]))\n self.assertEqual(bound_circuits[1, 0], self.circuit.assign_parameters(vals[1, 0]))\n self.assertEqual(bound_circuits[1, 1], self.circuit.assign_parameters(vals[1, 1]))\n self.assertEqual(bound_circuits[1, 2], self.circuit.assign_parameters(vals[1, 2]))\n\n with self.subTest(\"binding a single param\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params[0]: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 1)\n self.assertEqual(ba.ndim, 1)\n self.assertEqual(ba.shape, (50,))\n self.assertEqual(ba.size, 50)\n self.assertEqual(list(ba.data.keys()), [(self.params[0].name,)])\n np.testing.assert_allclose(list(ba.data.values()), [vals[..., np.newaxis]])", + "completions": [ + "```python\n def test_data(self):\n from qiskit import QuantumCircuit\n from qiskit.quantum_info import Pauli\n from qiskit.quantum_info import PauliTable\n from qiskit.quantum_info import PauliSumOp\n from qiskit.quantum_info import PauliOp\n from qiskit.quantum_info import PauliExpectation\n from qiskit.quantum_info import PauliMoment\n from qiskit.quantum_info import PauliTensor\n from qiskit.quantum_info import PauliList\n from qiskit.quantum_info import PauliTransform\n from qiskit.quantum_info import PauliTransformOp\n from qiskit.quantum_info import PauliTransformList\n from qiskit.quantum_info import PauliTransformSum\n from qiskit.quantum_info import PauliTransformProduct\n from qiskit.quantum_info import PauliTransformExpectation\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum" + ], + "line": 15, + "token": 122, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef set_seed(\n seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None\n) -> int:\n world_size, rank = get_world_size_and_rank()\n max_val = np.iinfo(np.uint32).max - world_size + 1\n min_val = np.iinfo(np.uint32).min\n if seed is None:\n rand_seed = torch.randint(min_val, max_val, (1,))\n seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks\n if seed < min_val or seed > max_val:\n raise ValueError(\n f\"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]\"\n )\n local_seed = seed + rank\n _log.debug(\n f\"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}\"\n )\n\n torch.manual_seed(local_seed)\n np.random.seed(local_seed)\n random.seed(local_seed)\n\n if debug_mode is not None:\n _log.debug(f\"Setting deterministic debug mode to {debug_mode}\")\n torch.set_deterministic_debug_mode(debug_mode)\n deterministic_debug_mode = torch.get_deterministic_debug_mode()\n if deterministic_debug_mode == 0:\n _log.debug(\"Disabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = False\n torch.backends.cudnn.benchmark = True\n else:\n _log.debug(\"Enabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.benchmark = False\n # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n return seed\n```\n###test function signature:\n```python\n\n def test_deterministic_true(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_deterministic_true", + "reference": "\n def test_deterministic_true(self) -> None:\n for det_debug_mode, det_debug_mode_str in [(1, \"warn\"), (2, \"error\")]:\n warn_only = det_debug_mode == 1\n for debug_mode in (det_debug_mode, det_debug_mode_str):\n set_seed(42, debug_mode=debug_mode)\n assert torch.backends.cudnn.deterministic\n assert not torch.backends.cudnn.benchmark\n assert det_debug_mode == torch.get_deterministic_debug_mode()\n assert torch.are_deterministic_algorithms_enabled()\n assert (\n warn_only == torch.is_deterministic_algorithms_warn_only_enabled()\n )\n assert os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] == \":4096:8\"", + "completions": [ + "```python\n def test_deterministic_true(self) -> None:\n seed = set_seed(debug_mode=True)\n self.assertEqual(seed, 123456789)\n```\n" + ], + "line": 45, + "token": 345, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 56, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n```\n###test function signature:\n```python\n\ndef test_concat_periodarray_2d():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_concat_periodarray_2d", + "reference": "\ndef test_concat_periodarray_2d():\n pi = pd.period_range(\"2016-01-01\", periods=36, freq=\"D\")\n arr = pi._data.reshape(6, 6)\n\n result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)\n tm.assert_period_array_equal(result, arr)\n\n result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)\n tm.assert_period_array_equal(result, arr)\n\n msg = (\n \"all the input array dimensions.* for the concatenation axis must match exactly\"\n )\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)\n\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:2], arr[2:]], axis=1)", + "completions": [ + "```python\n\ndef test_concat_periodarray_2d():\n import numpy as np\n import pandas as pd\n from pandas.core.dtypes.common import is_period_dtype\n\n def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n\n def _concat_datetime(to_concat, axis: AxisInt = 0):\n from pandas.core.arrays import PeriodArray\n\n if axis == 0:\n " + ], + "line": 70, + "token": 610, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 57, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef ingest_ocp_payload(request):\n request_id = uuid4().hex\n response_data = {\"request-id\": request_id, \"payload-name\": []}\n if request.method == \"POST\":\n if not request.FILES:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for _, file in request.FILES.items():\n payload_name = file.name\n response_data[\"payload-name\"].append(payload_name)\n s3_signature = get_s3_signature(settings.S3_ENDPOINT, payload_name, method=\"put_object\")\n res = upload_file_to_s3(s3_signature, data=file.file)\n if res.status_code == HTTPStatus.OK:\n response_data[\"upload\"] = \"success\"\n else:\n response_data[\"upload\"] = \"failed\"\n response_data[\"failed-reason\"] = res.reason\n return Response(response_data, status=res.status_code)\n send_payload(request_id, payload_name)\n else:\n params = request.query_params\n payload_names = params.get(\"payload_name\")\n if not payload_names:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for payload_name in payload_names.split(\",\"):\n send_payload(request_id, payload_name)\n response_data[\"payload-name\"].append(payload_name)\n\n response_data[\"ingest-started\"] = True\n\n return Response(response_data, status=HTTPStatus.ACCEPTED)\n```\n###test function signature:\n```python\n\n def test_ingest_ocp_payload(self):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ingest_ocp_payload", + "reference": "\n def test_ingest_ocp_payload(self):\n # Arrange\n file1 = SimpleUploadedFile(\"file1.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n file2 = SimpleUploadedFile(\"file2.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n test_table = [\n # Happy path tests\n (\n \"POST\",\n {\"file1\": file1, \"file2\": file2},\n \"\",\n HTTPStatus.ACCEPTED,\n {\"upload\": \"success\", \"ingest-started\": True},\n \"happy_path_post\",\n ),\n (\"GET\", {}, \"?payload_name=file1,file2\", HTTPStatus.ACCEPTED, {\"ingest-started\": True}, \"happy_path_get\"),\n # Edge cases\n (\"POST\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_files_post\"),\n (\"GET\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_payload_name_get\"),\n # Error cases\n (\"POST\", {\"file1\": file1}, \"\", HTTPStatus.INTERNAL_SERVER_ERROR, {\"upload\": \"failed\"}, \"error_case_post\"),\n (\"GET\", {}, \"?payload_name=non_existent_file\", HTTPStatus.ACCEPTED, {}, \"error_case_get\"),\n ]\n for test in test_table:\n method, files, query_params, expected_status, expected_response, test_id = test\n with self.subTest(test=test_id):\n\n mock_response = MagicMock()\n mock_response.status_code = (\n HTTPStatus.OK if expected_response.get(\"upload\") != \"failed\" else HTTPStatus.INTERNAL_SERVER_ERROR\n )\n mock_response.reason = \"Some reason\"\n\n with patch(\"masu.api.ingest_ocp_payload.get_s3_signature\") as mock_get_s3_signature, patch(\n \"masu.api.ingest_ocp_payload.upload_file_to_s3\", return_value=mock_response\n ) as mock_upload_file_to_s3, patch(\"masu.api.ingest_ocp_payload.send_payload\") as mock_send_payload:\n\n request = (\n APIRequestFactory().get(f\"/{query_params}\")\n if method == \"GET\"\n else APIRequestFactory().post(\"/\", data=files)\n )\n\n # Act\n response = ingest_ocp_payload(request)\n\n # Assert\n self.assertEqual(response.status_code, expected_status)\n if method == \"POST\" and files:\n mock_get_s3_signature.assert_called()\n mock_upload_file_to_s3.assert_called()\n if method == \"GET\" and \"payload_name\" in query_params:\n mock_send_payload.assert_called()\n for key, value in expected_response.items():\n self.assertEqual(response.data[key], value)", + "completions": [ + "```python\n def test_ingest_ocp_payload(self):\n # Arrange\n request = self.factory.post(\n \"/api/v1/ingest/ocp/payload\",\n data={\"payload_name\": \"payload_name\"},\n content_type=\"multipart/form-data\",\n )\n request.FILES = {\"payload_name\": \"payload_name\"}\n # Act\n response = ingest_ocp_payload(request)\n # Assert\n self.assertEqual(response.status_code, HTTPStatus.ACCEPTED)\n self.assertEqual(response.data[\"request-id\"], \"request_id\")\n self.assertEqual(response.data[\"payload-name\"], [\"payload_name\"])\n self.assertEqual(response.data[\"ingest-started\"], True)\n```\n" + ], + "line": 41, + "token": 308, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def compute_nb_items_with_metadata_field(\n project_id: str,\n metadata_field: str,\n collection_name: str,\n quantile_value: float = 0.1,\n) -> Tuple[int, float, int]:\n\n mongo_db = await get_mongo_db()\n\n try:\n pipeline = [\n # Match on the project id and the existence of the metadata field\n {\n \"$match\": {\n \"project_id\": project_id,\n f\"metadata.{metadata_field}\": {\"$exists\": True, \"$ne\": None},\n }\n },\n {\n \"$group\": {\n \"_id\": \"$user_id\", # Group by 'user_id'\n \"count\": {\"$sum\": 1}, # Count the documents in each group\n }\n },\n {\n \"$sort\": {\"count\": -1} # Sort by 'count' in descending order\n },\n ]\n\n # Execute the aggregation pipeline\n results = [] # List to hold the results\n async for group in mongo_db[collection_name].aggregate(pipeline):\n results.append(group[\"count\"])\n\n logger.debug(\n f\"Results for {metadata_field} in collection {collection_name} for project {project_id}: {results}\"\n )\n\n # If no results, return 0\n if len(results) == 0:\n return 0, 0.0, 0\n\n # Get the average and the quantiles\n average = sum(results) / len(results)\n bottom_quantile = results[int(len(results) * quantile_value)]\n top_quantile = results[int(len(results) * (1 - quantile_value))]\n\n return bottom_quantile, average, top_quantile\n\n except Exception as e:\n logger.warning(\n f\"Failed to fetch the number of {metadata_field} in collection {collection_name} for project {project_id}: {e}\",\n )\n return 0, 0.0, 0\n```\n###test function signature:\n```python\n\nasync def test_main_pipeline(db, populated_project):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_main_pipeline", + "reference": "\nasync def test_main_pipeline(db, populated_project):\n async for mongo_db in db:\n test_project_id = populated_project.id\n\n # Run the function\n topics = await fetch_topics(test_project_id)\n\n logger.debug(f\"Topics: {topics}\")\n\n count = await nb_items_with_a_metadata_field(\n test_project_id, \"tasks\", \"user_id\"\n )\n\n assert count > 0\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_successrate_metadata_quantiles(test_project_id, \"user_id\")\n\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert bottom_quantile >= 0\n # assert bottom_quantile <= top_quantile\n\n test_project_id = populated_project.id\n test_metadata_field = \"user_id\"\n test_collection_name = \"tasks\"\n test_quantile_value = 0.1\n\n # Run the function\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_nb_items_with_metadata_field(\n test_project_id,\n test_metadata_field,\n test_collection_name,\n test_quantile_value,\n )\n\n logger.debug(\"runned compute_nb_items_with_metadata_field\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, int)\n assert isinstance(average, float)\n assert isinstance(top_quantile, int)\n # assert bottom_quantile <= average <= top_quantile\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_session_length_per_metadata(\n test_project_id,\n test_metadata_field,\n test_quantile_value,\n )\n logger.debug(\"runned compute_session_length_per_metadata\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, float)\n assert isinstance(average, float)\n assert isinstance(top_quantile, float)\n # assert bottom_quantile <= average <= top_quantile", + "completions": [ + "```python\nasync def test_main_pipeline(db, populated_project):\n # Arrange\n project_id = populated_project.id\n collection_name = \"collection_name\"\n metadata_field = \"metadata_field\"\n quantile_value = 0.1\n\n # Act\n result = await compute_nb_items_with_metadata_field(\n project_id, metadata_field, collection_name, quantile_value\n )\n\n # Assert\n assert result == (0, 0.0, 0)\n```\n" + ], + "line": 56, + "token": 430, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n```\n###test function signature:\n```python\n\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_engine_auto_error_message", + "reference": "\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n from pandas.compat._optional import VERSIONS\n\n # Do we have engines installed, but a bad version of them?\n pa_min_ver = VERSIONS.get(\"pyarrow\")\n fp_min_ver = VERSIONS.get(\"fastparquet\")\n have_pa_bad_version = (\n False\n if not _HAVE_PYARROW\n else Version(pyarrow.__version__) < Version(pa_min_ver)\n )\n have_fp_bad_version = (\n False\n if not _HAVE_FASTPARQUET\n else Version(fastparquet.__version__) < Version(fp_min_ver)\n )\n # Do we have usable engines installed?\n have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version\n have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version\n\n if not have_usable_pa and not have_usable_fp:\n # No usable engines found.\n if have_pa_bad_version:\n match = f\"Pandas requires version .{pa_min_ver}. or newer of .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n\n if have_fp_bad_version:\n match = f\"Pandas requires version .{fp_min_ver}. or newer of .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")", + "completions": [ + "```python\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n import pyarrow\n import fastparquet\n from pyarrow.parquet import ParquetEngine\n from fastparquet.parquet import ParquetEngine as FastParquetEngine\n\n def get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n\n # pyarrow is installed\n assert get_engine(\"auto\").__class__ == PyArrowImpl\n assert get_engine(\"pyarrow\").__class__ == PyArrowImpl\n assert get_engine(\"fastparquet\").__class__ == FastParquetImpl\n\n # pyarrow is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'pyarrow'\" in str(excinfo.value)\n\n # fastparquet is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'fastparquet'\" in str(excinfo.value)\n\n # pyarrow is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n\n # fastparquet is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n" + ], + "line": 38, + "token": 277, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 60, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def attribute(\n self,\n inp: InterpretableInput,\n target: Union[str, torch.Tensor, None] = None,\n num_trials: int = 1,\n gen_args: Optional[Dict] = None,\n # internal callback hook can be used for logging\n _inspect_forward: Optional[Callable] = None,\n **kwargs,\n ) -> LLMAttributionResult:\n\n assert isinstance(\n inp, self.SUPPORTED_INPUTS\n ), f\"LLMAttribution does not support input type {type(inp)}\"\n\n if target is None:\n # generate when None\n assert hasattr(self.model, \"generate\") and callable(self.model.generate), (\n \"The model does not have recognizable generate function.\"\n \"Target must be given for attribution\"\n )\n\n if not gen_args:\n gen_args = DEFAULT_GEN_ARGS\n\n model_inp = self._format_model_input(inp.to_model_input())\n output_tokens = self.model.generate(model_inp, **gen_args)\n target_tokens = output_tokens[0][model_inp.size(1) :]\n else:\n assert gen_args is None, \"gen_args must be None when target is given\"\n\n if type(target) is str:\n # exclude sos\n target_tokens = self.tokenizer.encode(target)[1:]\n target_tokens = torch.tensor(target_tokens)\n elif type(target) is torch.Tensor:\n target_tokens = target\n\n attr = torch.zeros(\n [\n 1 + len(target_tokens) if self.include_per_token_attr else 1,\n inp.n_itp_features,\n ],\n dtype=torch.float,\n device=self.device,\n )\n\n for _ in range(num_trials):\n attr_input = inp.to_tensor().to(self.device)\n\n cur_attr = self.attr_method.attribute(\n attr_input,\n additional_forward_args=(inp, target_tokens, _inspect_forward),\n **kwargs,\n )\n\n # temp necessary due to FA & Shapley's different return shape of multi-task\n # FA will flatten output shape internally (n_output_token, n_itp_features)\n # Shapley will keep output shape (batch, n_output_token, n_input_features)\n cur_attr = cur_attr.reshape(attr.shape)\n\n attr += cur_attr\n\n attr = attr / num_trials\n\n attr = inp.format_attr(attr)\n\n return LLMAttributionResult(\n attr[0],\n attr[1:]\n if self.include_per_token_attr\n else None, # shape(n_output_token, n_input_features)\n inp.values,\n self.tokenizer.convert_ids_to_tokens(target_tokens),\n )\n```\n###test function signature:\n```python\n\n def test_llm_attr(self, AttrClass) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_llm_attr", + "reference": "\n def test_llm_attr(self, AttrClass) -> None:\n llm = DummyLLM()\n llm.to(self.device)\n tokenizer = DummyTokenizer()\n llm_attr = LLMAttribution(AttrClass(llm), tokenizer)\n\n inp = TextTemplateInput(\"{} b {} {} e {}\", [\"a\", \"c\", \"d\", \"f\"])\n res = llm_attr.attribute(inp, \"m n o p q\")\n\n self.assertEqual(res.seq_attr.shape, (4,))\n self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))\n self.assertEqual(res.input_tokens, [\"a\", \"c\", \"d\", \"f\"])\n self.assertEqual(res.output_tokens, [\"m\", \"n\", \"o\", \"p\", \"q\"])\n self.assertEqual(res.seq_attr.device.type, self.device)\n self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)", + "completions": [ + "```python\n def test_llm_attr(self, AttrClass) -> None:\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for L" + ], + "line": 72, + "token": 464, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 61, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def plan(self, task: str):\n agent_pool_description = self.agent_pool.get_agent_pool_description()\n feedback = self.get_last_feedback()\n last_plan = self.get_last_plan()\n\n # run the planning\n plan = self.planning.execute(task, agent_pool_description, last_plan, feedback)\n\n logger.info(f\"Plan: {plan}\")\n\n self.shared_memory.add(\n event_type=EventType.planning, agent=self.name, content=str(plan)\n )\n\n return plan\n```\n###test function signature:\n```python\n\ndef test_planner_succeeds(get_llm): # noqa: F811\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_planner_succeeds", + "reference": "\ndef test_planner_succeeds(get_llm): # noqa: F811\n llm = get_llm(__file__, test_planner_succeeds.__name__)\n\n physicist_description = (\n \"The physicist agent answers questions or research about physics-related topics\"\n )\n physicist = Physicist(\n name=\"Physicist\",\n description=physicist_description,\n llm=llm,\n )\n\n agent_pool = AgentPool()\n agent_pool.add_agents([physicist])\n\n shared_memory = SharedMemory(\n objective=\"Share the information across different agents.\",\n agent_pool=agent_pool,\n )\n\n planner = Planner(\n name=\"planner\",\n agent_pool=agent_pool,\n shared_memory=shared_memory,\n llm=llm,\n )\n\n task = \"\"\"We need to render a highly complex 3D image on the solar system. We can use any publicly avaliable\n resources to achieve this task.\"\"\" # noqa: E501\n\n p = planner.plan(task=task)\n assert len(p.steps) > 0", + "completions": [ + "```python\ndef test_planner_succeeds(get_llm): # noqa: F811\n planner = get_llm\n planner.agent_pool.set_agent_pool_description(\n {\n \"agent_pool\": {\n \"agents\": [\n {\n \"name\": \"agent1\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n {\n \"name\": \"agent2\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n ]\n }\n }\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task1\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task2\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task3\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task4\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task5\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task6\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task7\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\"," + ], + "line": 21, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_edge(self, from_step: str, to_step: str) -> None:\n if from_step not in self.G:\n raise ValueError(f\"Step with name '{from_step}' does not exist\")\n\n if to_step not in self.G:\n raise ValueError(f\"Step with name '{to_step}' does not exist\")\n\n if to_step in self.G[from_step]:\n raise ValueError(\n f\"There is already a edge from '{to_step}' to '{from_step}'\"\n )\n\n if to_step in nx.ancestors(self.G, from_step):\n raise ValueError(\n f\"Cannot add edge from '{from_step}' to '{to_step}' as it would create a cycle.\"\n )\n\n self.G.add_edge(from_step, to_step)\n```\n###test function signature:\n```python\n\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validate_missing_step_input", + "reference": "\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n class DummyStep3(Step):\n @property\n def inputs(self) -> List[str]:\n return [\"instruction\"]\n\n @property\n def outputs(self) -> List[str]:\n return [\"response\"]\n\n def process(self) -> \"StepOutput\": # type: ignore\n yield [{\"response\": \"response1\"}]\n\n dag = DAG()\n dag.add_step(dummy_generator_step)\n dag.add_step(DummyStep3(name=\"dummy_step_3\", pipeline=pipeline))\n dag.add_edge(\"dummy_generator_step\", \"dummy_step_3\")\n\n with pytest.raises(\n ValueError,\n match=\"Step 'dummy_step_3' should have a parameter with type hint `StepInput`\",\n ):\n dag.validate()", + "completions": [ + "```python\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_2\")\n```\n" + ], + "line": 26, + "token": 238, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 63, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_hourly_temp_mean", + "reference": "\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n meter_data = il_electricity_cdd_hdd_hourly[\"meter_data\"][\"2016-03-01\":\"2016-07-01\"]\n temperature_data = il_electricity_cdd_hdd_hourly[\"temperature_data\"][\n \"2016-03-01\":\"2016-07-01\"\n ]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert list(sorted(df.columns)) == [\n \"n_hours_dropped\",\n \"n_hours_kept\",\n \"temperature_mean\",\n ]\n assert df.shape == (2952, 3)\n\n assert round(df.temperature_mean.mean()) == 62.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n df = il_electricity_cdd_hdd_hourly.copy()\n df = df.iloc[100:150]\n df = df.assign(\n **{\n \"cdd_10\": np.maximum(df.temperature_mean - 10, 0),\n \"cdd_20\": np.maximum(df.temperature_mean - 20, 0),\n \"hdd_10\": np.maximum(10 - df.temperature_mean, 0),\n \"hdd_20\": np.maximum(20 - df.temperature_mean, 0),\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n df = df.drop([\"temperature_mean\"], axis=1)\n df = df.reindex(il_electricity_cdd_hdd_hourly.index)\n df = overwrite_partial_rows_with_nan(df)\n df = df.iloc[:-1].reindex(df.index)\n assert_frame_equal(df, il_electricity_cdd_hdd_hourly)\n```\n" + ], + "line": 167, + "token": 992, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 64, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef pandasSQL_builder(\n con,\n schema: str | None = None,\n need_transaction: bool = False,\n) -> PandasSQL:\n import sqlite3\n\n if isinstance(con, sqlite3.Connection) or con is None:\n return SQLiteDatabase(con)\n\n sqlalchemy = import_optional_dependency(\"sqlalchemy\", errors=\"ignore\")\n\n if isinstance(con, str) and sqlalchemy is None:\n raise ImportError(\"Using URI string without sqlalchemy installed.\")\n\n if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):\n return SQLDatabase(con, schema, need_transaction)\n\n warnings.warn(\n \"pandas only supports SQLAlchemy connectable (engine/connection) or \"\n \"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 \"\n \"objects are not tested. Please consider using SQLAlchemy.\",\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return SQLiteDatabase(con)\n```\n###test function signature:\n```python\n\n def test_execute_closed_connection(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_execute_closed_connection", + "reference": "\n def test_execute_closed_connection(self):\n create_sql = \"\"\"\n CREATE TABLE test\n (\n a TEXT,\n b TEXT,\n c REAL,\n PRIMARY KEY (a, b)\n );\n \"\"\"\n with contextlib.closing(sqlite3.connect(\":memory:\")) as conn:\n cur = conn.cursor()\n cur.execute(create_sql)\n\n with sql.pandasSQL_builder(conn) as pandas_sql:\n pandas_sql.execute('INSERT INTO test VALUES(\"foo\", \"bar\", 1.234)')\n\n msg = \"Cannot operate on a closed database.\"\n with pytest.raises(sqlite3.ProgrammingError, match=msg):\n tquery(\"select * from test\", con=conn)", + "completions": [ + "```python\n def test_execute_closed_connection(self):\n from pandas.io.sql import PandasSQL\n from pandas.io.sql.sqldatabase import SQLiteDatabase\n from pandas.io.sql.pandas_sql import PandasSQLQueryCompiler\n from pandas.io.sql.pandas_sql import PandasSQLEngine\n from pandas.io.sql.pandas_sql import PandasSQLTable\n from pandas.io.sql.pandas_sql import PandasSQLConverter\n from pandas.io.sql.pandas_sql import PandasSQLAlchemy\n from pandas.io.sql.pandas_sql import PandasSQLQuery\n from pandas.io.sql.pandas_sql import PandasSQLExecutionContext\n from pandas.io.sql.pandas_sql import PandasSQLCompiler\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas." + ], + "line": 31, + "token": 237, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 65, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_hourly_degree_days", + "reference": "\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index,\n temperature_data,\n heating_balance_points=[60, 61],\n cooling_balance_points=[65, 66],\n temperature_mean=False,\n degree_day_method=\"hourly\",\n )\n assert df.shape == (27, 6)\n assert list(sorted(df.columns)) == [\n \"cdd_65\",\n \"cdd_66\",\n \"hdd_60\",\n \"hdd_61\",\n \"n_hours_dropped\",\n \"n_hours_kept\",\n ]\n snapshot.assert_match(\n [\n round(df.hdd_60.mean(), 2),\n round(df.hdd_61.mean(), 2),\n round(df.cdd_65.mean(), 2),\n round(df.cdd_66.mean(), 2),\n round(df.n_hours_kept.mean(), 2),\n round(df.n_hours_dropped.mean(), 2),\n ],\n \"values\",\n )", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"H\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 1" + ], + "line": 168, + "token": 985, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef initialize_model_parallel(\n tensor_model_parallel_size: int = 1,\n pipeline_model_parallel_size: int = 1,\n virtual_pipeline_model_parallel_size: Optional[int] = None,\n pipeline_model_parallel_split_rank: Optional[int] = None,\n use_fp8: bool = False,\n) -> None:\n # Get world size and rank. Ensure some consistencies.\n assert torch.distributed.is_initialized()\n world_size: int = torch.distributed.get_world_size()\n\n if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:\n raise RuntimeError(\n f\"world_size ({world_size}) is not divisible by tensor_model_parallel_size \"\n f\"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n )\n\n data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n\n num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n num_data_parallel_groups: int = world_size // data_parallel_size\n\n if virtual_pipeline_model_parallel_size is not None:\n if not pipeline_model_parallel_size > 2:\n raise RuntimeError(\"pipeline-model-parallel size should be greater than 2 with \" \"interleaved schedule\")\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size\n\n if pipeline_model_parallel_split_rank is not None:\n global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK\n _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank\n\n rank = torch.distributed.get_rank()\n\n # Build the data-parallel groups.\n global _DATA_PARALLEL_GROUP\n global _DATA_PARALLEL_GROUP_GLOO\n global _DATA_PARALLEL_GLOBAL_RANKS\n assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'\n all_data_parallel_group_ranks = []\n for i in range(pipeline_model_parallel_size):\n start_rank = i * num_pipeline_model_parallel_groups\n end_rank = (i + 1) * num_pipeline_model_parallel_groups\n for j in range(tensor_model_parallel_size):\n ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)\n all_data_parallel_group_ranks.append(list(ranks))\n group = torch.distributed.new_group(ranks)\n group_gloo = torch.distributed.new_group(ranks, backend=\"gloo\")\n if rank in ranks:\n _DATA_PARALLEL_GROUP = group\n _DATA_PARALLEL_GROUP_GLOO = group_gloo\n _DATA_PARALLEL_GLOBAL_RANKS = ranks\n\n # Build the model-parallel groups.\n global _MODEL_PARALLEL_GROUP\n assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'\n for i in range(data_parallel_size):\n ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _MODEL_PARALLEL_GROUP = group\n\n # Build the tensor model-parallel groups.\n global _TENSOR_MODEL_PARALLEL_GROUP\n assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized'\n for i in range(num_tensor_model_parallel_groups):\n ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _TENSOR_MODEL_PARALLEL_GROUP = group\n\n # Build the pipeline model-parallel groups and embedding groups\n # (first and last rank in each pipeline model-parallel group).\n global _PIPELINE_MODEL_PARALLEL_GROUP\n global _PIPELINE_GLOBAL_RANKS\n assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized'\n global _EMBEDDING_GROUP\n global _EMBEDDING_GLOBAL_RANKS\n assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'\n global _POSITION_EMBEDDING_GROUP\n global _POSITION_EMBEDDING_GLOBAL_RANKS\n assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'\n for i in range(num_pipeline_model_parallel_groups):\n ranks = range(i, world_size, num_pipeline_model_parallel_groups)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _PIPELINE_MODEL_PARALLEL_GROUP = group\n _PIPELINE_GLOBAL_RANKS = ranks\n # Setup embedding group (to exchange gradients between\n # first and last stages).\n if len(ranks) > 1:\n embedding_ranks = [ranks[0], ranks[-1]]\n position_embedding_ranks = [ranks[0]]\n if pipeline_model_parallel_split_rank is not None:\n if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:\n embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]]\n if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:\n position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]\n else:\n embedding_ranks = ranks\n position_embedding_ranks = ranks\n\n group = torch.distributed.new_group(embedding_ranks)\n if rank in embedding_ranks:\n _EMBEDDING_GROUP = group\n if rank in ranks:\n _EMBEDDING_GLOBAL_RANKS = embedding_ranks\n\n group = torch.distributed.new_group(position_embedding_ranks)\n if rank in position_embedding_ranks:\n _POSITION_EMBEDDING_GROUP = group\n if rank in ranks:\n _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks\n\n # Build the FP8 groups.\n global _AMAX_REDUCTION_GROUP\n assert _AMAX_REDUCTION_GROUP is None, \\\n 'FP8 amax reduction group is already initialized'\n if use_fp8:\n amax_group_size: int = tensor_model_parallel_size * data_parallel_size\n num_amax_groups: int = world_size // amax_group_size\n for i in range(num_amax_groups):\n start_rank = i * amax_group_size\n end_rank = (i + 1) * amax_group_size\n ranks = range(start_rank, end_rank)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _AMAX_REDUCTION_GROUP = group\n\n # Initialize global memory buffer\n # This isn't really \"parallel state\" but there isn't another good place to\n # put this. If we end up with a more generic initialization of megatron-core\n # we could stick it there\n _set_global_memory_buffer()\n```\n###test function signature:\n```python\n\ndef test_initialize__and_destroy_model_parallel():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_initialize__and_destroy_model_parallel", + "reference": "\ndef test_initialize__and_destroy_model_parallel():\n with pytest.raises(AssertionError):\n assert(ps.initialize_model_parallel())\n Utils.initialize_distributed()\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))\n Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)\n\n assert(ps.model_parallel_is_initialized())\n assert(ps.get_model_parallel_group() is not None)\n assert(ps.get_tensor_model_parallel_group() is not None)\n assert(ps.get_pipeline_model_parallel_group() is not None)\n assert(ps.get_data_parallel_group() is not None) \n Utils.destroy_model_parallel()\n assert(ps._MODEL_PARALLEL_GROUP is None)", + "completions": [ + "```python\ndef test_initialize_and_destroy_model_parallel():\n initialize_model_parallel(\n tensor_model_parallel_size=1,\n pipeline_model_parallel_size=1,\n virtual_pipeline_model_parallel_size=None,\n pipeline_model_parallel_split_rank=None,\n use_fp8=False,\n )\n destroy_model_parallel()\n```\n" + ], + "line": 133, + "token": 886, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_shapes(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_shapes", + "reference": "\n def test_spectrogram_shapes(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 729))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 183))", + "completions": [ + "```python\n def test_spectrogram_shapes(self):\n waveform = np.random.rand(1000)\n window = np.hanning(100)\n frame_length = 100\n hop_length = 10\n fft_length = 100\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (10, 100))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_base64_image_given_and_no_resize_needed", + "reference": "\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n _, image = example_local_image\n _, buffer = cv2.imencode(\".jpg\", image)\n base64_image = base64.b64encode(buffer).decode(\"ascii\")\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=base64_image,\n max_width=None,\n max_height=None,\n )\n\n # then\n assert (\n scaling_factor is None\n ), \"No resize parameters given, so scaling factor should not be established\"\n assert (\n serialised_image == base64_image\n ), \"Serialised image should be identical with input\"", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n image_bytes = base64.b64decode(example_local_image[0])\n image = bytes_to_opencv_image(payload=image_bytes)\n # when\n result = load_image_from_string(\n reference=example_local_image[0],\n max_height=None,\n max_width=None,\n )\n # then\n assert result[0] == example_local_image[0]\n assert result[1] is None\n assert np.array_equal(image, result[1])\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 69, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def submit_payout_address(self, robot_index=1):\n path = reverse(\"order\")\n params = f\"?order_id={self.order_id}\"\n headers = self.get_robot_auth(robot_index)\n\n payout_address = create_address(\"robot\")\n signed_payout_address = sign_message(\n payout_address,\n passphrase_path=f\"tests/robots/{robot_index}/token\",\n private_key_path=f\"tests/robots/{robot_index}/enc_priv_key\",\n )\n body = {\n \"action\": \"update_address\",\n \"address\": signed_payout_address,\n \"mining_fee_rate\": 50,\n }\n self.response = self.client.post(path + params, body, **headers)\n```\n###test function signature:\n```python\n def test_successful_onchain(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_successful_onchain", + "reference": " def test_successful_onchain(self):\n trade = Trade(self.client)\n trade.publish_order()\n trade.take_order()\n trade.lock_taker_bond()\n trade.lock_escrow(trade.taker_index)\n trade.submit_payout_address(trade.maker_index)\n trade.confirm_fiat(trade.maker_index)\n trade.confirm_fiat(trade.taker_index)\n\n trade.process_payouts(mine_a_block=True)\n trade.get_order(trade.maker_index)\n\n data = trade.response.json()\n\n self.assertEqual(trade.response.status_code, 200)\n self.assertResponse(trade.response)\n\n self.assertEqual(data[\"status_message\"], Order.Status(Order.Status.SUC).label)\n self.assertTrue(data[\"is_fiat_sent\"])\n self.assertFalse(data[\"is_disputed\"])\n self.assertIsInstance(data[\"maker_summary\"][\"address\"], str)\n self.assertIsHash(data[\"maker_summary\"][\"txid\"])\n\n self.assert_order_logs(data[\"id\"])", + "completions": [ + "```python\n def test_successful_onchain(self):\n self.submit_payout_address(robot_index=1)\n self.assertEqual(self.response.status_code, 200)\n self.assertEqual(self.response.json()[\"status\"], \"success\")\n```\n" + ], + "line": 26, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 70, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_blank_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_blank_line", + "reference": "\ndef test_header_to_metadata_and_cell_blank_line():\n text = \"\"\"---\ntitle: Sample header\n---\n\nHeader is followed by a blank line\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {}\n assert lines[pos].startswith(\"Header is\")", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_blank_line():\n lines = [\"#' ---\", \"\"]\n header_prefix = \"#'\"\n header_suffix = \"\"\n ext = None\n root_level_metadata_as_raw_cell = True\n metadata, jupyter, cell, i = header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext, root_level_metadata_as_raw_cell\n )\n assert metadata == {}\n assert jupyter == []\n assert cell == new_raw_cell(source=\"\\n\".join([\"---\"]), metadata={\"lines_to_next_cell\": 1})\n assert i == 1\n```\n" + ], + "line": 104, + "token": 642, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 71, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def _create_output(\n self,\n accepted: torch.Tensor, # [batch_size, k]\n recovered_token_ids: torch.Tensor, # [batch_size, k]\n draft_token_ids: torch.Tensor, # [batch_size, k]\n bonus_token_ids: torch.Tensor, # [batch_size]\n ) -> torch.Tensor:\n bonus_token_ids = bonus_token_ids.squeeze()\n batch_size, k = recovered_token_ids.shape\n\n # Determine the index of the first False value for each row.\n limits = (accepted == 0).max(1).indices\n limits[~(accepted == 0).any(1)] = k\n\n # Create masks using the indices.\n indices = torch.arange(k, device=accepted.device).unsqueeze(0)\n accepted_mask = indices < limits.unsqueeze(1)\n after_false_mask = indices == limits.unsqueeze(1)\n\n # Create an extended output tensor\n output_with_bonus_tokens = -torch.ones(\n (batch_size, k + self._num_bonus_tokens),\n dtype=self.token_id_dtype,\n device=accepted.device)\n output = output_with_bonus_tokens[:, :k]\n\n # Fill in the first k columns of the output tensor using masks and data\n # tensors.\n output[:, :k] = torch.where(accepted_mask, draft_token_ids,\n -torch.ones_like(draft_token_ids))\n\n # Fill the last column.\n # We check output directly as accepted may have True values inconsistent\n # with causal acceptance.\n output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,\n bonus_token_ids, -1)\n\n # We disable bonus tokens because it causes corrupt KV cache for\n # proposal methods that require KV cache. We can fix it by \"prefilling\"\n # the bonus token in the proposer. The following issue tracks the fix.\n # https://github.com/vllm-project/vllm/issues/4212\n output_with_bonus_tokens[:, -1] = -1\n\n # Fill the recovered token ids.\n output.mul_(~after_false_mask).add_(\n recovered_token_ids.mul(after_false_mask))\n\n self.num_accepted_tokens += accepted.sum()\n self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()\n self.num_draft_tokens += batch_size * k\n\n return output_with_bonus_tokens\n```\n###test function signature:\n```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct_output_format", + "reference": "def test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n set_random_seed(seed)\n torch.set_default_device(device)\n\n batch_size = 10\n k = 5\n vocab_size = 3000\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"no_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"some_tokens_accepted\":\n last_accepted_indices = torch.randint(low=-1,\n high=k,\n size=(batch_size, ))\n accepted = mock_causal_accepted_tensor(k, last_accepted_indices)\n else:\n raise AssertionError()\n\n recovered_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n draft_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n bonus_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, 1),\n dtype=torch.int64)\n\n rejection_sampler = RejectionSampler()\n rejection_sampler.init_gpu_tensors(rank=0)\n output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access\n accepted,\n recovered_token_ids,\n draft_token_ids,\n bonus_token_ids,\n )\n\n # Bonus tokens are currently disabled. Verify they're set to -1.\n # See https://github.com/vllm-project/vllm/issues/4212\n expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n # Expect all tokens to be equal to draft tokens.\n assert torch.equal(output_token_ids[:, :-1], draft_token_ids)\n\n # Expect all bonus tokens to be included.\n assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)\n elif which_tokens_accepted == \"no_tokens_accepted\":\n # Expect first token to be equal to recovered tokens.\n assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])\n\n # Expect everything else to be -1.\n assert torch.equal(output_token_ids[:, 1:],\n torch.ones_like(output_token_ids[:, 1:]) * -1)\n elif which_tokens_accepted == \"some_tokens_accepted\":\n recovered_plus_bonus = torch.cat(\n (recovered_token_ids, expected_bonus_token_ids), dim=-1)\n # Assert first rejected token is a recovered token or bonus token.\n assert torch.equal(\n recovered_plus_bonus[torch.arange(0, batch_size),\n last_accepted_indices + 1],\n output_token_ids[torch.arange(0, batch_size),\n last_accepted_indices + 1])\n\n # Assert every subsequent token is -1.\n subsequent_mask = torch.arange(0, k + 1).expand(\n batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)\n assert torch.all(output_token_ids[subsequent_mask] == -1)", + "completions": [ + "```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n torch.manual_seed(seed)\n batch_size = 10\n k = 10\n num_bonus_tokens = 1\n accepted = torch.randint(0, 2, (batch_size, k), dtype=torch.int32)\n recovered_token_ids = torch.randint(0, 10, (batch_size, k),\n dtype=torch.int32)\n draft_token_ids = torch.randint(0, 10, (batch_size, k), dtype=torch.int32)\n bonus_token_ids = torch.randint(0, 10, (batch_size,), dtype=torch.int32)\n output = _create_output(accepted, recovered_token_ids, draft_token_ids,\n bonus_token_ids)\n assert output.shape == (batch_size, k + num_bonus_tokens)\n assert output.dtype == torch.int32\n assert output.device == device\n assert (output[:, :k] == torch.where(accepted, draft_token_ids,\n -torch.ones_like(draft_token_ids))).all()\n assert (output[:, -1] == torch.where(output[:, -1] != -1, bonus_token_ids,\n -1)).all()\n assert (output[:, -1] == -1).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted" + ], + "line": 54, + "token": 433, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 72, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_array(fp, allow_pickle=False, pickle_kwargs=None, *,\n max_header_size=_MAX_HEADER_SIZE):\n if allow_pickle:\n # Effectively ignore max_header_size, since `allow_pickle` indicates\n # that the input is fully trusted.\n max_header_size = 2**64\n\n version = read_magic(fp)\n _check_version(version)\n shape, fortran_order, dtype = _read_array_header(\n fp, version, max_header_size=max_header_size)\n if len(shape) == 0:\n count = 1\n else:\n count = numpy.multiply.reduce(shape, dtype=numpy.int64)\n\n # Now read the actual data.\n if dtype.hasobject:\n # The array contained Python objects. We need to unpickle the data.\n if not allow_pickle:\n raise ValueError(\"Object arrays cannot be loaded when \"\n \"allow_pickle=False\")\n if pickle_kwargs is None:\n pickle_kwargs = {}\n try:\n array = pickle.load(fp, **pickle_kwargs)\n except UnicodeError as err:\n # Friendlier error message\n raise UnicodeError(\"Unpickling a python object failed: %r\\n\"\n \"You may need to pass the encoding= option \"\n \"to numpy.load\" % (err,)) from err\n else:\n if isfileobj(fp):\n # We can use the fast fromfile() function.\n array = numpy.fromfile(fp, dtype=dtype, count=count)\n else:\n # This is not a real file. We have to read it the\n # memory-intensive way.\n # crc32 module fails on reads greater than 2 ** 32 bytes,\n # breaking large reads from gzip streams. Chunk reads to\n # BUFFER_SIZE bytes to avoid issue and reduce memory overhead\n # of the read. In non-chunked case count < max_read_count, so\n # only one read is performed.\n\n # Use np.ndarray instead of np.empty since the latter does\n # not correctly instantiate zero-width string dtypes; see\n # https://github.com/numpy/numpy/pull/6430\n array = numpy.ndarray(count, dtype=dtype)\n\n if dtype.itemsize > 0:\n # If dtype.itemsize == 0 then there's nothing more to read\n max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)\n\n for i in range(0, count, max_read_count):\n read_count = min(max_read_count, count - i)\n read_size = int(read_count * dtype.itemsize)\n data = _read_bytes(fp, read_size, \"array data\")\n array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n count=read_count)\n\n if fortran_order:\n array.shape = shape[::-1]\n array = array.transpose()\n else:\n array.shape = shape\n\n return array\n```\n###test function signature:\n```python\n\ndef test_version_2_0():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_version_2_0", + "reference": "\ndef test_version_2_0():\n f = BytesIO()\n # requires more than 2 byte for header\n dt = [((\"%d\" % i) * 100, float) for i in range(500)]\n d = np.ones(1000, dtype=dt)\n\n format.write_array(f, d, version=(2, 0))\n with warnings.catch_warnings(record=True) as w:\n warnings.filterwarnings('always', '', UserWarning)\n format.write_array(f, d)\n assert_(w[0].category is UserWarning)\n\n # check alignment of data portion\n f.seek(0)\n header = f.readline()\n assert_(len(header) % format.ARRAY_ALIGN == 0)\n\n f.seek(0)\n n = format.read_array(f, max_header_size=200000)\n assert_array_equal(d, n)\n\n # 1.0 requested but data cannot be saved this way\n assert_raises(ValueError, format.write_array, f, d, (1, 0))", + "completions": [ + "```python\ndef test_version_2_0():\n fp = open(\"test_file.npy\", \"rb\")\n read_array(fp)\n```\n" + ], + "line": 70, + "token": 516, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 73, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):\n if not torch.is_tensor(audio):\n if isinstance(audio, str):\n audio = load_audio(audio)\n audio = torch.from_numpy(audio)\n\n window = torch.hann_window(N_FFT).to(audio.device)\n stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)\n magnitudes = stft[..., :-1].abs() ** 2\n\n filters = mel_filters(audio.device, n_mels)\n mel_spec = filters @ magnitudes\n\n log_spec = torch.clamp(mel_spec, min=1e-10).log10()\n log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)\n log_spec = (log_spec + 4.0) / 4.0\n return log_spec\n```\n###test function signature:\n```python\n\ndef test_audio():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_audio", + "reference": "\ndef test_audio():\n audio_path = os.path.join(os.path.dirname(__file__), \"jfk.flac\")\n audio = load_audio(audio_path)\n assert audio.ndim == 1\n assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12\n assert 0 < audio.std() < 1\n\n mel_from_audio = log_mel_spectrogram(audio)\n mel_from_file = log_mel_spectrogram(audio_path)\n\n assert np.allclose(mel_from_audio, mel_from_file)\n assert mel_from_audio.max() - mel_from_audio.min() <= 2.0", + "completions": [ + "```python\ndef test_audio():\n audio = torch.randn(1, 16000)\n log_mel_spectrogram(audio)\n```\n" + ], + "line": 24, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 74, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef transform_grid_cell(\n grid: Grid,\n cell: npt.NDArray[np.float_] = np.eye(3),\n output_shape: Optional[Tuple[int, int, int]] = None,\n mode: str = \"constant\",\n order: int = 1,\n **kwargs,\n) -> Grid:\n # Take the current shape of the grid if no output shape was provided\n if output_shape is None:\n output_shape = grid.shape\n\n # Make sure the cell has type float\n cell = np.asarray(cell, dtype=float)\n\n # Get the current cell in coordinates of the destination axes\n inv_cell = cell_invert(cell).T\n projected_cell = grid.cell.dot(inv_cell)\n\n # From that, infere how long will the bounding box of the cell be\n lengths = abs(projected_cell).sum(axis=0)\n\n # Create the transformation matrix. Since we want to control the shape\n # of the output, we can not use grid.dcell directly, we need to modify it.\n scales = output_shape / lengths\n forward_t = (grid.dcell.dot(inv_cell) * scales).T\n\n # Scipy's affine transform asks for the inverse transformation matrix, to\n # map from output pixels to input pixels. By taking the inverse of our\n # transformation matrix, we get exactly that.\n tr = cell_invert(forward_t).T\n\n # Calculate the offset of the image so that all points of the grid \"fall\" inside\n # the output array.\n # For this we just calculate the centers of the input and output images\n center_input = 0.5 * (_a.asarrayd(grid.shape) - 1)\n center_output = 0.5 * (_a.asarrayd(output_shape) - 1)\n\n # And then make sure that the input center that is interpolated from the output\n # falls in the actual input's center\n offset = center_input - tr.dot(center_output)\n\n # We pass all the parameters to scipy's affine_transform\n transformed_image = affine_transform(\n grid.grid,\n tr,\n order=1,\n offset=offset,\n output_shape=output_shape,\n mode=mode,\n **kwargs,\n )\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1))\n new_grid.grid = transformed_image\n new_grid.geometry = grid.geometry\n new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset))\n\n # Find the offset between the origin before and after the transformation\n return new_grid\n```\n###test function signature:\n```python\n\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_grid_cell", + "reference": "\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10))\n\n assert new_grid.shape == (10, 10, 10)\n assert new_grid.lattice.is_cartesian()\n\n if not skewed:\n assert np.allclose(new_grid.lattice.cell, grid.lattice.cell)\n\n assert np.allclose(new_grid.grid, grid.grid) == (not skewed)\n assert not np.allclose(new_grid.grid, 0)\n\n assert np.allclose(new_grid.origin, grid.origin) == (not skewed)\n\n # Convert to a skewed cell\n directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]])\n new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5))\n\n assert new_grid.shape == (5, 5, 5)\n for i in range(3):\n n = new_grid.lattice.cell[i] / directions[i]\n assert np.allclose(n, n[0])", + "completions": [ + "```python\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n cell = grid.cell.dot(cell_invert(grid.cell))\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = transform_grid_cell(grid, cell, grid.shape, mode=\"constant\", order=1)\n\n # Check that the new grid has the same shape and cell as the original grid\n assert new_grid.shape == grid.shape\n assert np.allclose(new_grid.cell, grid.cell)\n\n # Check that the new grid has the same geometry as the original grid\n assert new_grid.geometry == grid.geometry\n\n # Check that the new grid has the same origin as the original grid\n assert np.allclose(new_grid.origin, grid.origin)\n\n # Check that the new grid has the same dcell as the original grid\n assert np.allclose(new_grid.dcell, grid.dcell)\n\n # Check that the new grid has the same grid as the original grid\n assert np.allclose(new_grid.grid, grid.grid)\n\n # Check that the new grid has the same skewed as the original grid\n assert np.allclose(new_grid.skewed, grid.skewed)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid" + ], + "line": 63, + "token": 533, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:\n\n index_norm = norm(index, axis=1)\n xq_norm = norm(xq.T)\n sim = np.dot(index, xq.T) / (index_norm * xq_norm)\n return sim\n```\n###test function signature:\n```python\ndef test_top_scores__is_sorted(test_index):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_top_scores__is_sorted", + "reference": "def test_top_scores__is_sorted(test_index):\n\n xq = test_index[0] # should have max similarity\n\n sim = similarity_matrix(xq, test_index)\n _, idx = top_scores(sim, 3)\n\n # Scores and indexes should be sorted ascending\n assert np.array_equal(idx, np.array([2, 1, 0]))", + "completions": [ + "```python\ndef test_top_scores__is_sorted(test_index):\n sim = similarity_matrix(xq, test_index)\n assert np.all(np.diff(sim, axis=1) >= 0)\n```\n" + ], + "line": 15, + "token": 125, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_mixed_ref_area_2", + "reference": "\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n\n out_bars_with_mixed_ref_area_2 = _generate_nanoplot(**CASES[16])\n\n assert _is_nanoplot_output(out_bars_with_mixed_ref_area_2)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n y_ref_area = [1, 10]\n y_ref_line = 5\n x_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n expand_x = \"2020-01-01 00:00:00 UTC\"\n expand_y = \"2020-01-01 00:00:00 UTC\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n all_single_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_" + ], + "line": 749, + "token": 7038, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 77, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def mean(self, lower=0.0, upper=np.inf):\n loc = []\n for idx in self.index:\n qpd = self.qpd.loc[idx, :].values[0]\n l, _ = quad(exp_func, args=(qpd), a=lower, b=upper)\n loc.append(l)\n loc_arr = np.array(loc)\n return pd.DataFrame(loc_arr, index=self.index, columns=self.columns)\n```\n###test function signature:\n```python\ndef test_qpd_b_simple_use():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_qpd_b_simple_use", + "reference": "def test_qpd_b_simple_use():\n from skpro.distributions.qpd import QPD_B\n\n qpd = QPD_B(\n alpha=0.2,\n qv_low=[1, 2],\n qv_median=[3, 4],\n qv_high=[5, 6],\n lower=0,\n upper=10,\n )\n\n qpd.mean()", + "completions": [ + "```python\ndef test_qpd_b_simple_use():\n import numpy as np\n import pandas as pd\n from scipy.integrate import quad\n from scipy.special import expit\n\n def exp_func(x, qpd):\n return expit(qpd * x)\n\n qpd = pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z'])\n qpd_b = qpd.mean(lower=0.0, upper=np.inf)\n assert qpd_b.equals(pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z']))\n```\n" + ], + "line": 18, + "token": 146, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef create(\n supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False\n) -> adapters.HTTPAdapter:\n settings = supply_settings()\n allowed_methods = set(util.Retry.DEFAULT_ALLOWED_METHODS)\n if include_post:\n allowed_methods.add(\"POST\")\n\n retries = util.Retry(\n total=settings.total,\n backoff_factor=settings.backoff_factor,\n status_forcelist=settings.statuses,\n allowed_methods=allowed_methods,\n )\n return adapters.HTTPAdapter(max_retries=retries)\n```\n###test function signature:\n```python\n\ndef test_should_create_adapter(include_post):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_should_create_adapter", + "reference": "\ndef test_should_create_adapter(include_post):\n total_retries = 5\n backoff_factor = 0.4\n statuses = (random.randint(400, 600),)\n settings = config.ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses)\n adapter = http_adapter.create(supply_settings=lambda: settings, include_post=include_post)\n assert adapter.max_retries.total == total_retries\n assert adapter.max_retries.backoff_factor == backoff_factor\n assert adapter.max_retries.status_forcelist == statuses\n allowed_methods = adapter.max_retries.allowed_methods or []\n if include_post:\n assert \"POST\" in allowed_methods\n else:\n assert \"POST\" not in allowed_methods", + "completions": [ + "```python\n\ndef test_should_create_adapter(include_post):\n assert create(include_post=include_post) == adapters.HTTPAdapter(max_retries=util.Retry(\n total=config.ResilienceSettings().total,\n backoff_factor=config.ResilienceSettings().backoff_factor,\n status_forcelist=config.ResilienceSettings().statuses,\n allowed_methods=set(util.Retry.DEFAULT_ALLOWED_METHODS) if not include_post else set(util.Retry.DEFAULT_ALLOWED_METHODS).union({\"POST\"})\n ))\n```\n" + ], + "line": 24, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 79, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:\n summary_path = os.path.join(trial_path, 'summary.csv')\n if not os.path.exists(summary_path):\n raise ValueError(f\"summary.csv does not exist in {trial_path}.\")\n trial_summary_df = load_summary_file(summary_path, dict_columns=['best_module_params'])\n config_yaml_path = os.path.join(trial_path, 'config.yaml')\n with open(config_yaml_path, 'r') as f:\n config_dict = yaml.safe_load(f)\n yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)\n if output_path is not None:\n with open(output_path, 'w') as f:\n yaml.dump(yaml_dict, f)\n return yaml_dict\n```\n###test function signature:\n```python\n\ndef test_runner(evaluator):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_runner", + "reference": "\ndef test_runner(evaluator):\n os.environ['BM25'] = 'bm25'\n evaluator.start_trial(os.path.join(resource_dir, 'simple.yaml'))\n project_dir = evaluator.project_dir\n\n def runner_test(runner: Runner):\n answer = runner.run('What is the best movie in Korea? Have Korea movie ever won Oscar?',\n 'retrieved_contents')\n assert len(answer) == 1\n assert isinstance(answer, list)\n assert isinstance(answer[0], str)\n\n runner = Runner.from_trial_folder(os.path.join(project_dir, '0'))\n runner_test(runner)\n runner_test(runner)\n\n with tempfile.NamedTemporaryFile(suffix='yaml', mode='w+t') as yaml_path:\n extract_best_config(os.path.join(project_dir, '0'), yaml_path.name)\n runner = Runner.from_yaml(yaml_path.name, project_dir=project_dir)\n runner_test(runner)", + "completions": [ + "```python\ndef test_runner(evaluator):\n def test_extract_best_config():\n trial_path = 'path/to/trial'\n output_path = 'path/to/output'\n yaml_dict = extract_best_config(trial_path, output_path)\n assert yaml_dict is not None\n assert isinstance(yaml_dict, dict)\n assert 'best_module_params' in yaml_dict\n assert isinstance(yaml_dict['best_module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['" + ], + "line": 23, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 80, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def init_weights(self):\n logger = MMLogger.get_current_instance()\n if self.init_cfg is None:\n logger.warn(f'No pre-trained weights for '\n f'{self.__class__.__name__}, '\n f'training start from scratch')\n if self.use_abs_pos_embed:\n trunc_normal_(self.absolute_pos_embed, std=0.02)\n for m in self.modules():\n if isinstance(m, nn.Linear):\n trunc_normal_init(m, std=.02, bias=0.)\n elif isinstance(m, nn.LayerNorm):\n constant_init(m, 1.0)\n else:\n assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n f'specify `Pretrained` in ' \\\n f'`init_cfg` in ' \\\n f'{self.__class__.__name__} '\n ckpt = CheckpointLoader.load_checkpoint(\n self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n if 'state_dict' in ckpt:\n _state_dict = ckpt['state_dict']\n elif 'model' in ckpt:\n _state_dict = ckpt['model']\n else:\n _state_dict = ckpt\n if self.convert_weights:\n # supported loading weight from original repo,\n _state_dict = swin_converter(_state_dict)\n\n state_dict = OrderedDict()\n for k, v in _state_dict.items():\n if k.startswith('backbone.'):\n state_dict[k[9:]] = v\n\n # strip prefix of state_dict\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n # reshape absolute position embedding\n if state_dict.get('absolute_pos_embed') is not None:\n absolute_pos_embed = state_dict['absolute_pos_embed']\n N1, L, C1 = absolute_pos_embed.size()\n N2, C2, H, W = self.absolute_pos_embed.size()\n if N1 != N2 or C1 != C2 or L != H * W:\n logger.warning('Error in loading absolute_pos_embed, pass')\n else:\n state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n # interpolate position bias table if needed\n relative_position_bias_table_keys = [\n k for k in state_dict.keys()\n if 'relative_position_bias_table' in k\n ]\n for table_key in relative_position_bias_table_keys:\n table_pretrained = state_dict[table_key]\n table_current = self.state_dict()[table_key]\n L1, nH1 = table_pretrained.size()\n L2, nH2 = table_current.size()\n if nH1 != nH2:\n logger.warning(f'Error in loading {table_key}, pass')\n elif L1 != L2:\n S1 = int(L1**0.5)\n S2 = int(L2**0.5)\n table_pretrained_resized = F.interpolate(\n table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n size=(S2, S2),\n mode='bicubic')\n state_dict[table_key] = table_pretrained_resized.view(\n nH2, L2).permute(1, 0).contiguous()\n\n # load state_dict\n self.load_state_dict(state_dict, False)\n```\n###test function signature:\n```python\ndef test_swin_transformer():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_swin_transformer", + "reference": "def test_swin_transformer():\n\n with pytest.raises(TypeError):\n # Pretrained arg must be str or None.\n SwinTransformer(pretrained=123)\n\n with pytest.raises(AssertionError):\n # Because swin uses non-overlapping patch embed, so the stride of patch\n # embed must be equal to patch size.\n SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)\n\n # test pretrained image size\n with pytest.raises(AssertionError):\n SwinTransformer(pretrain_img_size=(224, 224, 224))\n\n # Test absolute position embedding\n temp = torch.randn((1, 3, 224, 224))\n model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)\n model.init_weights()\n model(temp)\n\n # Test patch norm\n model = SwinTransformer(patch_norm=False)\n model(temp)\n\n # Test normal inference\n temp = torch.randn((1, 3, 32, 32))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 31, 31))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 112, 137))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 28, 35)\n assert outs[1].shape == (1, 192, 14, 18)\n assert outs[2].shape == (1, 384, 7, 9)\n assert outs[3].shape == (1, 768, 4, 5)\n\n model = SwinTransformer(frozen_stages=4)\n model.train()\n for p in model.parameters():\n assert not p.requires_grad", + "completions": [ + "```python\ndef test_swin_transformer():\n from mmcv.cnn import (build_pos_embed, build_pos_embed_sine,\n constant_init, trunc_normal_init)\n from mmcv.runner import CheckpointLoader\n from mmdet.models.builder import build_backbone\n from mmdet.models.utils import swin_converter\n from mmdet.utils import get_root_logger\n from mmdet.utils.logger import MMLogger\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import PatchEmbed\n from mmdet.models.backbones.swin import PatchMerging\n from mmdet.models.backbones.swin import BasicLayer\n from mmdet.models.backbones.swin import BasicBlock\n from mmdet.models.backbones.swin import DepthwiseSeparableConv\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransform" + ], + "line": 79, + "token": 586, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 81, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\n\ndef test_date_columns_equal():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_date_columns_equal", + "reference": "\ndef test_date_columns_equal():\n data = \"\"\"a|b|expected\n2017-01-01|2017-01-01|True\n2017-01-02|2017-01-02|True\n2017-10-01|2017-10-10|False\n2017-01-01||False\n|2017-01-01|False\n||True\"\"\"\n df = pl.read_csv(\n io.StringIO(data),\n separator=\"|\",\n null_values=[\"NULL\"],\n missing_utf8_is_empty_string=True,\n )\n # First compare just the strings\n actual_out = columns_equal(df[\"a\"], df[\"b\"], rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)\n\n # Then compare converted to datetime objects\n col_a = df[\"a\"].str.to_datetime(strict=False)\n col_b = df[\"b\"].str.to_datetime(strict=False)\n actual_out = columns_equal(col_a, col_b, rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)", + "completions": [ + "```python\ndef test_date_columns_equal():\n col_1 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n col_2 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 82, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def upsert_child_firing(\n db: PrefectDBInterface,\n session: AsyncSession,\n firing: Firing,\n):\n automation_id = firing.trigger.automation.id\n parent_trigger_id = firing.trigger.parent.id\n child_trigger_id = firing.trigger.id\n\n upsert = (\n postgresql.insert(db.CompositeTriggerChildFiring)\n .values(\n automation_id=automation_id,\n parent_trigger_id=parent_trigger_id,\n child_trigger_id=child_trigger_id,\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n )\n .on_conflict_do_update(\n index_elements=[\n db.CompositeTriggerChildFiring.automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id,\n ],\n set_=dict(\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n updated=pendulum.now(\"UTC\"),\n ),\n )\n )\n\n await session.execute(upsert)\n\n result = await session.execute(\n sa.select(db.CompositeTriggerChildFiring).filter(\n db.CompositeTriggerChildFiring.automation_id == automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id == parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id == child_trigger_id,\n )\n )\n\n return result.scalars().one()\n```\n###test function signature:\n```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_insert_compound_trigger_child_firing", + "reference": "\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n firing = Firing(\n trigger=saw_a_red_dragon,\n trigger_states={TriggerState.Triggered},\n triggered=frozen_time.add(seconds=1),\n triggering_labels={},\n triggering_event=baby_red_dragon_passed_by,\n )\n res = await upsert_child_firing(\n session=automations_session,\n firing=firing,\n )\n\n assert res and res.id\n assert res.automation_id == raise_the_alarm.id\n assert res.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert res.child_trigger_id == saw_a_red_dragon.id\n assert res.child_firing_id == firing.id\n assert res.child_fired_at == firing.triggered\n\n child_firings = await get_child_firings(\n session=automations_session, trigger=saw_a_red_and_green_dragon\n )\n assert len(child_firings) == 1\n assert child_firings[0].id == res.id\n assert child_firings[0].child_firing_id == firing.id", + "completions": [ + "```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n \"\"\"\n Test that a compound trigger child firing is inserted correctly\n \"\"\"\n # Arrange\n firing = Firing(\n id=uuid4(),\n triggered=frozen_time,\n trigger=saw_a_red_and_green_dragon,\n event=baby_red_dragon_passed_by,\n )\n\n # Act\n result = await upsert_child_firing(\n db=automations_session,\n session=automations_session,\n firing=firing,\n )\n\n # Assert\n assert result.automation_id == raise_the_alarm.id\n assert result.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert result.child_trigger_id == saw_a_red_dragon.id\n assert result.child_firing_id == firing.id\n assert result.child_fired_at == firing.triggered\n assert result.child_firing == firing.dict()\n assert result.updated == frozen_time\n\n```\n" + ], + "line": 58, + "token": 225, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 83, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef write_output(output_file, output_type, results):\n\n # If json or txt, write contents directly to file.\n if output_type == \"json\" or output_type == \"txt\":\n with open(output_file, \"w\") as f:\n json.dump(results, f, indent=4)\n return\n # If csv or xlsx, convert results to pandas dataframes.\n urls_df = get_urls_df(results)\n codes_df = get_codes_df(results)\n\n # If csv, write dataframes to separate csv files for urls, codes.\n if output_type == \"csv\":\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n urls_df.to_csv(urls_output_file, index=False)\n codes_df.to_csv(codes_output_file, index=False)\n\n # If xlsx, write dataframes to separate sheets for urls, codes.\n if output_type == \"xlsx\":\n writer = pd.ExcelWriter(output_file, engine=\"xlsxwriter\")\n urls_df.to_excel(writer, sheet_name=\"URLs\", index=False)\n codes_df.to_excel(writer, sheet_name=\"Codes\", index=False)\n writer.close()\n```\n###test function signature:\n```python\n def test_write_output_csv(self, mock_urls, mock_codes):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_output_csv", + "reference": " def test_write_output_csv(self, mock_urls, mock_codes):\n\n test_file = \"./test_output/test_file.csv\"\n test_file_urls = \"./test_output/test_file_urls.csv\"\n test_file_codes = \"./test_output/test_file_codes.csv\"\n test_results = {\"test\": \"test\"}\n mock_urls.return_value = pd.DataFrame([test_results])\n mock_codes.return_value = pd.DataFrame([test_results])\n\n with open(test_file_urls, \"w\") as f:\n pass\n\n with open(test_file_codes, \"w\") as f:\n pass\n\n write_output(test_file, \"csv\", test_results)\n\n test_data_urls = pd.read_csv(test_file_urls).to_dict(orient=\"records\")[0]\n test_data_codes = pd.read_csv(test_file_codes).to_dict(orient=\"records\")[0]\n\n os.remove(test_file_urls)\n os.remove(test_file_codes)\n\n self.assertEqual(test_data_urls, test_results)\n self.assertEqual(test_data_codes, test_results)", + "completions": [ + "```python\ndef test_write_output_csv(self, mock_urls, mock_codes):\n \"\"\"\n Test write_output function with csv output type.\n \"\"\"\n output_file = \"output.csv\"\n output_type = \"csv\"\n results = {\"urls\": mock_urls, \"codes\": mock_codes}\n write_output(output_file, output_type, results)\n\n # Check that the urls and codes dataframes are written to separate csv files.\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n self.assertTrue(os.path.isfile(urls_output_file))\n self.assertTrue(os.path.isfile(codes_output_file))\n\n # Check that the dataframes are correct.\n urls_df = pd.read_csv(urls_output_file)\n codes_df = pd.read_csv(codes_output_file)\n self.assertEqual(urls_df.shape, (2, 2))\n self.assertEqual(codes_df.shape, (2, 2))\n self.assertEqual(urls_df.iloc[0, 0], \"url1\")\n self.assertEqual(codes_df.iloc[0, 0], \"code1\")\n self.assertEqual(urls_df.iloc[1, 0], \"url2\")\n self.assertEqual(codes_df.iloc[1, 0], \"code2\")\n```\n" + ], + "line": 31, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 84, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def terminate_idle_instance() -> None:\n async with get_session_ctx() as session:\n async with PROCESSING_POOL_LOCK:\n res = await session.execute(\n select(InstanceModel)\n .where(\n InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE,\n InstanceModel.deleted == False,\n InstanceModel.job == None, # noqa: E711\n InstanceModel.status == InstanceStatus.IDLE,\n )\n .options(joinedload(InstanceModel.project))\n )\n instances = res.scalars().all()\n\n for instance in instances:\n last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)\n if instance.last_job_processed_at is not None:\n last_time = instance.last_job_processed_at.replace(\n tzinfo=datetime.timezone.utc\n )\n\n idle_seconds = instance.termination_idle_time\n delta = datetime.timedelta(seconds=idle_seconds)\n\n current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc)\n\n if last_time + delta < current_time:\n jpd = JobProvisioningData.__response__.parse_raw(\n instance.job_provisioning_data\n )\n await terminate_job_provisioning_data_instance(\n project=instance.project, job_provisioning_data=jpd\n )\n instance.deleted = True\n instance.deleted_at = get_current_datetime()\n instance.finished_at = get_current_datetime()\n instance.status = InstanceStatus.TERMINATED\n instance.termination_reason = \"Idle timeout\"\n\n idle_time = current_time - last_time\n logger.info(\n \"Instance %s terminated by termination policy: idle time %ss\",\n instance.name,\n str(idle_time.seconds),\n extra={\n \"instance_name\": instance.name,\n \"instance_status\": InstanceStatus.TERMINATED.value,\n },\n )\n\n await session.commit()\n```\n###test function signature:\n```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_terminate_by_idle_timeout", + "reference": "\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n project = await create_project(session=session)\n pool = await create_pool(session, project)\n\n instance = await create_instance(session, project, pool, status=InstanceStatus.IDLE)\n instance.termination_idle_time = 300\n instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE\n instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)\n await session.commit()\n\n with patch(\n \"dstack._internal.server.background.tasks.process_instances.terminate_job_provisioning_data_instance\"\n ):\n await terminate_idle_instance()\n\n await session.refresh(instance)\n\n assert instance is not None\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"", + "completions": [ + "```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n # given\n project = ProjectModel(name=\"test_project\")\n await session.add(project)\n await session.commit()\n\n instance = InstanceModel(\n name=\"test_instance\",\n project=project,\n termination_policy=TerminationPolicy.DESTROY_AFTER_IDLE,\n termination_idle_time=10,\n status=InstanceStatus.IDLE,\n created_at=datetime.datetime.now(datetime.timezone.utc),\n )\n await session.add(instance)\n await session.commit()\n\n # when\n await terminate_idle_instance()\n\n # then\n instance = await session.get(InstanceModel, instance.id)\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"\n```\n" + ], + "line": 56, + "token": 289, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 85, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef nargsort(\n items,\n kind: str = \"quicksort\",\n ascending: bool = True,\n na_position: str = \"last\",\n key: Callable | None = None,\n mask: npt.NDArray[np.bool_] | None = None,\n) -> npt.NDArray[np.intp]:\n\n if key is not None:\n items = ensure_key_mapped(items, key)\n return nargsort(\n items,\n kind=kind,\n ascending=ascending,\n na_position=na_position,\n key=None,\n mask=mask,\n )\n\n if isinstance(items, ABCRangeIndex):\n return items.argsort(ascending=ascending) # TODO: test coverage with key?\n elif not isinstance(items, ABCMultiIndex):\n items = extract_array(items)\n if mask is None:\n mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?\n\n if is_extension_array_dtype(items):\n return items.argsort(ascending=ascending, kind=kind, na_position=na_position)\n else:\n items = np.asanyarray(items)\n\n idx = np.arange(len(items))\n non_nans = items[~mask]\n non_nan_idx = idx[~mask]\n\n nan_idx = np.nonzero(mask)[0]\n if not ascending:\n non_nans = non_nans[::-1]\n non_nan_idx = non_nan_idx[::-1]\n indexer = non_nan_idx[non_nans.argsort(kind=kind)]\n if not ascending:\n indexer = indexer[::-1]\n # Finally, place the NaNs at the end or the beginning according to\n # na_position\n if na_position == \"last\":\n indexer = np.concatenate([indexer, nan_idx])\n elif na_position == \"first\":\n indexer = np.concatenate([nan_idx, indexer])\n else:\n raise ValueError(f\"invalid na_position: {na_position}\")\n return ensure_platform_int(indexer)\n```\n###test function signature:\n```python\n\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nargsort", + "reference": "\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)\n\n # mergesort is the most difficult to get right because we want it to be\n # stable.\n\n # According to numpy/core/tests/test_multiarray, \"\"\"The number of\n # sorted items must be greater than ~50 to check the actual algorithm\n # because quick and merge sort fall over to insertion sort for small\n # arrays.\"\"\"\n\n result = nargsort(\n items, kind=\"mergesort\", ascending=ascending, na_position=na_position\n )\n tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)", + "completions": [ + "```python\n def test_nargsort(self, ascending, na_position, exp, box):\n items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216," + ], + "line": 58, + "token": 412, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 86, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_rotated_rb_circuits(\n n_qubits: int,\n num_cliffords: int,\n theta: Optional[float] = None,\n trials: int = 1,\n return_type: Optional[str] = None,\n seed: Optional[int] = None,\n) -> List[QPROGRAM]:\n\n circuits = cast(\n List[cirq.Circuit],\n generate_rb_circuits(n_qubits, num_cliffords, trials, seed=seed),\n )\n\n if theta is None:\n theta = 2 * np.pi * np.random.rand()\n\n for circ in circuits:\n qubits = list(circ.all_qubits())\n circ.insert(len(circ) // 2, cirq.Rz(rads=theta).on(qubits[0]))\n\n return_type = \"cirq\" if not return_type else return_type\n return [convert_from_mitiq(circuit, return_type) for circuit in circuits]\n```\n###test function signature:\n```python\n\ndef test_rotated_rb_circuits(n_qubits, theta):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rotated_rb_circuits", + "reference": "\ndef test_rotated_rb_circuits(n_qubits, theta):\n depth = 10\n for trials in [2, 3]:\n circuits = generate_rotated_rb_circuits(\n n_qubits=n_qubits, num_cliffords=depth, theta=theta, trials=trials\n )\n for circ in circuits:\n zero_prob = (\n cirq.DensityMatrixSimulator()\n .simulate(circ)\n .final_density_matrix[0, 0]\n .real\n )\n assert -1.0001 <= zero_prob <= 1.0001", + "completions": [ + "```python\ndef test_rotated_rb_circuits(n_qubits, theta):\n from cirq import Circuit, LineQubit\n from cirq.testing import assert_has_diagram\n from mitiq.generate import generate_rotated_rb_circuits\n\n circuits = generate_rotated_rb_circuits(n_qubits, 10, theta)\n for circuit in circuits:\n assert_has_diagram(\n circuit,\n f\"\"\"\n0: ───Rz({theta})───\n1: ───Rz({theta})───\n2: ───Rz({theta})───\n3: ───Rz({theta})───\n4: ───Rz({theta})───\n5: ───Rz({theta})───\n6: ───Rz({theta})───\n7: ───Rz({theta})───\n8: ───Rz({theta})───\n9: ───Rz({theta})───\n10: ───Rz({theta})───\n11: ───Rz({theta})───\n12: ───Rz({theta})───\n13: ───Rz({theta})───\n14: ───Rz({theta})───\n15: ───Rz({theta})───\n16: ───Rz({theta})───\n17: ───Rz({theta})───\n18: ───Rz({theta})───\n19: ───Rz({theta})───\n20: ───Rz({theta})───\n21: ───Rz({theta})───\n22: ───Rz({theta})───\n23: ───Rz({theta})───\n24: ───Rz({theta})───\n25: ───Rz({theta})───\n26: ───Rz({theta})───\n27: ───Rz({theta})───\n28: ───Rz({theta})───\n29: ───Rz({theta})───\n30: ───Rz({theta})───\n31: ───Rz({theta})───\n32: ───Rz({theta})───\n33: ───Rz({theta})───\n34: ───Rz({theta})───\n35: ───Rz({theta})───\n36: ───Rz({theta})───\n37: ───Rz({theta})───\n38: ───Rz({theta})───\n39: ───Rz({theta})───\n40: ───Rz({theta})───\n41: ───Rz({theta})───\n42: ───Rz({theta})───\n43: ───Rz({theta})───\n44: ───Rz({theta})───\n45: ───Rz({theta})───\n46: ───Rz({theta})───\n47: ───Rz({theta})───\n48: ───Rz({theta})───\n49: ───Rz({theta})───\n50: ───Rz({theta})───\n51: ───Rz({theta})───\n52: ───Rz({theta})───\n53: ───Rz({theta})───\n54: ───Rz({theta})───\n55: ───Rz({theta})───\n56: ───Rz({theta})───\n57: ───Rz({theta})───\n58: ───Rz({theta})───\n59: ───Rz({theta})───\n60: ───Rz({theta})───\n61: ───Rz({theta})───\n62: ───Rz({theta})───\n63: ───" + ], + "line": 29, + "token": 226, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 87, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef eye(N, M=None, k=None, dtype=None):\n dtype = to_torch_dtype(dtype or config.floatx())\n M = N if M is None else M\n k = 0 if k is None else k\n if k == 0:\n # TODO: torch.eye doesn't support bfloat16 with cpu\n if get_device() == \"cpu\" and dtype == torch.bfloat16:\n return cast(\n torch.eye(\n N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n ),\n dtype,\n )\n return torch.eye(N, M, dtype=dtype, device=get_device())\n diag_length = builtins.max(N, M)\n diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n return torch.diag(diag, diagonal=k)[:N, :M]\n```\n###test function signature:\n```python\n\n def test_eye(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_eye", + "reference": "\n def test_eye(self):\n self.assertAllClose(knp.eye(3), np.eye(3))\n self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n self.assertAllClose(knp.Eye()(3), np.eye(3))\n self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1))\n\n # Test k >= N\n self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3))\n\n # Test k > 0 and N >= M\n self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1))\n\n # Test k > 0 and N < M and N + k > M\n self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2))\n\n # Test k < 0 and M >= N\n self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1))\n\n # Test k < 0 and M < N and M - k > N\n self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))", + "completions": [ + "```python\n\n def test_eye(self):\n self.assertEqual(eye(3), torch.eye(3))\n self.assertEqual(eye(3, 4), torch.eye(3, 4))\n self.assertEqual(eye(3, 4, 1), torch.eye(3, 4, 1))\n self.assertEqual(eye(3, 4, -1), torch.eye(3, 4, -1))\n self.assertEqual(eye(3, 4, 2), torch.eye(3, 4, 2))\n self.assertEqual(eye(3, 4, -2), torch.eye(3, 4, -2))\n self.assertEqual(eye(3, 4, 3), torch.eye(3, 4, 3))\n self.assertEqual(eye(3, 4, -3), torch.eye(3, 4, -3))\n self.assertEqual(eye(3, 4, 4), torch.eye(3, 4, 4))\n self.assertEqual(eye(3, 4, -4), torch.eye(3, 4, -4))\n self.assertEqual(eye(3, 4, 5), torch.eye(3, 4, 5))\n self.assertEqual(eye(3, 4, -5), torch.eye(3, 4, -5))\n self.assertEqual(eye(3, 4, 6), torch.eye(3, 4, 6))\n self.assertEqual(eye(3, 4, -6), torch.eye(3, 4, -6))\n self.assertEqual(eye(3, 4, 7), torch.eye(3, 4, 7))\n self.assertEqual(eye(3, 4, -7), torch.eye(3, 4, -7))\n self.assertEqual(eye(3, 4, 8), torch.eye(3, 4, 8))\n self.assertEqual(eye(3, 4, -8), torch.eye(3, 4, -8))\n self.assertEqual(eye(3, 4, 9), torch.eye(3, 4, 9))\n self.assertEqual(eye(3, 4, -9), torch.eye(3, 4, -9))\n self.assertEqual(eye(3, 4, 10), torch.eye(3, 4, 10))\n self.assertEqual(eye(3, 4, -10), torch.eye(3, 4, -10))\n self.assertEqual(eye(3, 4, 11), torch.eye(3, 4, 11))\n self.assertEqual(eye(3, 4, -11), torch.eye(3, 4, -11))\n self.assertEqual(eye(3, 4, 12), torch.eye(3, 4, 12))\n self.assertEqual(eye(3, 4, -12), torch.eye(3, 4, -12))\n self.assertEqual(eye(3, 4, 13), torch.eye(3, 4, 13))\n self.assertEqual(eye(3, 4, -13), torch.eye(3, 4, -13))\n self.assertEqual(eye(3, 4, 14), torch.eye(3, 4, 14))\n self.assertEqual(eye(3, 4, -14), torch.eye(3, 4, -14))\n self.assertEqual(eye(3, 4, 15), torch.eye(" + ], + "line": 27, + "token": 219, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef ragged_qk_dotprod(\n query: RaggedActivations, key: RaggedActivations, lut: RaggedQkPidLookupTable\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n # check constraints\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n # allocates output\n # max_n_ctx_q_across_seqs = query.max_n_ctx_per_seq\n\n assert query.n_seqs == key.n_seqs\n # TODO: flag use zeros for garbage\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n # Integers\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n```\n###test function signature:\n```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ragged_qk_dotprod_single_seq", + "reference": "\ndef test_ragged_qk_dotprod_single_seq(dtype):\n d_head = 2\n\n key = RaggedActivations.from_list(\n [\n make_seq(n_ctx=3, value=42, d_model=d_head, dtype=dtype),\n ]\n )\n query = RaggedActivations.from_list(\n [\n make_seq(n_ctx=4, value=55, d_model=d_head, dtype=dtype),\n ]\n )\n torch_scores = scores_via_qk_dotprod(query, key)\n print(f\"{torch_scores=}\")\n\n lut = RaggedQkPidLookupTable.from_query_and_key_tokens_per_seq(\n n_ctx_q_per_seq=query.n_ctx_per_seq, n_ctx_k_per_seq=key.n_ctx_per_seq\n )\n\n scores = ragged_qk_dotprod(query, key, lut)\n assert_eq(torch_scores, scores)", + "completions": [ + "```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n # query\n query_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n query = RaggedActivations(\n raw_tensor=query_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # key\n key_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n key = RaggedActivations(\n raw_tensor=key_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # lut\n pid_to_in_q_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_in_k_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_q_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_k_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_seq_idx = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n lut = RaggedQkPidLookupTable(\n pid_to_in_q_token_offset=pid_to_in_q_token_offset,\n pid_to_in_k_token_offset=pid_to_in_k_token_offset,\n pid_to_out_q_block=pid_to_out_q_block,\n pid_to_out_k_block=pid_to_out_k_block,\n pid_to_out_seq_idx=pid_to_out_seq_idx,\n n_pids_total=1,\n )\n\n # scores_out\n scores_out = torch.ones(\n (1, 1, 1), dtype=dtype, device=torch.device(\"cuda\")\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_" + ], + "line": 57, + "token": 310, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 89, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):\n # Ravel both arrays, behavior for the first array could be different\n ar1 = np.asarray(ar1).ravel()\n ar2 = np.asarray(ar2).ravel()\n\n # Ensure that iteration through object arrays yields size-1 arrays\n if ar2.dtype == object:\n ar2 = ar2.reshape(-1, 1)\n\n if kind not in {None, 'sort', 'table'}:\n raise ValueError(\n f\"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.\")\n\n # Can use the table method if all arrays are integers or boolean:\n is_int_arrays = all(ar.dtype.kind in (\"u\", \"i\", \"b\") for ar in (ar1, ar2))\n use_table_method = is_int_arrays and kind in {None, 'table'}\n\n if use_table_method:\n if ar2.size == 0:\n if invert:\n return np.ones_like(ar1, dtype=bool)\n else:\n return np.zeros_like(ar1, dtype=bool)\n\n # Convert booleans to uint8 so we can use the fast integer algorithm\n if ar1.dtype == bool:\n ar1 = ar1.astype(np.uint8)\n if ar2.dtype == bool:\n ar2 = ar2.astype(np.uint8)\n\n ar2_min = np.min(ar2)\n ar2_max = np.max(ar2)\n\n ar2_range = int(ar2_max) - int(ar2_min)\n\n # Constraints on whether we can actually use the table method:\n # 1. Assert memory usage is not too large\n below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)\n # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype\n range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max\n # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype\n if ar1.size > 0:\n ar1_min = np.min(ar1)\n ar1_max = np.max(ar1)\n\n # After masking, the range of ar1 is guaranteed to be\n # within the range of ar2:\n ar1_upper = min(int(ar1_max), int(ar2_max))\n ar1_lower = max(int(ar1_min), int(ar2_min))\n\n range_safe_from_overflow &= all((\n ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,\n ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min\n ))\n\n # Optimal performance is for approximately\n # log10(size) > (log10(range) - 2.27) / 0.927.\n # However, here we set the requirement that by default\n # the intermediate array can only be 6x\n # the combined memory allocation of the original\n # arrays. See discussion on \n # https://github.com/numpy/numpy/pull/12065.\n\n if (\n range_safe_from_overflow and \n (below_memory_constraint or kind == 'table')\n ):\n\n if invert:\n outgoing_array = np.ones_like(ar1, dtype=bool)\n else:\n outgoing_array = np.zeros_like(ar1, dtype=bool)\n\n # Make elements 1 where the integer exists in ar2\n if invert:\n isin_helper_ar = np.ones(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 0\n else:\n isin_helper_ar = np.zeros(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 1\n\n # Mask out elements we know won't work\n basic_mask = (ar1 <= ar2_max) & (ar1 >= ar2_min)\n outgoing_array[basic_mask] = isin_helper_ar[ar1[basic_mask] -\n ar2_min]\n\n return outgoing_array\n elif kind == 'table': # not range_safe_from_overflow\n raise RuntimeError(\n \"You have specified kind='table', \"\n \"but the range of values in `ar2` or `ar1` exceed the \"\n \"maximum integer of the datatype. \"\n \"Please set `kind` to None or 'sort'.\"\n )\n elif kind == 'table':\n raise ValueError(\n \"The 'table' method is only \"\n \"supported for boolean or integer arrays. \"\n \"Please select 'sort' or None for kind.\"\n )\n\n\n # Check if one of the arrays may contain arbitrary objects\n contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject\n\n # This code is run when\n # a) the first condition is true, making the code significantly faster\n # b) the second condition is true (i.e. `ar1` or `ar2` may contain\n # arbitrary objects), since then sorting is not guaranteed to work\n if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:\n if invert:\n mask = np.ones(len(ar1), dtype=bool)\n for a in ar2:\n mask &= (ar1 != a)\n else:\n mask = np.zeros(len(ar1), dtype=bool)\n for a in ar2:\n mask |= (ar1 == a)\n return mask\n\n # Otherwise use sorting\n if not assume_unique:\n ar1, rev_idx = np.unique(ar1, return_inverse=True)\n ar2 = np.unique(ar2)\n\n ar = np.concatenate((ar1, ar2))\n # We need this to be a stable sort, so always use 'mergesort'\n # here. The values from the first array should always come before\n # the values from the second array.\n order = ar.argsort(kind='mergesort')\n sar = ar[order]\n if invert:\n bool_ar = (sar[1:] != sar[:-1])\n else:\n bool_ar = (sar[1:] == sar[:-1])\n flag = np.concatenate((bool_ar, [invert]))\n ret = np.empty(ar.shape, dtype=bool)\n ret[order] = flag\n\n if assume_unique:\n return ret[:len(ar1)]\n else:\n return ret[rev_idx]\n```\n###test function signature:\n```python\n\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_in1d_ravel", + "reference": "\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n a = np.arange(6).reshape(2, 3)\n b = np.arange(3, 9).reshape(3, 2)\n long_b = np.arange(3, 63).reshape(30, 2)\n ec = np.array([False, False, False, True, True, True])\n\n assert_array_equal(in1d(a, b, assume_unique=True, kind=kind),\n ec)\n assert_array_equal(in1d(a, b, assume_unique=False,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=True,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=False,\n kind=kind),\n ec)", + "completions": [ + "```python\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n ar1 = np.array([1, 2, 3, 4, 5])\n ar2 = np.array([1, 2, 3, 4, 5])\n ar1_ravel = ar1.ravel()\n ar2_ravel = ar2.ravel()\n self.assertTrue(np.array_equal(ar1_ravel, ar1))\n self.assertTrue(np.array_equal(ar2_ravel, ar2))\n self.assertTrue(np.array_equal(in1d(ar1, ar2, kind=kind),\n in1d(ar1_ravel, ar2_ravel, kind=kind)))\n```\n" + ], + "line": 133, + "token": 1112, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 90, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_metadata", + "reference": "\ndef test_header_to_metadata_and_cell_metadata():\n text = \"\"\"---\ntitle: Sample header\njupyter:\n mainlanguage: python\n---\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {\"mainlanguage\": \"python\"}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {\"lines_to_next_cell\": 0}\n assert pos == len(lines)", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_metadata():\n lines = [\n \"#' ---\",\n \"#' jupyter:\",\n \"#' jupytext:\",\n \"#' text_representation:\",\n \"#' extension: .py\",\n \"#' format_name: light\",\n \"#' format_version: '1.5'\",\n \"#' jupytext_version: 1.11.4\",\n \"#' kernelspec:\",\n \"#' display_name: Python 3\",\n \"#' language: python\",\n \"#' name: python3\",\n \"#' ---\",\n \"#'\",\n \"#' This is a comment\",\n \"#'\",\n \"#' And another one\",\n \"#'\",\n \"#' ```python\",\n \"#' import numpy as np\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' import matplotlib.pyplot as plt\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %matplotlib inline\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2" + ], + "line": 104, + "token": 642, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef import_optional_dependency(\n name: str,\n extra: str = \"\",\n errors: str = \"raise\",\n min_version: str | None = None,\n):\n\n assert errors in {\"warn\", \"raise\", \"ignore\"}\n\n package_name = INSTALL_MAPPING.get(name)\n install_name = package_name if package_name is not None else name\n\n msg = (\n f\"Missing optional dependency '{install_name}'. {extra} \"\n f\"Use pip or conda to install {install_name}.\"\n )\n try:\n module = importlib.import_module(name)\n except ImportError:\n if errors == \"raise\":\n raise ImportError(msg)\n return None\n\n # Handle submodules: if we have submodule, grab parent module from sys.modules\n parent = name.split(\".\")[0]\n if parent != name:\n install_name = parent\n module_to_get = sys.modules[install_name]\n else:\n module_to_get = module\n minimum_version = min_version if min_version is not None else VERSIONS.get(parent)\n if minimum_version:\n version = get_version(module_to_get)\n if version and Version(version) < Version(minimum_version):\n msg = (\n f\"Pandas requires version '{minimum_version}' or newer of '{parent}' \"\n f\"(version '{version}' currently installed).\"\n )\n if errors == \"warn\":\n warnings.warn(\n msg,\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return None\n elif errors == \"raise\":\n raise ImportError(msg)\n\n return module\n```\n###test function signature:\n```python\n\ndef test_bad_version(monkeypatch):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_bad_version", + "reference": "\ndef test_bad_version(monkeypatch):\n name = \"fakemodule\"\n module = types.ModuleType(name)\n module.__version__ = \"0.9.0\"\n sys.modules[name] = module\n monkeypatch.setitem(VERSIONS, name, \"1.0.0\")\n\n match = \"Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'\"\n with pytest.raises(ImportError, match=match):\n import_optional_dependency(\"fakemodule\")\n\n # Test min_version parameter\n result = import_optional_dependency(\"fakemodule\", min_version=\"0.8\")\n assert result is module\n\n with tm.assert_produces_warning(UserWarning):\n result = import_optional_dependency(\"fakemodule\", errors=\"warn\")\n assert result is None\n\n module.__version__ = \"1.0.0\" # exact match is OK\n result = import_optional_dependency(\"fakemodule\")\n assert result is module", + "completions": [ + "```python\ndef test_bad_version(monkeypatch):\n monkeypatch.setattr(import_optional_dependency, \"get_version\", lambda x: \"0.1.0\")\n monkeypatch.setattr(import_optional_dependency, \"VERSIONS\", {\"pandas\": \"0.2.0\"})\n with pytest.raises(ImportError):\n import_optional_dependency(\"pandas\")\n```\n" + ], + "line": 54, + "token": 352, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_mel_spectrogram(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mel_spectrogram", + "reference": "\n def test_mel_spectrogram(self):\n waveform = self._load_datasamples(1)[0]\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=513,\n num_mel_filters=13,\n min_frequency=100,\n max_frequency=4000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"htk\",\n )\n self.assertEqual(mel_filters.shape, (513, 13))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n )\n self.assertEqual(spec.shape, (513, 732))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n mel_filters=mel_filters,\n )\n self.assertEqual(spec.shape, (13, 732))\n\n # fmt: off\n expected = np.array([\n 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,\n 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,\n 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,\n 9.44153646e-04\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:, 300], expected))", + "completions": [ + "```python\n def test_mel_spectrogram(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n mel_floor = 1e-10\n log_mel = \"log\"\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (1, 10))\n self.assertEqual(spectrogram[0][0], 1)\n self.assertEqual(spectrogram[0][1], 2)\n self.assertEqual(spectrogram[0][2], 3)\n self.assertEqual(spectrogram[0][3], 4)\n self.assertEqual(spectrogram[0][4], 5)\n self.assertEqual(spectrogram[0][5], 6)\n self.assertEqual(spectrogram[0][6], 7)\n self.assertEqual(spectrogram[0][7], 8)\n self.assertEqual(spectrogram[0][8], 9)\n self.assertEqual(spectrogram[0][9], 10)\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 93, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef create(\n meta_schema,\n validators=(),\n version=None,\n type_checker=_types.draft202012_type_checker,\n format_checker=_format.draft202012_format_checker,\n id_of=_id_of,\n applicable_validators=methodcaller(\"items\"),\n):\n # preemptively don't shadow the `Validator.format_checker` local\n format_checker_arg = format_checker\n\n @attr.s\n class Validator:\n\n VALIDATORS = dict(validators)\n META_SCHEMA = dict(meta_schema)\n TYPE_CHECKER = type_checker\n FORMAT_CHECKER = format_checker_arg\n ID_OF = staticmethod(id_of)\n\n schema = attr.ib(repr=reprlib.repr)\n resolver = attr.ib(default=None, repr=False)\n format_checker = attr.ib(default=None)\n\n def __init_subclass__(cls):\n warnings.warn(\n (\n \"Subclassing validator classes is not intended to \"\n \"be part of their public API. A future version \"\n \"will make doing so an error, as the behavior of \"\n \"subclasses isn't guaranteed to stay the same \"\n \"between releases of jsonschema. Instead, prefer \"\n \"composition of validators, wrapping them in an object \"\n \"owned entirely by the downstream library.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n\n def __attrs_post_init__(self):\n if self.resolver is None:\n self.resolver = RefResolver.from_schema(\n self.schema,\n id_of=id_of,\n )\n\n @classmethod\n def check_schema(cls, schema, format_checker=_UNSET):\n Validator = validator_for(cls.META_SCHEMA, default=cls)\n if format_checker is _UNSET:\n format_checker = Validator.FORMAT_CHECKER\n validator = Validator(\n schema=cls.META_SCHEMA,\n format_checker=format_checker,\n )\n for error in validator.iter_errors(schema):\n raise exceptions.SchemaError.create_from(error)\n\n def evolve(self, **changes):\n # Essentially reproduces attr.evolve, but may involve instantiating\n # a different class than this one.\n cls = self.__class__\n\n schema = changes.setdefault(\"schema\", self.schema)\n NewValidator = validator_for(schema, default=cls)\n\n for field in attr.fields(cls):\n if not field.init:\n continue\n attr_name = field.name # To deal with private attributes.\n init_name = attr_name if attr_name[0] != \"_\" else attr_name[1:]\n if init_name not in changes:\n changes[init_name] = getattr(self, attr_name)\n\n return NewValidator(**changes)\n\n def iter_errors(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.iter_errors \"\n \"is deprecated and will be removed in a future \"\n \"release. Call validator.evolve(schema=new_schema).\"\n \"iter_errors(...) instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n else:\n _schema = self.schema\n\n if _schema is True:\n return\n elif _schema is False:\n yield exceptions.ValidationError(\n f\"False schema does not allow {instance!r}\",\n validator=None,\n validator_value=None,\n instance=instance,\n schema=_schema,\n )\n return\n\n scope = id_of(_schema)\n if scope:\n self.resolver.push_scope(scope)\n try:\n for k, v in applicable_validators(_schema):\n validator = self.VALIDATORS.get(k)\n if validator is None:\n continue\n\n errors = validator(self, v, instance, _schema) or ()\n for error in errors:\n # set details if not already set by the called fn\n error._set(\n validator=k,\n validator_value=v,\n instance=instance,\n schema=_schema,\n type_checker=self.TYPE_CHECKER,\n )\n if k not in {\"if\", \"$ref\"}:\n error.schema_path.appendleft(k)\n yield error\n finally:\n if scope:\n self.resolver.pop_scope()\n\n def descend(self, instance, schema, path=None, schema_path=None):\n for error in self.evolve(schema=schema).iter_errors(instance):\n if path is not None:\n error.path.appendleft(path)\n if schema_path is not None:\n error.schema_path.appendleft(schema_path)\n yield error\n\n def validate(self, *args, **kwargs):\n for error in self.iter_errors(*args, **kwargs):\n raise error\n\n def is_type(self, instance, type):\n try:\n return self.TYPE_CHECKER.is_type(instance, type)\n except exceptions.UndefinedTypeCheck:\n raise exceptions.UnknownType(type, instance, self.schema)\n\n def is_valid(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.is_valid is deprecated \"\n \"and will be removed in a future release. Call \"\n \"validator.evolve(schema=new_schema).is_valid(...) \"\n \"instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n self = self.evolve(schema=_schema)\n\n error = next(self.iter_errors(instance), None)\n return error is None\n\n if version is not None:\n safe = version.title().replace(\" \", \"\").replace(\"-\", \"\")\n Validator.__name__ = Validator.__qualname__ = f\"{safe}Validator\"\n Validator = validates(version)(Validator)\n\n return Validator\n```\n###test function signature:\n```python\n\n def test_validates_registers_meta_schema_id(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validates_registers_meta_schema_id", + "reference": "\n def test_validates_registers_meta_schema_id(self):\n meta_schema_key = \"meta schema id\"\n my_meta_schema = {\"id\": meta_schema_key}\n\n validators.create(\n meta_schema=my_meta_schema,\n version=\"my version\",\n id_of=lambda s: s.get(\"id\", \"\"),\n )\n self.addCleanup(validators._META_SCHEMAS.pop, meta_schema_key)\n\n self.assertIn(meta_schema_key, validators._META_SCHEMAS)", + "completions": [ + "```python\n def test_validates_registers_meta_schema_id(self):\n validator = create(\n meta_schema={\n \"$id\": \"https://example.com/schema.json\",\n \"type\": \"object\",\n \"properties\": {\"foo\": {\"type\": \"string\"}},\n }\n )\n self.assertEqual(\n validator.resolver.store[\"https://example.com/schema.json\"],\n validator.schema,\n )\n```\n" + ], + "line": 160, + "token": 896, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 94, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_metrics(directory: str) -> pd.Series:\n\n dfs = []\n # Loop through all csv files in the directory\n csv_files = [\n filename for filename in os.listdir(directory) if filename.endswith(\".csv\")\n ]\n\n if len(csv_files) == 0:\n return pd.Series()\n\n for filename in csv_files:\n filepath = os.path.join(directory, filename)\n # Read the csv file into a pandas DataFrame\n df = pd.read_csv(\n filepath,\n header=None,\n names=[\n \"initial_time_iso\",\n \"lead_time_hours\",\n \"channel\",\n \"metric\",\n \"value\",\n ],\n )\n out = pd.DataFrame()\n out[\"initial_time\"] = pd.to_datetime(df[\"initial_time_iso\"])\n out[\"lead_time\"] = pd.to_timedelta(df[\"lead_time_hours\"], unit=\"h\")\n out[\"channel\"] = df[\"channel\"]\n out[\"metric\"] = df[\"metric\"]\n out[\"value\"] = df[\"value\"]\n dfs.append(out)\n\n df = pd.concat(dfs, axis=0)\n df.set_index([\"initial_time\", \"lead_time\", \"channel\", \"metric\"], inplace=True)\n return df[\"value\"]\n```\n###test function signature:\n```python\n\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_and_read_metric", + "reference": "\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n initial_time = datetime.datetime(2022, 1, 1, 0, 0, 0)\n lead_time = datetime.timedelta(hours=24)\n channel = \"t2m\"\n metric = \"rmse\"\n value = 25.6\n\n # Read the metric from the directory\n directory = tmpdir.mkdir(\"metrics\")\n filename = \"test_metric.csv\"\n filepath = os.path.join(directory, filename)\n with open(filepath, \"w\") as f:\n forecast_metrics_io.write_metric(\n f, initial_time, lead_time, channel, metric, value\n )\n metrics = forecast_metrics_io.read_metrics(directory)\n\n # Check that the metric was written correctly\n expected_index = pd.MultiIndex.from_tuples(\n [(initial_time, lead_time, channel, metric)],\n names=[\"initial_time\", \"lead_time\", \"channel\", \"metric\"],\n )\n expected_value = pd.Series([value], index=expected_index, name=\"value\")\n pd.testing.assert_series_equal(metrics, expected_value)", + "completions": [ + "```python\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n metric = pd.Series(\n data=[1, 2, 3],\n index=[\n pd.Timestamp(\"2020-01-01 00:00:00\"),\n pd.Timestamp(\"2020-01-01 01:00:00\"),\n pd.Timestamp(\"2020-01-01 02:00:00\"),\n ],\n name=\"metric\",\n )\n metric.to_csv(tmpdir.join(\"metric.csv\"), header=False)\n\n # Read the metric from the file object\n read_metric = read_metrics(tmpdir.strpath)\n\n # Check that the metric is the same\n assert_series_equal(metric, read_metric)\n```\n" + ], + "line": 43, + "token": 320, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 95, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n\n def test_kde_bandwidth_method(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_kde_bandwidth_method", + "reference": "\n def test_kde_bandwidth_method(self):\n\n np.random.seed(8765678)\n n_basesample = 50\n xn = np.random.randn(n_basesample)\n\n # Default\n gkde = mlab.GaussianKDE(xn)\n # Supply a callable\n gkde2 = mlab.GaussianKDE(xn, 'scott')\n # Supply a scalar\n gkde3 = mlab.GaussianKDE(xn, bw_method=gkde.factor)\n\n xs = np.linspace(-7, 7, 51)\n kdepdf = gkde.evaluate(xs)\n kdepdf2 = gkde2.evaluate(xs)\n assert kdepdf.all() == kdepdf2.all()\n kdepdf3 = gkde3.evaluate(xs)\n assert kdepdf.all() == kdepdf3.all()", + "completions": [ + "```python\n def test_kde_bandwidth_method(self):\n from sklearn.neighbors import KernelDensity\n from sklearn.metrics import euclidean_distances\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n from sklearn.model_selection import train_test_split\n from sklearn.metrics import accuracy_score\n from sklearn.metrics import f1_score\n from sklearn.metrics import precision_score\n from sklearn.metrics import recall_score\n from sklearn.metrics import roc_auc_score\n from sklearn.metrics import confusion_matrix\n from sklearn.metrics import classification_report\n from sklearn.metrics import cohen_kappa_score\n from sklearn.metrics import mean_squared_error\n from sklearn.metrics import mean_absolute_error\n from sklearn.metrics import mean_squared_log_error\n from sklearn.metrics import median_absolute_error\n from sklearn.metrics import r2_score\n from sklearn.metrics import explained_variance_score\n from sklearn.metrics import silhouette_score\n from sklearn.metrics import calinski_harabasz_score\n from sklearn.metrics import davies_bouldin_score\n from sklearn.metrics import mutual_info_score\n from sklearn.metrics import normalized_mutual_info_score\n from sklearn.metrics import adjusted_mutual_info_score\n from sklearn.metrics import adjusted_rand_score\n from sklearn.metrics import completeness_score\n from sklearn.metrics import contingency_matrix\n from sklearn.metrics import homogeneity_completeness_v_measure\n from sklearn.metrics import homogeneity_score\n from sklearn.metrics import v_measure_score\n from sklearn.metrics import normalized_v_measure_score\n from sklearn.metrics import poisson\n from sklearn.metrics import kld\n from sklearn.metrics import kld_normalized\n from sklearn.metrics import kld_unnormalized\n from sklearn.metrics import kld_unnormalized_normalized\n from sklearn.metrics import kld_unnormalized_normalized_sym\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_top_k_embeddings(\n query_embedding: List[float],\n embeddings: List[List[float]],\n similarity_fn: Optional[Callable[..., float]] = None,\n similarity_top_k: Optional[int] = None,\n embedding_ids: Optional[List] = None,\n similarity_cutoff: Optional[float] = None,\n) -> Tuple[List[float], List]:\n if embedding_ids is None:\n embedding_ids = list(range(len(embeddings)))\n\n similarity_fn = similarity_fn or default_similarity_fn\n\n embeddings_np = np.array(embeddings)\n query_embedding_np = np.array(query_embedding)\n\n similarity_heap: List[Tuple[float, Any]] = []\n for i, emb in enumerate(embeddings_np):\n similarity = similarity_fn(query_embedding_np, emb)\n if similarity_cutoff is None or similarity > similarity_cutoff:\n heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))\n if similarity_top_k and len(similarity_heap) > similarity_top_k:\n heapq.heappop(similarity_heap)\n result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)\n\n result_similarities = [s for s, _ in result_tups]\n result_ids = [n for _, n in result_tups]\n\n return result_similarities, result_ids\n```\n###test function signature:\n```python\ndef test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_top_k_mmr_embeddings", + "reference": "def test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n query_embedding = [5.0, 0.0, 0.0]\n embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]]\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.8\n )\n\n assert np.isclose(0.8 * 4 / 5, result_similarities[0], atol=0.00001)\n assert np.isclose(\n 0.8 * 3 / 5 - (1 - 0.8) * (3 * 4 / 25 + 3 * 4 / 25),\n result_similarities[1],\n atol=0.00001,\n )\n assert np.isclose(\n 0.8 * -4 / 5 - (1 - 0.8) * (3 * -4 / 25 + 4 * 3 / 25),\n result_similarities[2],\n atol=0.00001,\n )\n assert result_ids == [0, 1, 2]\n\n # Tests that if the first embedding vector is close to the second,\n # it will return the third\n query_embedding = [1.0, 0.0, 1.0]\n embeddings = [[1.0, 0.0, 0.9], [1.0, 0.0, 0.8], [0.7, 0.0, 1.0]]\n\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.5\n )\n assert result_ids == [0, 2, 1]\n\n # Tests that embedding ids map properly to results\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, embedding_ids=[\"A\", \"B\", \"C\"], mmr_threshold=0.5\n )\n assert result_ids == [\"A\", \"C\", \"B\"]\n # Test that it will go back to the original order under a high threshold\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n assert result_ids == [0, 1, 2]\n\n # Test similarity_top_k works\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1, similarity_top_k=2\n )\n assert result_ids == [0, 1]\n\n # Test the results for get_top_k_embeddings and get_top_k_mmr_embeddings are the\n # same for threshold = 1\n query_embedding = [10, 23, 90, 78]\n embeddings = [[1, 23, 89, 68], [1, 74, 144, 23], [0.23, 0.0, 1.0, 9]]\n result_similarities_no_mmr, result_ids_no_mmr = get_top_k_embeddings(\n query_embedding, embeddings\n )\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n\n for result_no_mmr, result_with_mmr in zip(\n result_similarities_no_mmr, result_similarities\n ):\n assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)", + "completions": [ + "```python\ndef test_get_top_k_mmr_embeddings() -> None:\n query_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]\n embeddings = [\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n ]\n result_similarities, result_ids = get_top_k_embeddings(\n query_embedding, embeddings, similarity_top_k=3\n )\n assert result_similarities == [0.5, 0.4, 0.3]\n assert result_ids == [0, 1, 2]\n```\n" + ], + "line": 35, + "token": 309, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 97, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef frame_accuracy(\n logits: np.ndarray,\n label: np.ndarray,\n background_logit_threshold: float = 0.0,\n) -> float:\n top1_idx = np.argmax(logits, axis=-1)\n background_label = np.sum(label, axis=-1) == 0\n pred_background = (\n np.max(logits, axis=-1) <= background_logit_threshold\n ).astype(np.int32)\n\n # Extracts the label at the highest logit index for each input.\n top1_correct = np.take_along_axis(label, top1_idx[..., None], axis=-1)\n top1_correct = np.squeeze(top1_correct)\n foreground_correct = ~pred_background.astype(bool) * top1_correct\n\n # Count correctly classified background samples.\n background_correct = pred_background * background_label\n correct = foreground_correct + background_correct\n return np.sum(correct) / len(logits)\n```\n###test function signature:\n```python\n\n def test_frame_accuracy_all_background(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_frame_accuracy_all_background", + "reference": "\n def test_frame_accuracy_all_background(self):\n logits = np.array(\n [\n [1.5, 0.0, -1.0],\n [0.1, 0.7, 0.2],\n [0.1, 0.7, 0.2],\n [0.1, -0.7, 0.2],\n [-0.1, -0.7, -0.2],\n ],\n np.float32,\n )\n label = np.array(\n [\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n ],\n np.int32,\n )\n actual = metrics.frame_accuracy(logits, label)\n self.assertAlmostEqual(actual, 0.2)", + "completions": [ + "```python\ndef test_frame_accuracy_all_background(self):\n logits = np.array([[0.1, 0.2, 0.3, 0.4]])\n label = np.array([[0, 0, 0, 0]])\n self.assertAlmostEqual(frame_accuracy(logits, label), 1.0)\n```\n" + ], + "line": 28, + "token": 206, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_local_image_to_be_loaded", + "reference": "\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n file_path, _ = example_local_image\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=file_path,\n max_width=64,\n max_height=128,\n )\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert abs(scaling_factor - 0.5) < 1e-5\n assert decoding_result.shape == (64, 64, 3)\n assert (decoding_result == 0).all()", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n reference, image = example_local_image\n # when\n result = load_image_from_string(\n reference=reference,\n max_height=image.shape[0],\n max_width=image.shape[1],\n )\n # then\n assert result[0] == reference\n assert result[1] is None\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 99, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = \"htk\") -> Union[float, np.ndarray]:\n\n if mel_scale not in [\"slaney\", \"htk\", \"kaldi\"]:\n raise ValueError('mel_scale should be one of \"htk\", \"slaney\" or \"kaldi\".')\n\n if mel_scale == \"htk\":\n return 2595.0 * np.log10(1.0 + (freq / 700.0))\n elif mel_scale == \"kaldi\":\n return 1127.0 * np.log(1.0 + (freq / 700.0))\n\n min_log_hertz = 1000.0\n min_log_mel = 15.0\n logstep = 27.0 / np.log(6.4)\n mels = 3.0 * freq / 200.0\n\n if isinstance(freq, np.ndarray):\n log_region = freq >= min_log_hertz\n mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep\n elif freq >= min_log_hertz:\n mels = min_log_mel + np.log(freq / min_log_hertz) * logstep\n\n return mels\n```\n###test function signature:\n```python\n\n def test_hertz_to_mel(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_hertz_to_mel", + "reference": "\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(0.0), 0.0)\n self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)\n\n inputs = np.array([100, 200])\n expected = np.array([150.48910241, 283.22989816])\n self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))\n\n self.assertEqual(hertz_to_mel(0.0, \"slaney\"), 0.0)\n self.assertEqual(hertz_to_mel(100, \"slaney\"), 1.5)\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"slaney\"), expected))\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"kaldi\"), expected))\n\n with pytest.raises(ValueError):\n hertz_to_mel(100, mel_scale=None)", + "completions": [ + "```python\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(" + ], + "line": 27, + "token": 262, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/4/EI/line_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Python/4/EI/line_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..734820549a2b06855aac7eae9ad225634d8f703a --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/4/EI/line_counts_EI.csv @@ -0,0 +1,8 @@ +Models,subset_0(15~187),subset_1,subset_2,subset_3(749~749) +StarCoder2-15b,24.55,0.00,0.00,20.70 +CodeLlama-7b,26.44,0.00,0.00,27.18 +CodeLlama-13b,26.46,0.00,0.00,31.82 +CodeLlama-34b,27.20,0.00,0.00,24.58 +DeepSeek-Coder-1.3b,25.85,0.00,0.00,21.28 +DeepSeek-Coder-6.7b,24.76,0.00,0.00,20.74 +DeepSeek-Coder-33b,29.22,0.00,0.00,25.46 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/4/EI/token_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Python/4/EI/token_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..56aa443e2d908ee896d53b8d610ef2965a7c183d --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/4/EI/token_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(122~1112),subset_1(1981~1981),subset_2,subset_3(7038~7038) +StarCoder2-15b,24.65,15.68,0.00,20.70 +CodeLlama-7b,26.49,22.34,0.00,27.18 +CodeLlama-13b,26.64,10.43,0.00,31.82 +CodeLlama-34b,27.29,19.34,0.00,24.58 +DeepSeek-Coder-1.3b,25.90,21.64,0.00,21.28 +DeepSeek-Coder-6.7b,24.71,29.94,0.00,20.74 +DeepSeek-Coder-33b,29.18,32.72,0.00,25.46 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/4/EI/tongji.py b/dataset/Test Generation/ComplexCodeEval-Python/4/EI/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/4/EI/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/4/QS/QS.json b/dataset/Test Generation/ComplexCodeEval-Python/4/QS/QS.json new file mode 100644 index 0000000000000000000000000000000000000000..fbde96b6334ae8f4681fad1e4f329d52e4faa8cd --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/4/QS/QS.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_pair_vector_and_distance(g: dgl.DGLGraph):\n dst_pos = g.ndata[\"pos\"][g.edges()[1]] + g.edata[\"pbc_offshift\"]\n src_pos = g.ndata[\"pos\"][g.edges()[0]]\n bond_vec = dst_pos - src_pos\n bond_dist = torch.norm(bond_vec, dim=1)\n\n return bond_vec, bond_dist\n```\n###test function signature:\n```python\n\n def test_compute_angle(self, graph_Mo, graph_CH4):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_angle", + "reference": "\n def test_compute_angle(self, graph_Mo, graph_CH4):\n s1, g1, state1 = graph_Mo\n lattice = torch.tensor(s1.lattice.matrix, dtype=matgl.float_th).unsqueeze(dim=0)\n g1.edata[\"pbc_offshift\"] = torch.matmul(g1.edata[\"pbc_offset\"], lattice[0])\n g1.ndata[\"pos\"] = g1.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g1)\n g1.edata[\"bond_vec\"] = bv\n g1.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g1, 4.0)\n\n line_graph = create_line_graph(g1, 4.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))\n np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata[\"theta\"])), decimal=4)\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n s2, g2, state2 = graph_CH4\n lattice = torch.tensor(np.identity(3), dtype=matgl.float_th).unsqueeze(dim=0)\n g2.edata[\"pbc_offshift\"] = torch.matmul(g2.edata[\"pbc_offset\"], lattice[0])\n g2.ndata[\"pos\"] = g2.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g2)\n g2.edata[\"bond_vec\"] = bv\n g2.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g2, 2.0)\n\n line_graph = create_line_graph(g2, 2.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata[\"theta\"]))\n )\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )", + "completions": [ + "```python\n def test_compute_angle(self, graph_Mo, graph_CH4):\n \"\"\"\n Test compute_angle function\n \"\"\"\n # compute angle between Mo-C and Mo-H\n angle_Mo_C = compute_angle(graph_Mo, graph_CH4, 0, 1)\n angle_Mo_H = compute_angle(graph_Mo, graph_CH4, 0, 2)\n\n # check angle between Mo-C and Mo-H\n self.assertAlmostEqual(angle_Mo_C, 104.5, places=1)\n self.assertAlmostEqual(angle_Mo_H, 109.5, places=1)\n```\n" + ], + "line": 16, + "token": 148, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 1, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def _make_schedule(self, targets: List[Task]) -> List[Task]:\n self.schedule = []\n self.dependencies = self._build_dependencies(targets)\n\n edge_tups = []\n for node in self.dependencies:\n for dependency in self.dependencies[node]:\n edge_tups.append((dependency, node))\n\n for task in targets:\n # add edges from a dummy node to each target to guarantee\n # they will be included in the final schedule\n edge_tups.append((Executor.DUMMY_TASK_VALUE, task))\n\n def _compare_key(task: Union[Task, str]):\n if task == Executor.DUMMY_TASK_VALUE:\n return (\"\", 0)\n return (\n task.group_label() or \"\",\n -task.priority(),\n )\n\n graph = networkx.DiGraph(edge_tups)\n res = [\n t\n for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)\n if t != Executor.DUMMY_TASK_VALUE\n ]\n return res\n```\n###test function signature:\n```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_group_label_with_dependencies", + "reference": "\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = create_mock_task(\"task1\", result=1, group_label=\"group1\")\n task2 = create_mock_task(\n \"task2\", result=2, dependencies={\"task1\": task1}, group_label=\"group2\"\n )\n task3 = create_mock_task(\n \"task3\", result=3, dependencies={\"task2\": task2}, group_label=\"group1\"\n )\n\n executor = Executor([task3])\n schedule = executor._make_schedule([task3])\n scheduled_labels = [\n task.group_label() for task in schedule if task.group_label()\n ]\n\n # Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1\n group1_indices = [\n i for i, label in enumerate(scheduled_labels) if label == \"group1\"\n ]\n group2_index = scheduled_labels.index(\"group2\")\n\n assert (\n group1_indices[-1] > group2_index\n ), \"Task with the same group label but later dependency was not scheduled after different group label\"", + "completions": [ + "```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = Task(name=\"task1\", group_label=\"group1\")\n task2 = Task(name=\"task2\", group_label=\"group1\")\n task3 = Task(name=\"task3\", group_label=\"group2\")\n task4 = Task(name=\"task4\", group_label=\"group2\")\n task5 = Task(name=\"task5\", group_label=\"group3\")\n task6 = Task(name=\"task6\", group_label=\"group3\")\n task7 = Task(name=\"task7\", group_label=\"group4\")\n task8 = Task(name=\"task8\", group_label=\"group4\")\n task9 = Task(name=\"task9\", group_label=\"group5\")\n task10 = Task(name=\"task10\", group_label=\"group5\")\n\n # Creating dependencies between tasks\n task1.depends_on(task2)\n task2.depends_on(task3)\n task3.depends_on(task4)\n task4.depends_on(task5)\n task5.depends_on(task6)\n task6.depends_on(task7)\n task7.depends_on(task8)\n task8.depends_on(task9)\n task9.depends_on(task10)\n\n # Creating a list of tasks\n tasks = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies\n tasks_with_dependencies = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with group labels\n tasks_with_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_" + ], + "line": 36, + "token": 242, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 2, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n):\n if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:\n return\n\n head_dim = dim // heads\n max_seq_len = find_multiple(max_seq_len, 8)\n\n for b in layers:\n b.attention.kv_cache = KVCache(\n max_batch_size, max_seq_len, heads, head_dim\n )\n\n freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)\n causal_mask = torch.tril(\n torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)\n )\n\n return causal_mask, freq_cis\n```\n###test function signature:\n```python\n\ndef test_setup_cache_max_batch_size_greater_than_max():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_setup_cache_max_batch_size_greater_than_max", + "reference": "\ndef test_setup_cache_max_batch_size_greater_than_max():\n layers = [\n Mock(attention=Mock(kw_cache=None)),\n Mock(attention=Mock(kw_cache=None)),\n ]\n max_batch_size = 64\n max_seq_len = 32\n heads = 4\n head_dim = 32\n block_size = 32\n rope_base = 1000\n setup_cache(\n max_batch_size + 10,\n max_seq_len,\n head_dim * heads,\n heads,\n layers,\n block_size,\n rope_base,\n )\n for layer in layers:\n assert isinstance(layer.attention.kw_cache, KVCache)\n assert layer.attention.kw_cache.k_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )\n assert layer.attention.kw_cache.v_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )", + "completions": [ + "```python\ndef test_setup_cache_max_batch_size_greater_than_max():\n max_batch_size = 10\n max_seq_len = 10\n dim = 10\n heads = 10\n layers = [1, 2, 3]\n block_size = 10\n rope_base = 10\n\n causal_mask, freq_cis = setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n )\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size" + ], + "line": 26, + "token": 168, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_file_to_be_downloaded", + "reference": "\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n response = Response()\n response.status_code = 200\n image = np.zeros((128, 128, 3), dtype=np.uint8)\n _, encoded_image = cv2.imencode(\".jpg\", image)\n response._content = encoded_image\n requests_get_mock.return_value = response\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=\"https://some/file.jpg\"\n )\n\n # then\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert scaling_factor is None\n assert decoding_result.shape == image.shape\n assert np.allclose(decoding_result, image)\n requests_get_mock.assert_called_once_with(\"https://some/file.jpg\")", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n reference = \"https://www.example.com/image.jpg\"\n max_height = 100\n max_width = 100\n requests_get_mock.return_value.content = b\"image_bytes\"\n # when\n result = load_image_from_string(\n reference=reference, max_height=max_height, max_width=max_width\n )\n # then\n assert result == (\n \"data:image/jpeg;base64,image_bytes\",\n None,\n )\n```\n" + ], + "line": 47, + "token": 276, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 4, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_random_circuits", + "reference": "\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n placer = placer(connectivity=connectivity)\n layout_circ = Circuit(5)\n initial_layout = placer(layout_circ)\n router = Sabre(connectivity=connectivity, lookahead=look, decay_lookahead=decay)\n circuit = generate_random_circuit(nqubits=5, ngates=n_gates)\n measurement = gates.M(*range(5))\n circuit.add(measurement)\n transpiled_circuit, final_qubit_map = router(circuit, initial_layout)\n assert router.added_swaps >= 0\n assert_connectivity(connectivity, transpiled_circuit)\n assert_placement(transpiled_circuit, final_qubit_map)\n assert n_gates + router.added_swaps + 1 == transpiled_circuit.ngates\n assert_circuit_equivalence(\n original_circuit=circuit,\n transpiled_circuit=transpiled_circuit,\n final_map=final_qubit_map,\n initial_map=initial_layout,\n )\n assert transpiled_circuit.queue[-1].register_name == measurement.register_name", + "completions": [ + "```python\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n import random\n import networkx as nx\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import" + ], + "line": 27, + "token": 250, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 5, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef xx_circuit_step(source, strength, target, embodiment):\n\n permute_source_for_overlap, permute_target_for_overlap = None, None\n\n # apply all possible reflections, shifts to the source\n for source_reflection_name in reflection_options:\n reflected_source_coord, source_reflection, reflection_phase_shift = apply_reflection(\n source_reflection_name, source\n )\n for source_shift_name in shift_options:\n shifted_source_coord, source_shift, shift_phase_shift = apply_shift(\n source_shift_name, reflected_source_coord\n )\n\n # check for overlap, back out permutation\n source_shared, target_shared = None, None\n for i, j in [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]:\n\n if (\n abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi)) < EPSILON\n or abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi) - np.pi)\n < EPSILON\n ):\n source_shared, target_shared = i, j\n break\n if source_shared is None:\n continue\n\n # pick out the other coordinates\n source_first, source_second = (x for x in [0, 1, 2] if x != source_shared)\n target_first, target_second = (x for x in [0, 1, 2] if x != target_shared)\n\n # check for arccos validity\n r, s, u, v, x, y = decompose_xxyy_into_xxyy_xx(\n float(target[target_first]),\n float(target[target_second]),\n float(shifted_source_coord[source_first]),\n float(shifted_source_coord[source_second]),\n float(strength),\n )\n if any(math.isnan(val) for val in (r, s, u, v, x, y)):\n continue\n\n # OK: this combination of things works.\n # save the permutation which rotates the shared coordinate into ZZ.\n permute_source_for_overlap = canonical_rotation_circuit(source_first, source_second)\n permute_target_for_overlap = canonical_rotation_circuit(target_first, target_second)\n break\n\n if permute_source_for_overlap is not None:\n break\n\n if permute_source_for_overlap is None:\n raise QiskitError(\n \"Error during RZX decomposition: Could not find a suitable Weyl \"\n f\"reflection to match {source} to {target} along {strength}.\"\n )\n\n prefix_circuit, affix_circuit = QuantumCircuit(2), QuantumCircuit(2)\n\n # the basic formula we're trying to work with is:\n # target^p_t_f_o =\n # rs * (source^s_reflection * s_shift)^p_s_f_o * uv * operation * xy\n # but we're rearranging it into the form\n # target = affix source prefix\n # and computing just the prefix / affix circuits.\n\n # the outermost prefix layer comes from the (inverse) target permutation.\n prefix_circuit.compose(permute_target_for_overlap.inverse(), inplace=True)\n # the middle prefix layer comes from the local Z rolls.\n prefix_circuit.rz(2 * x, [0])\n prefix_circuit.rz(2 * y, [1])\n prefix_circuit.compose(embodiment, inplace=True)\n prefix_circuit.rz(2 * u, [0])\n prefix_circuit.rz(2 * v, [1])\n # the innermost prefix layer is source_reflection, shifted by source_shift,\n # finally conjugated by p_s_f_o.\n prefix_circuit.compose(permute_source_for_overlap, inplace=True)\n prefix_circuit.compose(source_reflection, inplace=True)\n prefix_circuit.global_phase += -np.log(reflection_phase_shift).imag\n prefix_circuit.global_phase += -np.log(shift_phase_shift).imag\n\n # the affix circuit is constructed in reverse.\n # first (i.e., innermost), we install the other half of the source transformations and p_s_f_o.\n affix_circuit.compose(source_reflection.inverse(), inplace=True)\n affix_circuit.compose(source_shift, inplace=True)\n affix_circuit.compose(permute_source_for_overlap.inverse(), inplace=True)\n # then, the other local rolls in the middle.\n affix_circuit.rz(2 * r, [0])\n affix_circuit.rz(2 * s, [1])\n # finally, the other half of the p_t_f_o conjugation.\n affix_circuit.compose(permute_target_for_overlap, inplace=True)\n\n return {\"prefix_circuit\": prefix_circuit, \"affix_circuit\": affix_circuit}\n```\n###test function signature:\n```python\n def test_xx_circuit_step(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_xx_circuit_step", + "reference": " def test_xx_circuit_step(self):\n\n for _ in range(100):\n source_coordinate, interaction, target_coordinate = self._generate_xxyy_test_case()\n\n source_embodiment = qiskit.QuantumCircuit(2)\n source_embodiment.append(UnitaryGate(canonical_matrix(*source_coordinate)), [0, 1])\n interaction_embodiment = qiskit.QuantumCircuit(2)\n interaction_embodiment.append(UnitaryGate(canonical_matrix(*interaction)), [0, 1])\n\n prefix_circuit, affix_circuit = itemgetter(\"prefix_circuit\", \"affix_circuit\")(\n xx_circuit_step(\n source_coordinate, interaction[0], target_coordinate, interaction_embodiment\n )\n )\n\n target_embodiment = QuantumCircuit(2)\n target_embodiment.compose(prefix_circuit, inplace=True)\n target_embodiment.compose(source_embodiment, inplace=True)\n target_embodiment.compose(affix_circuit, inplace=True)\n self.assertTrue(\n np.all(\n np.abs(\n qiskit.quantum_info.operators.Operator(target_embodiment).data\n - canonical_matrix(*target_coordinate)\n )\n < EPSILON\n )\n )", + "completions": [ + "```python\n def test_xx_circuit_step(self):\n source = np.array([0.1, 0.2, 0.3])\n strength = 0.4\n target = np.array([0.5, 0.6, 0.7])\n embodiment = QuantumCircuit(2)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.cx(0, 1)\n embodiment.barrier()\n embodiment.rz(0.1, 0)\n embodiment.rz(0.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.3, 0)\n embodiment.rz(0.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.5, 0)\n embodiment.rz(0.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.7, 0)\n embodiment.rz(0.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.9, 0)\n embodiment.rz(1.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.1, 0)\n embodiment.rz(1.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.3, 0)\n embodiment.rz(1.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.5, 0)\n embodiment.rz(1.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.7, 0)\n embodiment.rz(1.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.9, 0)\n embodiment.rz(2.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(2.1, 0)\n embodiment.rz(2.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment" + ], + "line": 90, + "token": 807, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 6, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_einsum(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_einsum", + "reference": "\n def test_einsum(self):\n x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n self.assertAllClose(\n knp.einsum(\"ijk,lkj->il\", x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"ijk,ikj->i\", x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"i...,j...k->...ijk\", x, y),\n np.einsum(\"i..., j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n self.assertAllClose(\n knp.Einsum(\"ijk,lkj->il\")(x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"ijk,ikj->i\")(x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n np.einsum(\"i...,j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))", + "completions": [ + "```python\n\n def test_arange(self):\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32),\n torch.arange(0, 10, 1, dtype=torch.float32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64),\n torch.arange(0, 10, 1, dtype=torch.float64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32),\n torch.arange(0, 10, 1, dtype=torch.int32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64),\n torch.arange(0, 10, 1, dtype=torch.int64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool),\n torch.arange(0, 10, 1, dtype=torch.bool),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64),\n torch.arange(0, 10, 1, dtype=torch.complex64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128),\n torch.arange(0, 10, 1, dtype=torch.complex128),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n )\n```\n" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 7, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_center_padding(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_center_padding", + "reference": "\n def test_spectrogram_center_padding(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"reflect\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,\n 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,\n 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,\n 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,\n 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,\n 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,\n 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,\n 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,\n 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,\n 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,\n 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,\n 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,\n 0.00217659, 0.00276204, 0.00260835, 0.00299299,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"constant\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,\n 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,\n 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,\n 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,\n 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,\n 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,\n 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,\n 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,\n 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,\n 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,\n 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,\n 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,\n 0.00788239, 0.00664407, 0.00824227, 0.00628301,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=False,\n )\n self.assertEqual(spec.shape, (257, 728))\n\n # fmt: off\n expected = np.array([\n 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,\n 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,\n 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,\n 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,\n 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,\n 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,\n 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,\n 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,\n 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,\n 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,\n 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,\n 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,\n 0.00811857, 0.00538216, 0.00685749, 0.00535275,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))", + "completions": [ + "```python\n def test_spectrogram_center_padding(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 4\n hop_length = 2\n fft_length = 16\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n np.testing.assert_array_equal(\n spectrogram,\n np.array(\n [\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0," + ], + "line": 95, + "token": 804, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 8, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_kword_ref_line", + "reference": "\ndef test_nanoplot_out_bars_with_kword_ref_line():\n\n out_bars_with_kword_ref_line = _generate_nanoplot(**CASES[12])\n\n assert _is_nanoplot_output(out_bars_with_kword_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = \"max\"\n y_ref_area = [1, 2]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = \"min\"\n expand_y = \"min\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_stroke_color=data_bar_negative_stroke_" + ], + "line": 749, + "token": 7038, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 9, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_take_along_axis(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_take_along_axis", + "reference": "\n def test_take_along_axis(self):\n x = np.arange(24).reshape([1, 2, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=1),\n np.take_along_axis(x, indices, axis=1),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=1)(x, indices),\n np.take_along_axis(x, indices, axis=1),\n )\n\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )\n\n # Test with axis=None\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.array([1, 2, 3], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=None),\n np.take_along_axis(x, indices, axis=None),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=None)(x, indices),\n np.take_along_axis(x, indices, axis=None),\n )\n\n # Test with negative indices\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )", + "completions": [ + "```python\n\n def test_take_along_axis(self):\n # test arange\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=1, dtype=torch.float32),\n torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=2, dtype=torch.float32),\n torch.tensor([0, 2, 4, 6, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=3, dtype=torch.float32),\n torch.tensor([0, 3, 6, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=4, dtype=torch.float32),\n torch.tensor([0, 4, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=5, dtype=torch.float32),\n torch.tensor([0, 5], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=6, dtype=torch.float32),\n torch.tensor([0, 6], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=7, dtype=torch.float32),\n torch.tensor([0, 7], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=8, dtype=torch.float32),\n torch.tensor([0, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=9, dtype=torch.float32),\n torch.tensor([0, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=10, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=11, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=12, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=13, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=14, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 10, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_correlation(\n response_level_scores: dict[str, list[dict[str, dict[str, Any]]]],\n name: str = '',\n) -> dict[str, dict[str, dict[str, float]]]:\n factscore_metrics = set(list_metrics(response_level_scores[FACTSCORE]))\n pred_metrics = set(list_metrics(response_level_scores[EVAL_METHOD]))\n result = {}\n\n for metric in factscore_metrics.intersection(pred_metrics):\n factscore_scores = [\n find_metric(response_level_scores[FACTSCORE][i], metric)\n for i in range(len(response_level_scores[FACTSCORE]))\n ]\n pred_scores = [\n find_metric(response_level_scores[EVAL_METHOD][i], metric)\n for i in range(len(response_level_scores[EVAL_METHOD]))\n ]\n\n # Remove failed runs\n for i in range(len(factscore_scores) - 1, -1, -1):\n if factscore_scores[i] == -1 or pred_scores[i] == -1:\n factscore_scores.pop(i)\n pred_scores.pop(i)\n\n # Cannot calculate correlation with only one data point\n if len(factscore_scores) <= 1 or len(pred_scores) <= 1:\n pearson_result, spearman_result = None, None\n else:\n pearson_result = stats.pearsonr(factscore_scores, pred_scores)\n spearman_result = stats.spearmanr(factscore_scores, pred_scores)\n\n scatter_plot(\n factscore_scores,\n pred_scores,\n f'{name}-{metric}',\n FACTSCORE,\n EVAL_METHOD,\n )\n result[metric] = {\n _PEARSON: CorrelationResult(pearson_result).result,\n _SPEARMAN: CorrelationResult(spearman_result).result,\n }\n\n return result\n```\n###test function signature:\n```python\n\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_correlation_one_data_point", + "reference": "\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n response_level_scores = {\n correlation_vs_factscore.FACTSCORE: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'N': 1}},\n ],\n correlation_vs_factscore.EVAL_METHOD: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'M': 1}},\n ],\n }\n actual_output = correlation_vs_factscore.compute_correlation(\n response_level_scores=response_level_scores\n )\n mock_pearsonr.assert_not_called()\n mock_spearmanr.assert_not_called()\n mock_scatter_plot.assert_called()\n self.assertIsInstance(actual_output, dict)\n self.assertLen(actual_output, 1)\n self.assertIn('S', actual_output)\n self.assertIsInstance(actual_output['S'], dict)\n self.assertLen(actual_output['S'], 2)\n self.assertIn('Pearson', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Pearson'], dict)\n self.assertIn('correlation', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['correlation']))\n self.assertIn('p_value', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['p_value']))\n self.assertIn('Spearman', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Spearman'], dict)\n self.assertIn('correlation', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['correlation']))\n self.assertIn('p_value', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['p_value']))", + "completions": [ + "```python\ndef test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n) -> None:\n response_level_scores = {\n FACTSCORE: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n EVAL_METHOD: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n }\n result = compute_correlation(response_level_scores)\n self.assertEqual(result, {})\n mock_spearmanr.assert_not_called()\n mock_pearsonr.assert_not_called()\n mock_scatter_plot.assert_not_called()\n```\n" + ], + "line": 54, + "token": 384, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 11, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_data_lines():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_data_lines", + "reference": "\ndef test_nanoplot_out_data_lines():\n\n out_data_lines = _generate_nanoplot(**CASES[0])\n\n assert _is_nanoplot_output(out_data_lines)\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_data_lines():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n expected = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_" + ], + "line": 749, + "token": 7038, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 12, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_daily_temp_mean", + "reference": "\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data = il_electricity_cdd_hdd_daily[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_daily[\"temperature_data\"]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert df.shape == (810, 3)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_mean\",\n ]\n\n assert round(df.temperature_mean.mean()) == 55.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"D\", tz=\"UTC\"\n )\n temperature_data = pd.DataFrame(\n {\"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]},\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n df = compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=[],\n cooling_balance_points=[],\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n )\n expected = pd.DataFrame(\n {\n \"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],\n \"n_hours_kept\": [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],\n \"n_hours_dropped\": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n },\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n pd.testing.assert_frame_equal(df, expected)\n```\n" + ], + "line": 166, + "token": 983, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 13, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef power_color(\n frequency,\n power,\n power_err=None,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n):\n freq_edges = np.asarray(freq_edges)\n if len(freq_edges) != 5:\n raise ValueError(\"freq_edges must have 5 elements\")\n\n frequency = np.asarray(frequency)\n power = np.asarray(power)\n\n if df is None:\n df = np.median(np.diff(frequency))\n input_frequency_low_edges = frequency - df / 2\n input_frequency_high_edges = frequency + df / 2\n\n if freq_edges.min() < input_frequency_low_edges[0]:\n raise ValueError(\"The minimum frequency is larger than the first frequency edge\")\n if freq_edges.max() > input_frequency_high_edges[-1]:\n raise ValueError(\"The maximum frequency is lower than the last frequency edge\")\n\n if power_err is None:\n power_err = power / np.sqrt(m)\n else:\n power_err = np.asarray(power_err)\n\n if freqs_to_exclude is not None:\n if len(np.shape(freqs_to_exclude)) == 1:\n freqs_to_exclude = [freqs_to_exclude]\n\n if (\n not isinstance(freqs_to_exclude, Iterable)\n or len(np.shape(freqs_to_exclude)) != 2\n or np.shape(freqs_to_exclude)[1] != 2\n ):\n raise ValueError(\"freqs_to_exclude must be of format [[f0, f1], [f2, f3], ...]\")\n for f0, f1 in freqs_to_exclude:\n frequency_mask = (input_frequency_low_edges > f0) & (input_frequency_high_edges < f1)\n idx0, idx1 = np.searchsorted(frequency, [f0, f1])\n power[frequency_mask] = np.mean([power[idx0], power[idx1]])\n\n var00, var00_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[:2],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var01, var01_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[2:4],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var10, var10_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[1:3],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var11, var11_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[3:5],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n pc0 = var00 / var01\n pc1 = var10 / var11\n pc0_err = pc0 * (var00_err / var00 + var01_err / var01)\n pc1_err = pc1 * (var10_err / var10 + var11_err / var11)\n if return_log:\n pc0_err = 1 / pc0 * pc0_err\n pc1_err = 1 / pc1 * pc1_err\n pc0 = np.log10(pc0)\n pc1 = np.log10(pc1)\n return pc0, pc0_err, pc1, pc1_err\n```\n###test function signature:\n```python\n\n def test_with_power_err(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_power_err", + "reference": "\n def test_with_power_err(self):\n pc0, pc0_err, pc1, pc1_err = power_color(\n self.freq,\n self.power,\n power_err=self.power / 2,\n )\n pc0e, pc0e_err, pc1e, pc1e_err = power_color(\n self.freq,\n self.power,\n power_err=self.power,\n )\n assert np.isclose(pc0, 1, atol=0.001)\n assert np.isclose(pc1, 1, atol=0.001)\n assert np.isclose(pc0e, 1, atol=0.001)\n assert np.isclose(pc1e, 1, atol=0.001)\n assert np.isclose(pc0e_err / pc0_err, 2, atol=0.001)\n assert np.isclose(pc1e_err / pc1_err, 2, atol=0.001)", + "completions": [ + "```python\n def test_with_power_err(self):\n frequency = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power_err = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n pc0, pc0_err, pc1, pc1_err = power_color(\n frequency,\n power,\n power_err=power_err,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n )\n self.assertEqual(pc0, 0.0)\n self.assertEqual(pc0_err, 0.0)\n self.assertEqual(pc1, 0.0)\n self.assertEqual(pc1_err, 0.0)\n```\n" + ], + "line": 97, + "token": 560, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 14, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def apply_top_k(\n self, logits: numpy.ndarray, filter_value=_MIN_FLOAT\n ) -> numpy.ndarray:\n logits_shape = logits.shape\n logits = logits.reshape(logits.shape[-1])\n top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]\n logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value\n\n return logits.reshape(logits_shape)\n```\n###test function signature:\n```python\n\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_apply_topk", + "reference": "\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11))\n\n token_generator = TokenGenerator(\n logits_shape=logits[-1].shape[-1],\n top_k=3,\n )\n\n filter_value = -float(\"Inf\")\n new_logits = token_generator.apply_top_k(\n logits.copy(), filter_value=filter_value\n )\n\n for _ in range(token_generator.top_k):\n curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits)\n assert curr_max > filter_value\n new_logits = numpy.delete(new_logits, idx)\n\n assert numpy.all(new_logits == filter_value)", + "completions": [ + "```python\n def test_apply_topk(\n self,\n ):\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9," + ], + "line": 21, + "token": 152, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(\n self,\n sources: List[Union[str, Path, ByteStream]],\n meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,\n ):\n parser = MarkdownIt(renderer_cls=RendererPlain)\n if self.table_to_single_line:\n parser.enable(\"table\")\n\n documents = []\n meta_list = normalize_metadata(meta=meta, sources_count=len(sources))\n\n for source, metadata in tqdm(\n zip(sources, meta_list),\n total=len(sources),\n desc=\"Converting markdown files to Documents\",\n disable=not self.progress_bar,\n ):\n try:\n bytestream = get_bytestream_from_source(source)\n except Exception as e:\n logger.warning(\"Could not read {source}. Skipping it. Error: {error}\", source=source, error=e)\n continue\n try:\n file_content = bytestream.data.decode(\"utf-8\")\n text = parser.render(file_content)\n except Exception as conversion_e:\n logger.warning(\n \"Failed to extract text from {source}. Skipping it. Error: {error}\",\n source=source,\n error=conversion_e,\n )\n continue\n\n merged_metadata = {**bytestream.meta, **metadata}\n document = Document(content=text, meta=merged_metadata)\n documents.append(document)\n\n return {\"documents\": documents}\n```\n###test function signature:\n```python\n\n def test_run_calls_normalize_metadata(self, test_files_path):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_calls_normalize_metadata", + "reference": "\n def test_run_calls_normalize_metadata(self, test_files_path):\n bytestream = ByteStream(data=b\"test\", meta={\"author\": \"test_author\", \"language\": \"en\"})\n\n converter = MarkdownToDocument()\n\n with patch(\"haystack.components.converters.markdown.normalize_metadata\") as normalize_metadata, patch(\n \"haystack.components.converters.markdown.MarkdownIt\"\n ):\n converter.run(sources=[bytestream, test_files_path / \"markdown\" / \"sample.md\"], meta={\"language\": \"it\"})\n\n # check that the metadata normalizer is called properly\n normalize_metadata.assert_called_with(meta={\"language\": \"it\"}, sources_count=2)", + "completions": [ + "```python\n def test_run_calls_normalize_metadata(self, test_files_path):\n # Arrange\n sources = [test_files_path / \"test_file.md\"]\n meta = {\"key\": \"value\"}\n\n # Act\n run(sources, meta)\n\n # Assert\n normalize_metadata.assert_called_once_with(meta=meta, sources_count=1)\n```\n" + ], + "line": 45, + "token": 311, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 16, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_config(config: object) -> None:\n if hasattr(config, \"kind\"):\n Config.running_config.classes[config.kind] = config # type: ignore\n Config.running_config.types[config.kind] = {}\n for field in fields(config): # type: ignore\n if hasattr(field, \"type\"):\n Config.running_config.types[config.kind][field.name] = optional_origin(field.type)\n else:\n raise RuntimeError(\"Config must have a 'kind' attribute\")\n```\n###test function signature:\n```python\n\ndef test_config():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_config", + "reference": "\ndef test_config():\n arg_parser = get_arg_parser()\n core_add_args(arg_parser)\n arg_parser.parse_known_args()\n cfg = Config(\"test\")\n cfg.add_config(ConfigTest)\n cfg.init_default_config()\n assert Config.dict() == {\n \"configtest\": {\n \"testvar1\": \"testing123\",\n \"testvar2\": 12345,\n \"testvar3\": {\n \"mydict\": {\"foo\": \"bar\", \"abc\": {\"def\": \"ghi\"}},\n \"myint\": 0,\n \"mystr\": \"Hello\",\n },\n }\n }\n cfg2 = Config(\"test2\")\n assert cfg.configtest.testvar1 == cfg2.configtest.testvar1\n assert cfg.configtest.testvar1 == Config.configtest.testvar1\n assert cfg.configtest.testvar2 == cfg2.configtest.testvar2\n assert cfg.configtest.testvar3 == cfg2.configtest.testvar3\n Config.configtest.testvar2 += 1\n assert cfg.configtest.testvar2 == 12346\n with pytest.raises(ConfigNotFoundError):\n Config.does_not_exist.foo = \"bar\"\n with pytest.raises(ConfigNotFoundError):\n cfg.does_not_exist.foo = \"bar\"", + "completions": [ + "```python\ndef test_config():\n from dataclasses import dataclass, field\n from typing import Optional\n\n @dataclass\n class Config:\n kind: str\n a: Optional[int] = field(default=None)\n b: Optional[int] = field(default=None)\n\n add_config(Config)\n\n assert Config.running_config.classes[\"Config\"] == Config\n assert Config.running_config.types[\"Config\"] == {\"a\": int, \"b\": int}\n```\n" + ], + "line": 19, + "token": 166, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 17, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_with_num_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_with_num_ref_line", + "reference": "\ndef test_nanoplot_out_with_num_ref_line():\n\n out_with_num_ref_line = _generate_nanoplot(**CASES[1])\n\n assert _is_nanoplot_output(out_with_num_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_with_num_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color," + ], + "line": 749, + "token": 7038, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 18, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def begin_typing(self, x=None):\n self.capturekeystrokes = True\n # Disable keypress shortcuts, which may otherwise cause the figure to\n # be saved, closed, etc., until the user stops typing. The way to\n # achieve this depends on whether toolmanager is in use.\n stack = ExitStack() # Register cleanup actions when user stops typing.\n self._on_stop_typing = stack.close\n toolmanager = getattr(\n self.ax.figure.canvas.manager, \"toolmanager\", None)\n if toolmanager is not None:\n # If using toolmanager, lock keypresses, and plan to release the\n # lock when typing stops.\n toolmanager.keypresslock(self)\n stack.callback(toolmanager.keypresslock.release, self)\n else:\n # If not using toolmanager, disable all keypress-related rcParams.\n # Avoid spurious warnings if keymaps are getting deprecated.\n with _api.suppress_matplotlib_deprecation_warning():\n stack.enter_context(mpl.rc_context(\n {k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n```\n###test function signature:\n```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_TextBox", + "reference": "\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n plt.rcParams._set(\"toolbar\", toolbar)\n\n submit_event = mock.Mock(spec=noop, return_value=None)\n text_change_event = mock.Mock(spec=noop, return_value=None)\n tool = widgets.TextBox(ax, '')\n tool.on_submit(submit_event)\n tool.on_text_change(text_change_event)\n\n assert tool.text == ''\n\n do_event(tool, '_click')\n\n tool.set_val('x**2')\n\n assert tool.text == 'x**2'\n assert text_change_event.call_count == 1\n\n tool.begin_typing()\n tool.stop_typing()\n\n assert submit_event.call_count == 2\n\n do_event(tool, '_click')\n do_event(tool, '_keypress', key='+')\n do_event(tool, '_keypress', key='5')\n\n assert text_change_event.call_count == 3", + "completions": [ + "```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n with ExitStack() as stack:\n stack.enter_context(mpl.rc_context({k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n ax.begin_typing()\n assert ax.capturekeystrokes is True\n assert toolbar.keypresslock.is_locked(ax)\n```\n" + ], + "line": 31, + "token": 258, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 19, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def read_events(\n db: PrefectDBInterface,\n session: AsyncSession,\n events_filter: EventFilter,\n limit: \"int | None\" = None,\n offset: \"int | None\" = None,\n) -> Sequence[ORMEvent]:\n # Always order by occurred timestamp, with placeholder for order direction\n order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc\n\n # Check if distinct fields are provided\n if distinct_fields := build_distinct_queries(events_filter):\n # Define window function\n window_function = (\n sa.func.row_number()\n .over(partition_by=distinct_fields, order_by=order(db.Event.occurred))\n .label(\"row_number\")\n )\n # Create a subquery with the window function\n subquery = (\n sa.select(db.Event, window_function)\n .where(\n sa.and_(\n *events_filter.build_where_clauses(db)\n ) # Ensure the same filters are applied here\n )\n .subquery()\n )\n\n # Alias the subquery for easier column references\n aliased_table = aliased(db.Event, subquery)\n\n # Create the final query from the subquery, filtering to get only rows with row_number = 1\n select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1)\n\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(subquery.c.occurred))\n\n else:\n # If no distinct fields are provided, create a query for all events\n select_events_query = sa.select(db.Event).where(\n sa.and_(*events_filter.build_where_clauses(db))\n )\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(db.Event.occurred))\n\n if limit is not None:\n limit = max(0, min(limit, events_filter.logical_limit))\n select_events_query = select_events_query.limit(limit=limit)\n if offset is not None:\n select_events_query = select_events_query.offset(offset=offset)\n\n logger.debug(\"Running PostgreSQL query: %s\", select_events_query)\n\n select_events_query_result = await session.execute(select_events_query)\n return select_events_query_result.scalars().unique().all()\n```\n###test function signature:\n```python\n\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_read_events", + "reference": "\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_1.id\n\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=3),\n until=pendulum.now(\"UTC\").subtract(days=1),\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_2.id\n\n # Read criteria should apply AND logic\n # Nothing should be returned for this resource\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n resource=EventResourceFilter(id=[\"prefect.garbage.foo\"]),\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 0", + "completions": [ + "```python\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n # Arrange\n events_filter = EventFilter(\n occurred_after=datetime.datetime(2022, 1, 1),\n occurred_before=datetime.datetime(2022, 1, 2),\n order=EventOrder.DESC,\n )\n # Act\n events = await read_events(\n db=self.db,\n session=session,\n events_filter=events_filter,\n )\n # Assert\n self.assertEqual(len(events), 2)\n self.assertEqual(events[0], event_2)\n self.assertEqual(events[1], event_1)\n```\n" + ], + "line": 58, + "token": 412, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 20, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n def test_evaluate_diff_dim(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_evaluate_diff_dim", + "reference": " def test_evaluate_diff_dim(self):\n x1 = np.arange(3, 10, 2)\n kde = mlab.GaussianKDE(x1)\n x2 = np.arange(3, 12, 2)\n y_expected = [\n 0.08797252, 0.11774109, 0.11774109, 0.08797252, 0.0370153\n ]\n y = kde.evaluate(x2)\n np.testing.assert_array_almost_equal(y, y_expected, 7)", + "completions": [ + "```python\n def test_evaluate_diff_dim(self):\n \"\"\"\n Test evaluate with different dimensions\n \"\"\"\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n\n # Generate a random dataset\n X, _ = make_blobs(n_samples=100, n_features=10, centers=5,\n random_state=42)\n\n # Standardize the data\n scaler = StandardScaler()\n X = scaler.fit_transform(X)\n\n # Create a GaussianMixture object\n gmm = GaussianMixture(n_components=5, covariance_type='full')\n gmm.fit(X)\n\n # Generate a random set of points\n n_points = 100\n points = np.random.rand(n_points, 10)\n\n # Evaluate the points\n result = gmm.evaluate(points)\n\n # Check that the result is correct\n self.assertEqual(result.shape, (n_points,))\n self.assertTrue(np.all(result >= 0))\n self.assertTrue(np.all(result <= 1))\n```\n" + ], + "line": 33, + "token": 273, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 21, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _differentiate(func, x, *, args=(), atol=None, rtol=None, maxiter=10,\n order=8, initial_step=0.5, step_factor=2.0,\n step_direction=0, preserve_shape=False, callback=None):\n # TODO (followup):\n # - investigate behavior at saddle points\n # - array initial_step / step_factor?\n # - multivariate functions?\n\n res = _differentiate_iv(func, x, args, atol, rtol, maxiter, order, initial_step,\n step_factor, step_direction, preserve_shape, callback)\n (func, x, args, atol, rtol, maxiter, order,\n h0, fac, hdir, preserve_shape, callback) = res\n\n # Initialization\n # Since f(x) (no step) is not needed for central differences, it may be\n # possible to eliminate this function evaluation. However, it's useful for\n # input validation and standardization, and everything else is designed to\n # reduce function calls, so let's keep it simple.\n temp = eim._initialize(func, (x,), args, preserve_shape=preserve_shape)\n func, xs, fs, args, shape, dtype = temp\n x, f = xs[0], fs[0]\n df = np.full_like(f, np.nan)\n # Ideally we'd broadcast the shape of `hdir` in `_elementwise_algo_init`, but\n # it's simpler to do it here than to generalize `_elementwise_algo_init` further.\n # `hdir` and `x` are already broadcasted in `_differentiate_iv`, so we know\n # that `hdir` can be broadcasted to the final shape.\n hdir = np.broadcast_to(hdir, shape).flatten()\n\n status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress\n nit, nfev = 0, 1 # one function evaluations performed above\n # Boolean indices of left, central, right, and (all) one-sided steps\n il = hdir < 0\n ic = hdir == 0\n ir = hdir > 0\n io = il | ir\n\n # Most of these attributes are reasonably obvious, but:\n # - `fs` holds all the function values of all active `x`. The zeroth\n # axis corresponds with active points `x`, the first axis corresponds\n # with the different steps (in the order described in\n # `_differentiate_weights`).\n # - `terms` (which could probably use a better name) is half the `order`,\n # which is always even.\n work = _RichResult(x=x, df=df, fs=f[:, np.newaxis], error=np.nan, h=h0,\n df_last=np.nan, error_last=np.nan, h0=h0, fac=fac,\n atol=atol, rtol=rtol, nit=nit, nfev=nfev,\n status=status, dtype=dtype, terms=(order+1)//2,\n hdir=hdir, il=il, ic=ic, ir=ir, io=io)\n # This is the correspondence between terms in the `work` object and the\n # final result. In this case, the mapping is trivial. Note that `success`\n # is prepended automatically.\n res_work_pairs = [('status', 'status'), ('df', 'df'), ('error', 'error'),\n ('nit', 'nit'), ('nfev', 'nfev'), ('x', 'x')]\n\n def pre_func_eval(work):\n \"\"\"Determine the abscissae at which the function needs to be evaluated.\n\n See `_differentiate_weights` for a description of the stencil (pattern\n of the abscissae).\n\n In the first iteration, there is only one stored function value in\n `work.fs`, `f(x)`, so we need to evaluate at `order` new points. In\n subsequent iterations, we evaluate at two new points. Note that\n `work.x` is always flattened into a 1D array after broadcasting with\n all `args`, so we add a new axis at the end and evaluate all point\n in one call to the function.\n\n For improvement:\n - Consider measuring the step size actually taken, since `(x + h) - x`\n is not identically equal to `h` with floating point arithmetic.\n - Adjust the step size automatically if `x` is too big to resolve the\n step.\n - We could probably save some work if there are no central difference\n steps or no one-sided steps.\n \"\"\"\n n = work.terms # half the order\n h = work.h # step size\n c = work.fac # step reduction factor\n d = c**0.5 # square root of step reduction factor (one-sided stencil)\n # Note - no need to be careful about dtypes until we allocate `x_eval`\n\n if work.nit == 0:\n hc = h / c**np.arange(n)\n hc = np.concatenate((-hc[::-1], hc))\n else:\n hc = np.asarray([-h, h]) / c**(n-1)\n\n if work.nit == 0:\n hr = h / d**np.arange(2*n)\n else:\n hr = np.asarray([h, h/d]) / c**(n-1)\n\n n_new = 2*n if work.nit == 0 else 2 # number of new abscissae\n x_eval = np.zeros((len(work.hdir), n_new), dtype=work.dtype)\n il, ic, ir = work.il, work.ic, work.ir\n x_eval[ir] = work.x[ir, np.newaxis] + hr\n x_eval[ic] = work.x[ic, np.newaxis] + hc\n x_eval[il] = work.x[il, np.newaxis] - hr\n return x_eval\n\n def post_func_eval(x, f, work):\n \"\"\" Estimate the derivative and error from the function evaluations\n\n As in `pre_func_eval`: in the first iteration, there is only one stored\n function value in `work.fs`, `f(x)`, so we need to add the `order` new\n points. In subsequent iterations, we add two new points. The tricky\n part is getting the order to match that of the weights, which is\n described in `_differentiate_weights`.\n\n For improvement:\n - Change the order of the weights (and steps in `pre_func_eval`) to\n simplify `work_fc` concatenation and eliminate `fc` concatenation.\n - It would be simple to do one-step Richardson extrapolation with `df`\n and `df_last` to increase the order of the estimate and/or improve\n the error estimate.\n - Process the function evaluations in a more numerically favorable\n way. For instance, combining the pairs of central difference evals\n into a second-order approximation and using Richardson extrapolation\n to produce a higher order approximation seemed to retain accuracy up\n to very high order.\n - Alternatively, we could use `polyfit` like Jacobi. An advantage of\n fitting polynomial to more points than necessary is improved noise\n tolerance.\n \"\"\"\n n = work.terms\n n_new = n if work.nit == 0 else 1\n il, ic, io = work.il, work.ic, work.io\n\n # Central difference\n # `work_fc` is *all* the points at which the function has been evaluated\n # `fc` is the points we're using *this iteration* to produce the estimate\n work_fc = (f[ic, :n_new], work.fs[ic, :], f[ic, -n_new:])\n work_fc = np.concatenate(work_fc, axis=-1)\n if work.nit == 0:\n fc = work_fc\n else:\n fc = (work_fc[:, :n], work_fc[:, n:n+1], work_fc[:, -n:])\n fc = np.concatenate(fc, axis=-1)\n\n # One-sided difference\n work_fo = np.concatenate((work.fs[io, :], f[io, :]), axis=-1)\n if work.nit == 0:\n fo = work_fo\n else:\n fo = np.concatenate((work_fo[:, 0:1], work_fo[:, -2*n:]), axis=-1)\n\n work.fs = np.zeros((len(ic), work.fs.shape[-1] + 2*n_new))\n work.fs[ic] = work_fc\n work.fs[io] = work_fo\n\n wc, wo = _differentiate_weights(work, n)\n work.df_last = work.df.copy()\n work.df[ic] = fc @ wc / work.h\n work.df[io] = fo @ wo / work.h\n work.df[il] *= -1\n\n work.h /= work.fac\n work.error_last = work.error\n # Simple error estimate - the difference in derivative estimates between\n # this iteration and the last. This is typically conservative because if\n # convergence has begin, the true error is much closer to the difference\n # between the current estimate and the *next* error estimate. However,\n # we could use Richarson extrapolation to produce an error estimate that\n # is one order higher, and take the difference between that and\n # `work.df` (which would just be constant factor that depends on `fac`.)\n work.error = abs(work.df - work.df_last)\n\n def check_termination(work):\n \"\"\"Terminate due to convergence, non-finite values, or error increase\"\"\"\n stop = np.zeros_like(work.df).astype(bool)\n\n i = work.error < work.atol + work.rtol*abs(work.df)\n work.status[i] = eim._ECONVERGED\n stop[i] = True\n\n if work.nit > 0:\n i = ~((np.isfinite(work.x) & np.isfinite(work.df)) | stop)\n work.df[i], work.status[i] = np.nan, eim._EVALUEERR\n stop[i] = True\n\n # With infinite precision, there is a step size below which\n # all smaller step sizes will reduce the error. But in floating point\n # arithmetic, catastrophic cancellation will begin to cause the error\n # to increase again. This heuristic tries to avoid step sizes that are\n # too small. There may be more theoretically sound approaches for\n # detecting a step size that minimizes the total error, but this\n # heuristic seems simple and effective.\n i = (work.error > work.error_last*10) & ~stop\n work.status[i] = _EERRORINCREASE\n stop[i] = True\n\n return stop\n\n def post_termination_check(work):\n return\n\n def customize_result(res, shape):\n return shape\n\n return eim._loop(work, callback, shape, maxiter, func, args, dtype,\n pre_func_eval, post_func_eval, check_termination,\n post_termination_check, customize_result, res_work_pairs,\n preserve_shape)\n```\n###test function signature:\n```python\n\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_maxiter_callback", + "reference": "\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n x = 0.612814\n dist = stats.norm()\n maxiter = 3\n\n def f(x):\n res = dist.cdf(x)\n return res\n\n default_order = 8\n res = differentiate(f, x, maxiter=maxiter, rtol=1e-15)\n assert not np.any(res.success)\n assert np.all(res.nfev == default_order + 1 + (maxiter - 1)*2)\n assert np.all(res.nit == maxiter)\n\n def callback(res):\n callback.iter += 1\n callback.res = res\n assert hasattr(res, 'x')\n assert res.df not in callback.dfs\n callback.dfs.add(res.df)\n assert res.status == eim._EINPROGRESS\n if callback.iter == maxiter:\n raise StopIteration\n callback.iter = -1 # callback called once before first iteration\n callback.res = None\n callback.dfs = set()\n\n res2 = differentiate(f, x, callback=callback, rtol=1e-15)\n # terminating with callback is identical to terminating due to maxiter\n # (except for `status`)\n for key in res.keys():\n if key == 'status':\n assert res[key] == eim._ECONVERR\n assert callback.res[key] == eim._EINPROGRESS\n assert res2[key] == eim._ECALLBACK\n else:\n assert res2[key] == callback.res[key] == res[key]", + "completions": [ + "```python\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n def callback(work):\n if work.nit == 0:\n work.maxiter = 1\n else:\n work.maxiter = 0\n res = _differentiate(lambda x: x**2, 0, maxiter=2, callback=callback)\n self.assertEqual(res.nit, 1)\n self.assertEqual(res.status, 0)\n```\n" + ], + "line": 187, + "token": 1981, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 22, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef get_standard_metadata(\n pyproject_dict: Mapping[str, Any],\n settings: ScikitBuildSettings,\n) -> StandardMetadata:\n new_pyproject_dict = dict(pyproject_dict)\n # Handle any dynamic metadata\n for field, provider, config in load_dynamic_metadata(settings.metadata):\n if provider is None:\n msg = f\"{field} is missing provider\"\n raise KeyError(msg)\n if field not in pyproject_dict.get(\"project\", {}).get(\"dynamic\", []):\n msg = f\"{field} is not in project.dynamic\"\n raise KeyError(msg)\n new_pyproject_dict[\"project\"][field] = provider.dynamic_metadata(field, config)\n new_pyproject_dict[\"project\"][\"dynamic\"].remove(field)\n\n metadata = StandardMetadata.from_pyproject(new_pyproject_dict)\n\n # For scikit-build-core < 0.5, we keep the normalized name for back-compat\n if settings.minimum_version is not None and settings.minimum_version < Version(\n \"0.5\"\n ):\n metadata.name = metadata.canonical_name\n\n # The description field is required to be one line. Instead of merging it\n # or cutting off subsequent lines (setuptools), we throw a nice error.\n # But we didn't validate before 0.9.\n if (\n settings.minimum_version is None or settings.minimum_version >= Version(\"0.9\")\n ) and \"\\n\" in (metadata.description or \"\"):\n msg = \"Multiple lines in project.description are not supported; this is supposed to be a one line summary\"\n raise ValueError(msg)\n\n return metadata\n```\n###test function signature:\n```python\n\ndef test_plugin_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_plugin_metadata", + "reference": "\ndef test_plugin_metadata():\n reason_msg = (\n \"install hatch-fancy-pypi-readme and setuptools-scm to test the \"\n \"dynamic metadata plugins\"\n )\n pytest.importorskip(\"hatch_fancy_pypi_readme\", reason=reason_msg)\n pytest.importorskip(\"setuptools_scm\", reason=reason_msg)\n if shutil.which(\"git\") is None:\n pytest.skip(\"git is not installed\")\n\n shutil.copy(\"plugin_project.toml\", \"pyproject.toml\")\n\n subprocess.run([\"git\", \"init\", \"--initial-branch=main\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.name\", \"bot\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.email\", \"bot@scikit-build.org\"], check=True)\n subprocess.run([\"git\", \"add\", \"pyproject.toml\"], check=True)\n subprocess.run([\"git\", \"commit\", \"-m\", \"initial commit\"], check=True)\n subprocess.run([\"git\", \"tag\", \"v0.1.0\", \"-m\", \"initial commint\"], check=True)\n\n with Path(\"pyproject.toml\").open(\"rb\") as ft:\n pyproject = tomllib.load(ft)\n settings_reader = SettingsReader(pyproject, {}, state=\"metadata_wheel\")\n settings = settings_reader.settings\n\n settings_reader.validate_may_exit()\n\n metadata = get_standard_metadata(pyproject, settings)\n\n assert str(metadata.version) == \"0.1.0\"\n assert metadata.readme == pyproject_metadata.Readme(\n \"Fragment #1Fragment #2\", None, \"text/x-rst\"\n )\n\n assert set(GetRequires().dynamic_metadata()) == {\n \"hatch-fancy-pypi-readme>=22.3\",\n \"setuptools-scm\",\n }", + "completions": [ + "```python\n\ndef test_plugin_metadata():\n from skbuild import __version__ as skbuild_version\n from skbuild.settings import ScikitBuildSettings\n from skbuild.metadata import StandardMetadata\n from skbuild.utils import load_dynamic_metadata\n from skbuild.version import Version\n\n pyproject_dict = {\n \"build-system\": {\n \"requires\": [\"setuptools>=40.8.0\", \"wheel\"],\n \"build-backend\": \"setuptools.build_meta\",\n },\n \"project\": {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n \"description\": \"A build system for Python packages\",\n \"dynamic\": [\"name\", \"version\"],\n },\n }\n\n settings = ScikitBuildSettings(\n minimum_version=Version(\"0.1.0\"),\n metadata=[\n (\"name\", None, None),\n (\"version\", None, None),\n ],\n )\n\n metadata = get_standard_metadata(pyproject_dict, settings)\n\n assert metadata.name == \"scikit-build\"\n assert metadata.version == \"0.1.0\"\n assert metadata.description == \"A build system for Python packages\"\n\n assert metadata.canonical_name == \"scikit-build\"\n assert metadata.canonical_version == \"0.1.0\"\n\n assert metadata.dynamic_metadata == {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n }\n\n assert metadata.dynamic_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_providers == [None, None]\n\n assert metadata.dynamic_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_metadata_providers == [None, None]\n\n assert metadata.dynamic_metadata_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields_by_provider == {None: [\"name\", \"version\"]}\n\n assert metadata.dynamic_metadata_providers_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_configs_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field == {\n (None, \"name\"): \"scikit-build\",\n (None, \"version\"): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider == {\n (\"name\", None): \"scikit-build\",\n (\"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider == {\n (\"name\", None): None,\n (\"version\", None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config == {\n (None, \"name\", None): \"scikit-build\",\n (None, \"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config == {\n (\"name\", None, None): \"scikit-build\",\n (\"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider_and_config == {\n (\"name\", None, None): None,\n (\"version\", None, None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config_and_provider == {\n (None, \"name\", None, None): \"scikit-build\",\n (None, \"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config_and_provider == {\n (\"name\", None, None" + ], + "line": 40, + "token": 347, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 23, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _get_datetime_format(value):\n if not isinstance(value, pd.Series):\n value = pd.Series(value)\n\n value = value[~value.isna()]\n value = value.astype(str).to_numpy()\n\n return _guess_datetime_format_for_array(value)\n```\n###test function signature:\n```python\ndef test__get_datetime_format():\n # Setup\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test__get_datetime_format", + "reference": "def test__get_datetime_format():\n # Setup\n string_value = '2021-02-02'\n list_value = [np.nan, '2021-02-02']\n series_value = pd.Series(['2021-02-02T12:10:59'])\n\n # Run\n string_out = _get_datetime_format(string_value)\n list_out = _get_datetime_format(list_value)\n series_out = _get_datetime_format(series_value)\n\n # Assert\n expected_output = '%Y-%m-%d'\n assert string_out == expected_output\n assert list_out == expected_output\n assert series_out == '%Y-%m-%dT%H:%M:%S'", + "completions": [ + "```python\ndef test__get_datetime_format():\n # Setup\n value = pd.Series([\"2020-01-01\", \"2020-01-02\", \"2020-01-03\"])\n expected = \"%Y-%m-%d\"\n\n # Exercise\n actual = _get_datetime_format(value)\n\n # Verify\n assert actual == expected\n```\n" + ], + "line": 17, + "token": 124, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 24, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def load(self) -> None:\n super().load()\n\n _path = str(\n importlib_resources.files(\"distilabel\")\n / \"steps\"\n / \"tasks\"\n / \"templates\"\n / \"ultrafeedback\"\n / f\"{self.aspect}.jinja2\"\n )\n\n self._template = Template(open(_path).read())\n```\n###test function signature:\n```python\n\n def test_process_with_complex_aspect(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_process_with_complex_aspect", + "reference": "\n def test_process_with_complex_aspect(self) -> None:\n pipeline = Pipeline(name=\"unit-test-pipeline\")\n llm = UltraFeedbackLLM()\n\n task = UltraFeedback(\n name=\"ultrafeedback\",\n aspect=\"truthfulness\",\n llm=llm,\n pipeline=pipeline,\n )\n task.load()\n\n assert next(\n task.process([{\"instruction\": \"test\", \"generations\": [\"A\", \"B\"]}])\n ) == [\n {\n \"instruction\": \"test\",\n \"generations\": [\"A\", \"B\"],\n \"types\": [1, 2],\n \"rationales\": [\"text\", \"text\"],\n \"ratings\": [1, 2],\n \"rationales-for-ratings\": [\"text\", \"text\"],\n \"model_name\": \"ultrafeedback-model\",\n }\n ]", + "completions": [ + "```python\n def test_process_with_complex_aspect(self) -> None:\n self.aspect = \"complex_aspect\"\n self.load()\n self.assertEqual(\n self._template.render(\n {\n \"aspect\": self.aspect,\n \"aspect_name\": \"Complex Aspect\",\n \"aspect_description\": \"This is a complex aspect\",\n \"aspect_type\": \"complex\",\n \"aspect_value\": \"1\",\n \"aspect_weight\": \"1\",\n \"aspect_color\": \"red\",\n \"aspect_icon\": \"icon-circle\",\n \"aspect_icon_color\": \"red\",\n \"aspect_icon_bg_color\": \"white\",\n \"aspect_icon_size\": \"20\",\n \"aspect_icon_border_color\": \"red\",\n \"aspect_icon_border_width\": \"2\",\n \"aspect_icon_border_radius\": \"50\",\n \"aspect_icon_padding\": \"5\",\n \"aspect_icon_margin\": \"5\",\n \"aspect_icon_bg_size\": \"20\",\n \"aspect_icon_bg_border_color\": \"red\",\n \"aspect_icon_bg_border_width\": \"2\",\n \"aspect_icon_bg_border_radius\": \"50\",\n \"aspect_icon_bg_padding\": \"5\",\n \"aspect_icon_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_" + ], + "line": 21, + "token": 145, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 25, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _setup_cuda_device(device: torch.device) -> torch.device:\n local_rank = _get_local_rank() or 0\n if device.index is None:\n device = torch.device(type=\"cuda\", index=local_rank)\n\n # Ensure index is available before setting device\n if device.index >= torch.cuda.device_count():\n raise RuntimeError(\n \"The local rank is larger than the number of available GPUs.\"\n )\n\n torch.cuda.set_device(device)\n return device\n```\n###test function signature:\n```python\n\n def test_get_gpu_device(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_gpu_device", + "reference": "\n def test_get_gpu_device(self) -> None:\n device_idx = torch.cuda.device_count() - 1\n assert device_idx >= 0\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(device_idx)}, clear=True):\n device = get_device()\n assert device.type == \"cuda\"\n assert device.index == device_idx\n assert device.index == torch.cuda.current_device()\n\n # Test that we raise an error if the device index is specified on distributed runs\n if device_idx > 0:\n with pytest.raises(\n RuntimeError,\n match=f\"Device specified is cuda:0 but was assigned cuda:{device_idx}\",\n ):\n device = get_device(\"cuda:0\")\n\n invalid_device_idx = device_idx + 10\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(invalid_device_idx)}):\n with pytest.raises(\n RuntimeError,\n match=\"The local rank is larger than the number of available GPUs\",\n ):\n device = get_device(\"cuda\")\n\n # Test that we fall back to 0 if LOCAL_RANK is not specified\n device = torch.device(_get_device_type_from_env())\n device = _setup_cuda_device(device)\n assert device.type == \"cuda\"\n assert device.index == 0\n assert device.index == torch.cuda.current_device()", + "completions": [ + "```python\n def test_get_gpu_device(self) -> None:\n device = torch.device(type=\"cuda\", index=0)\n self.assertEqual(device, _setup_cuda_device(device))\n```\n" + ], + "line": 21, + "token": 157, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):\n local_dir = settings.home_dir / settings.cache_dir / \"models\" / model_id\n client.load_artifact(local_dir, posixpath.join(\"models\", model_id))\n meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)\n # internal worker case, no token based http client\n if not local_dir.exists():\n raise RuntimeError(f\"Cannot find existing model {project_key}.{model_id} in {local_dir}\")\n with (Path(local_dir) / META_FILENAME).open(encoding=\"utf-8\") as f:\n file_meta = yaml.load(f, Loader=yaml.Loader)\n classification_labels = cls.cast_labels(meta_response)\n meta = ModelMeta(\n name=meta_response.name,\n description=meta_response.description,\n model_type=SupportedModelTypes[meta_response.modelType],\n feature_names=meta_response.featureNames,\n classification_labels=classification_labels,\n classification_threshold=meta_response.threshold,\n loader_module=file_meta[\"loader_module\"],\n loader_class=file_meta[\"loader_class\"],\n )\n\n model_py_ver = (\n tuple(meta_response.languageVersion.split(\".\")) if \"PYTHON\" == meta_response.language.upper() else None\n )\n\n clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver)\n\n constructor_params = meta.__dict__\n constructor_params[\"id\"] = str(model_id)\n\n del constructor_params[\"loader_module\"]\n del constructor_params[\"loader_class\"]\n\n model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params)\n return model\n```\n###test function signature:\n```python\n\ndef test_model_download(request):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_model_download", + "reference": "\ndef test_model_download(request):\n model: BaseModel = request.getfixturevalue(\"german_credit_model\")\n project_key = str(uuid.uuid4())\n\n with utils.MockedProjectCacheDir():\n with utils.MockedClient(mock_all=False) as (client, mr):\n # The model needs to request files\n requested_urls = []\n requested_urls.extend(utils.register_uri_for_model_meta_info(mr, model, project_key))\n requested_urls.extend(\n utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True)\n )\n\n downloaded_model = BaseModel.download(client=client, project_key=project_key, model_id=str(model.id))\n\n for requested_url in requested_urls:\n assert utils.is_url_requested(mr.request_history, requested_url)\n\n assert downloaded_model.id == model.id\n assert downloaded_model.meta == model.meta", + "completions": [ + "```python\ndef test_model_download(request):\n from pathlib import Path\n from unittest.mock import MagicMock\n from giskard.client import GiskardClient\n from giskard.models import ModelMetaInfo, ModelMeta, SupportedModelTypes\n from giskard.utils import cast_labels\n from giskard.models.base import Model\n from giskard.models.classification import ClassificationModel\n from giskard.models.regression import RegressionModel\n from giskard.models.clustering import ClusteringModel\n from giskard.models.dim_reduction import DimReductionModel\n from giskard.models.anomaly_detection import AnomalyDetectionModel\n from giskard.models.time_series import TimeSeriesModel\n from giskard.models.text_classification import TextClassificationModel\n from giskard.models.text_recommendation import TextRecommendationModel\n from giskard.models.image_classification import ImageClassificationModel\n from giskard.models.image_recommendation import ImageRecommendationModel\n from giskard.models.image_segmentation import ImageSegmentationModel\n from giskard.models.image_generation import ImageGenerationModel\n from giskard.models.image_style_transfer import ImageStyleTransferModel\n from giskard.models.image_super_resolution import ImageSuperResolutionModel\n from giskard.models.image_denoising import ImageDenoisingModel\n from giskard.models.image_colorization import ImageColorizationModel\n from giskard.models.image_deraining import ImageDerainingModel\n from giskard.models.image_deblurring import ImageDeblurringModel\n from giskard.models.image_edge_detection import ImageEdgeDetectionModel\n from giskard.models.image_face_detection import ImageFaceDetectionModel\n from giskard.models.image_face_recognition import ImageFaceRecognitionModel\n from giskard.models.image_object_detection import ImageObjectDetectionModel\n from giskard.models.image_object_recognition import ImageObjectRecognitionModel\n from giskard.models.image_segmentation_mask_rcnn import ImageSegmentationMaskRCNNModel\n from giskard.models.image_segmentation_yolo import ImageSegmentationYOLOModel\n from giskard.models.image_segmentation_faster_rcnn import ImageSegmentationFasterRCNNModel\n from giskard.models.image_segmentation_unet import ImageSegmentationUNetModel\n from giskard.models.image_segmentation_deeplab import ImageSegmentationDeepLabModel\n from giskard.models.image_segmentation_fcn import ImageSegmentationFCNModel\n from giskard.models.image_segmentation_pspnet import ImageSegmentationPSPNetModel\n from giskard.models.image_segmentation_linknet import ImageSegmentationLinkNetModel\n from giskard.models.image_segmentation_unet_plus import ImageSegmentationUNetPlusModel\n from giskard.models.image_segmentation_deeplab_v3 import ImageSegmentationDeepLabV3Model\n from giskard.models.image_segmentation_deeplab_v3_plus import ImageSegmentationDeepLabV3PlusModel\n from giskard.models.image_segmentation_fcn_resnet101 import ImageSegmentationFCNResNet101Model\n from giskard.models.image_segmentation_fcn_resnet50 import ImageSegmentationFCNResNet50Model\n from giskard.models.image_segmentation_fcn_mobilenet_v2 import ImageSegmentationFCNMobileNetV2Model\n from giskard.models.image_segmentation_fcn_mobilenet_v3 import ImageSegmentationFCNMobileNetV3Model\n from giskard.models.image_segmentation_fcn_efficientnet_b" + ], + "line": 40, + "token": 316, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 27, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\ndef test_rounded_date_columns():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rounded_date_columns", + "reference": "def test_rounded_date_columns():\n df = pl.DataFrame(\n [\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.000000\", \"exp\": True},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.123456\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:01.000000\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00\", \"exp\": True},\n ]\n )\n\n df = df.with_columns(df[\"a\"].str.to_date().alias(\"a_dt\"))\n actual = columns_equal(df[\"a_dt\"], df[\"b\"])\n expected = df[\"exp\"]\n assert_series_equal(actual, expected, check_names=False)", + "completions": [ + "```python\ndef test_rounded_date_columns():\n # given\n col_1 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n col_2 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n\n # when\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n\n # then\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 28, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef annotations_to_instances(annos, image_size, mask_format=\"polygon\"):\n boxes = [BoxMode.convert(obj[\"bbox\"], obj[\"bbox_mode\"], BoxMode.XYXY_ABS) for obj in annos]\n target = Instances(image_size)\n boxes = target.gt_boxes = Boxes(boxes)\n boxes.clip(image_size)\n\n classes = [obj[\"category_id\"] for obj in annos]\n classes = torch.tensor(classes, dtype=torch.int64)\n target.gt_classes = classes\n\n if len(annos) and \"segmentation\" in annos[0]:\n segms = [obj[\"segmentation\"] for obj in annos]\n if mask_format == \"polygon\":\n masks = PolygonMasks(segms)\n else:\n assert mask_format == \"bitmask\", mask_format\n masks = []\n for segm in segms:\n if isinstance(segm, list):\n # polygon\n masks.append(polygons_to_bitmask(segm, *image_size))\n elif isinstance(segm, dict):\n # COCO RLE\n masks.append(mask_util.decode(segm))\n elif isinstance(segm, np.ndarray):\n assert segm.ndim == 2, \"Expect segmentation of 2 dimensions, got {}.\".format(\n segm.ndim\n )\n # mask array\n masks.append(segm)\n else:\n raise ValueError(\n \"Cannot convert segmentation of type '{}' to BitMasks!\"\n \"Supported types are: polygons as list[list[float] or ndarray],\"\n \" COCO-style RLE as a dict, or a full-image segmentation mask \"\n \"as a 2D ndarray.\".format(type(segm))\n )\n # torch.from_numpy does not support array with negative stride.\n masks = BitMasks(\n torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])\n )\n target.gt_masks = masks\n\n if len(annos) and \"keypoints\" in annos[0]:\n kpts = [obj.get(\"keypoints\", []) for obj in annos]\n target.gt_keypoints = Keypoints(kpts)\n\n return target\n```\n###test function signature:\n```python\n\n def test_transform_simple_annotation(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_simple_annotation", + "reference": "\n def test_transform_simple_annotation(self):\n transforms = T.TransformList([T.HFlipTransform(400)])\n anno = {\n \"bbox\": np.asarray([10, 10, 200, 300]),\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 3,\n \"segmentation\": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],\n }\n\n output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))\n self.assertTrue(np.allclose(output[\"bbox\"], [200, 10, 390, 300]))\n self.assertEqual(len(output[\"segmentation\"]), len(anno[\"segmentation\"]))\n self.assertTrue(np.allclose(output[\"segmentation\"][0], [390, 10, 300, 100, 300, 10]))\n\n detection_utils.annotations_to_instances([output, output], (400, 400))", + "completions": [ + "```python\n def test_transform_simple_annotation(self):\n annos = [\n {\n \"bbox\": [10, 10, 100, 100],\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 0,\n \"segmentation\": [\n [10, 10, 50, 50, 90, 90, 100, 100],\n [20, 20, 50, 50, 80, 80, 100, 100],\n ],\n \"keypoints\": [10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n }\n ]\n image_size = (100, 100)\n target = annotations_to_instances(annos, image_size)\n assert target.gt_boxes.tensor.tolist() == [[10, 10, 100, 100]]\n assert target.gt_classes.tolist() == [0]\n assert target.gt_masks.polygons().tolist() == [\n [[10, 10], [50, 50], [90, 90], [100, 100]],\n [[20, 20], [50, 50], [80, 80], [100, 100]],\n ]\n assert target.gt_keypoints.tensor.tolist() == [[10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n```\n" + ], + "line": 54, + "token": 437, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 29, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_simple(seed):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_simple", + "reference": "\ndef test_sabre_simple(seed):\n placer = Trivial()\n circ = Circuit(5)\n circ.add(gates.CZ(0, 1))\n initial_layout = placer(circ)\n router = Sabre(connectivity=star_connectivity(), seed=seed)\n routed_circuit, final_map = router(circuit=circ, initial_layout=initial_layout)\n assert router.added_swaps == 1\n assert final_map == {\"q0\": 2, \"q1\": 1, \"q2\": 0, \"q3\": 3, \"q4\": 4}\n assert routed_circuit.queue[0].qubits == (0, 2)\n assert isinstance(routed_circuit.queue[0], gates.SWAP)\n assert isinstance(routed_circuit.queue[1], gates.CZ)\n assert_connectivity(circuit=routed_circuit, connectivity=star_connectivity())\n assert_circuit_equivalence(\n original_circuit=circ,\n transpiled_circuit=routed_circuit,\n final_map=final_map,\n initial_map=initial_layout,\n )", + "completions": [ + "```python\ndef test_sabre_simple(seed):\n import random\n random.seed(seed)\n import networkx as nx\n from pyquil.quil import Program\n from pyquil.api import QVMConnection\n from pyquil.gates import CNOT, H, MEASURE\n from pyquil.quilbase import Declare\n from pyquil.quilbase import DefGate\n from pyquil.quilbase import DefCircuit\n from pyquil.quilbase import Gate\n from pyquil.quilbase import Circuit\n from pyquil.quilbase import QubitPlaceholder\n from pyquil.quilbase import MemoryReference\n from pyquil.quilbase import Parameter\n from pyquil.quilbase import Pragma\n from pyquil.quilbase import PragmaDeclare\n from pyquil.quilbase import PragmaDefGate\n from pyquil.quilbase import PragmaDefCircuit\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefQubit\n from pyquil.quilbase import PragmaDefMemory\n from pyquil.quilbase import PragmaDefRegister\n from pyquil.quilbase import PragmaDefLabel\n from pyquil.quilbase import PragmaDefFunction\n from pyquil.quilbase import PragmaDefProcedure\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyqu" + ], + "line": 27, + "token": 242, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 30, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef unique_patterns_2d(\n agrid: NDArray[np.int64], ksize: int, periodic_input: bool\n) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:\n assert ksize >= 1\n if periodic_input:\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n else:\n # TODO: implement non-wrapped image handling\n # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n\n patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(\n agrid,\n (\n agrid.shape[0] - ksize + 1,\n agrid.shape[1] - ksize + 1,\n ksize,\n ksize,\n *agrid.shape[2:],\n ),\n agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],\n writeable=False,\n )\n patch_codes = hash_downto(patches, 2)\n uc, ui = np.unique(patch_codes, return_index=True)\n locs = np.unravel_index(ui, patch_codes.shape)\n up: NDArray[np.int64] = patches[locs[0], locs[1]]\n ids: NDArray[np.int64] = np.vectorize(\n {code: ind for ind, code in enumerate(uc)}.get\n )(patch_codes)\n return ids, up, patch_codes\n```\n###test function signature:\n```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_patterns_2d", + "reference": "\ndef test_unique_patterns_2d(img_redmaze) -> None:\n img = img_redmaze\n tile_size = 1\n pattern_width = 2\n _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(\n img, tile_size\n )\n\n (\n _patterns_in_grid,\n pattern_contents_list,\n patch_codes,\n ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)\n assert patch_codes[1][2] == 4867810695119132864\n assert pattern_contents_list[7][1][1] == 8253868773529191888", + "completions": [ + "```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n ids, up, patch_codes = unique_patterns_2d(img_redmaze, 3, True)\n assert ids.shape == (10,)\n assert up.shape == (10, 3, 3)\n assert patch_codes.shape == (10,)\n```\n" + ], + "line": 48, + "token": 425, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 31, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef serialize(data: Any, encoding: str | None = None) -> str | bytes:\n serialized: bytes | str\n\n if encoding is not None:\n serialized = yaml.dump(data, encoding=encoding, Dumper=AiiDADumper)\n else:\n serialized = yaml.dump(data, Dumper=AiiDADumper)\n\n return serialized\n```\n###test function signature:\n```python\ndef test_mixed_attribute_normal_dict():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mixed_attribute_normal_dict", + "reference": "def test_mixed_attribute_normal_dict():\n from aiida.common.extendeddicts import AttributeDict\n\n # Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively\n dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}\n attribute_dict = AttributeDict(dictionary)\n\n # Now add a normal dictionary in the attribute dictionary\n attribute_dict['nested']['normal'] = {'a': 2}\n\n serialized = serialize.serialize(attribute_dict)\n deserialized = serialize.deserialize_unsafe(serialized)\n\n assert attribute_dict, deserialized", + "completions": [ + "```python\ndef test_mixed_attribute_normal_dict():\n data = {\n \"a\": 1,\n \"b\": 2,\n \"c\": 3,\n \"d\": 4,\n \"e\": 5,\n \"f\": 6,\n \"g\": 7,\n \"h\": 8,\n \"i\": 9,\n \"j\": 10,\n \"k\": 11,\n \"l\": 12,\n \"m\": 13,\n \"n\": 14,\n \"o\": 15,\n \"p\": 16,\n \"q\": 17,\n \"r\": 18,\n \"s\": 19,\n \"t\": 20,\n \"u\": 21,\n \"v\": 22,\n \"w\": 23,\n \"x\": 24,\n \"y\": 25,\n \"z\": 26,\n }\n assert serialize(data) == \"\"\"a: 1\nb: 2\nc: 3\nd: 4\ne: 5\nf: 6\ng: 7\nh: 8\ni: 9\nj: 10\nk: 11\nl: 12\nm: 13\nn: 14\no: 15\np: 16\nq: 17\nr: 18\ns: 19\nt: 20\nu: 21\nv: 22\nw: 23\nx: 24\ny: 25\nz: 26\n\"\"\"\n```\n" + ], + "line": 17, + "token": 133, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected_against_alias", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n http_client.select_model(model_id=\"yolov8n-640\")\n\n # when\n result = http_client.unload_model(model_id=\"yolov8n-640\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"coco/3\",\n }\n assert (\n http_client.selected_model is None\n ), \"Even when alias is in use - selected model should be emptied\"", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n requests_mock.post(\n f\"{API_URL}/model/remove\",\n json={\n \"model_id\": \"test_model_id\",\n },\n headers=DEFAULT_HEADERS,\n )\n client = RoboflowClient(api_key=\"test_api_key\", api_secret=\"test_api_secret\")\n # when\n response = client.unload_model(model_id=\"test_model_alias\")\n # then\n assert response.model_id == \"test_model_id\"\n assert response.model_name == \"test_model_name\"\n assert response.model_version == \"test_model_version\"\n assert response.model_type == \"test_model_type\"\n assert response.model_status == \"test_model_status\"\n assert response.model_created_at == \"test_model_created_at\"\n assert response.model_updated_at == \"test_model_updated_at\"\n assert response.model_tags == \"test_model_tags\"\n assert response.model_description == \"test_model_description\"\n assert response.model_url == \"test_model_url\"\n assert response.model_thumbnail_url == \"test_model_thumbnail_url\"\n assert response.model_training_data_url == \"test_model_training_data_url\"\n assert response.model_training_data_type == \"test_model_training_data_type\"\n assert response.model_training_data_size == \"test_model_training_data_size\"\n assert response.model_training_data_description == \"test_model_training_data_description\"\n assert response.model_training_data_tags == \"test_model_training_data_tags\"\n assert response.model_training_data_created_at == \"test_model_training_data_created_at\"\n assert response.model_training_data_updated_at == \"test_model_training_data_updated_at\"\n assert response.model_training_data_status == \"test_model_training_data_status\"\n assert response.model_training_data_status_message == \"test_model_training_data_status_message\"\n assert response.model_training_data_status_details == \"test_model_training_data_status_details\"\n assert response.model_training_data_status_created_at == \"test_model_training_data_status_created_at\"\n assert response.model_training_data_status_updated_at == \"test_model_training_data_status_updated_at\"\n assert response.model_training_data_status_completed_at == \"test_model_training_data_status_completed_at\"\n assert response.model_training_data_status_duration == \"test_model_training_data_status_duration\"\n assert response.model_training_data_status_budget_type == \"test_model_training_data_status_budget_type\"\n assert response.model_training_data_status_budget_amount == \"test_model_training_data_status_budget_amount\"\n assert response.model_training_data_status_budget_unit == \"test_model_training_data_status_budget_unit\"\n assert response.model_training_data_status_budget_remaining == \"test_model_training_data_status_budget_remaining\"\n assert response.model_training_data_status_budget_message == \"test_model_training_data_status_budget_message\"\n assert response.model_training_data_status_budget_details == \"test_model_training_data_status_budget_details\"\n assert response.model_training_data_status_budget_created_at == \"test_model_training_data_status_budget_created_at\"\n assert response.model_training_data_status_budget_updated_at == \"test_" + ], + "line": 31, + "token": 162, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 33, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef edge_index_to_vector(\n edge_index: Tensor,\n size: Tuple[int, int],\n bipartite: bool,\n force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n row, col = edge_index\n\n if bipartite: # No need to account for self-loops.\n idx = (row * size[1]).add_(col)\n population = size[0] * size[1]\n return idx, population\n\n elif force_undirected:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We only operate on the upper triangular matrix:\n mask = row < col\n row, col = row[mask], col[mask]\n offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n idx = row.mul_(num_nodes).add_(col).sub_(offset)\n population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n return idx, population\n\n else:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We remove self-loops as we do not want to take them into account\n # when sampling negative values.\n mask = row != col\n row, col = row[mask], col[mask]\n col[row < col] -= 1\n idx = row.mul_(num_nodes - 1).add_(col)\n population = num_nodes * num_nodes - num_nodes\n return idx, population\n```\n###test function signature:\n```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_edge_index_to_vector_and_vice_versa", + "reference": "\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n N = 10\n row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n edge_index = torch.stack([row, col], dim=0)\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n assert population == N * N\n assert idx.tolist() == list(range(population))\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n assert is_undirected(edge_index2)\n assert edge_index.tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n assert population == N * N - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n force_undirected=True)\n assert population == (N * (N + 1)) / 2 - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n force_undirected=True)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()", + "completions": [ + "```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n\n # Test force_undirected:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test force_undirected with bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n```\n" + ], + "line": 43, + "token": 351, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 34, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_integration_test(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_integration_test", + "reference": "\n def test_spectrogram_integration_test(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,\n 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,\n 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,\n 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,\n 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,\n 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,\n 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,\n 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,\n 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,\n 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,\n 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,\n 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,\n 0.0293578 , 0.03452379, 0.02194803, 0.01676056,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=256,\n num_mel_filters=400,\n min_frequency=20,\n max_frequency=8000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"kaldi\",\n triangularize_in_mel_space=True,\n )\n\n mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"povey\", periodic=False),\n frame_length=400,\n hop_length=160,\n fft_length=512,\n power=2.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n preemphasis=0.97,\n mel_filters=mel_filters,\n log_mel=\"log\",\n mel_floor=1.1920928955078125e-07,\n remove_dc_offset=True,\n )\n self.assertEqual(spec.shape, (400, 584))\n\n # fmt: off\n expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -6.52463769, -7.73677889, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.18650018, -3.37195286,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -4.70190154, -2.4217066 , -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -5.62755239, -3.53385194,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -9.43303023, -8.77480925, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.2951092 , -5.51585994,\n -15.94238515, -15.94238515, -15.94238515, -4.40151721,\n -3.95228878, -15.94238515, -15.94238515, -15.94238515,\n -6.10365415, -4.59494697, -15.94238515, -15.94238515,\n -15.94238515, -8.10727767, -6.2585298 , -15.94238515,\n -15.94238515, -15.94238515, -5.60161702, -4.47217004,\n -15.94238515, -15.94238515, -15.94238515, -5.91641988]\n )\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))", + "completions": [ + "```python\n def test_spectrogram_integration_test(self):\n waveform = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])\n window = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n expected_spectrogram = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]])\n\n self.assertTrue(np.allclose(spectrogram, expected_spectrogram))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 35, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef run_query_expansion_node(modules: List[Callable],\n module_params: List[Dict],\n previous_result: pd.DataFrame,\n node_line_dir: str,\n strategies: Dict,\n ) -> pd.DataFrame:\n if not os.path.exists(node_line_dir):\n os.makedirs(node_line_dir)\n node_dir = os.path.join(node_line_dir, \"query_expansion\")\n if not os.path.exists(node_dir):\n os.makedirs(node_dir)\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n\n # run query expansion\n results, execution_times = zip(*map(lambda task: measure_speed(\n task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))\n average_times = list(map(lambda x: x / len(results[0]), execution_times))\n\n # save results to folder\n pseudo_module_params = deepcopy(module_params)\n for i, module_param in enumerate(pseudo_module_params):\n if 'prompt' in module_params:\n module_param['prompt'] = str(i)\n filepaths = list(map(lambda x: os.path.join(node_dir, f'{x}.parquet'), range(len(modules))))\n list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet\n filenames = list(map(lambda x: os.path.basename(x), filepaths))\n\n # make summary file\n summary_df = pd.DataFrame({\n 'filename': filenames,\n 'module_name': list(map(lambda module: module.__name__, modules)),\n 'module_params': module_params,\n 'execution_time': average_times,\n })\n\n # Run evaluation when there are more than one module.\n if len(modules) > 1:\n # pop general keys from strategies (e.g. metrics, speed_threshold)\n general_key = ['metrics', 'speed_threshold']\n general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))\n extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))\n\n # first, filter by threshold if it is enabled.\n if general_strategy.get('speed_threshold') is not None:\n results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],\n filenames)\n\n # check metrics in strategy\n if general_strategy.get('metrics') is None:\n raise ValueError(\"You must at least one metrics for query expansion evaluation.\")\n\n if extra_strategy.get('top_k') is None:\n extra_strategy['top_k'] = 10 # default value\n\n # get retrieval modules from strategy\n retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)\n\n # get retrieval_gt\n retrieval_gt = pd.read_parquet(os.path.join(project_dir, \"data\", \"qa.parquet\"))['retrieval_gt'].tolist()\n\n # run evaluation\n evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(\n retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,\n general_strategy['metrics'], project_dir, previous_result), results))\n\n evaluation_df = pd.DataFrame({\n 'filename': filenames,\n **{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))\n for metric_name in general_strategy['metrics']}\n })\n summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')\n\n best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)\n # change metric name columns to query_expansion_metric_name\n best_result = best_result.rename(columns={\n metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})\n best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])\n else:\n best_result, best_filename = results[0], filenames[0]\n best_result = pd.concat([previous_result, best_result], axis=1)\n\n # add 'is_best' column at summary file\n summary_df['is_best'] = summary_df['filename'] == best_filename\n\n # save files\n summary_df.to_csv(os.path.join(node_dir, \"summary.csv\"), index=False)\n best_result.to_parquet(os.path.join(node_dir, f\"best_{os.path.splitext(best_filename)[0]}.parquet\"), index=False)\n\n return best_result\n```\n###test function signature:\n```python\n\ndef test_run_query_expansion_node_default(node_line_dir):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_query_expansion_node_default", + "reference": "\ndef test_run_query_expansion_node_default(node_line_dir):\n generator_models['mock'] = MockLLM\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n qa_path = os.path.join(project_dir, \"data\", \"qa.parquet\")\n previous_result = pd.read_parquet(qa_path)\n\n modules = [query_decompose, hyde]\n module_params = [{'llm': \"mock\", 'temperature': 0.2, 'batch': 7}, {'llm': \"mock\"}]\n strategies = {\n 'metrics': metrics\n }\n best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n base_query_expansion_test(best_result, node_line_dir)", + "completions": [ + "```python\ndef test_run_query_expansion_node_default(node_line_dir):\n modules = [query_expansion_module]\n module_params = [{'prompt': 'default'}]\n previous_result = pd.DataFrame({'query': ['a', 'b', 'c']})\n strategies = {'metrics': ['recall', 'precision', 'f1'], 'speed_threshold': 1000}\n run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n```\n" + ], + "line": 84, + "token": 850, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 36, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(self) -> Tuple[pd.DataFrame, FreqDetectionResult]:\n try:\n if self._prophet_detector_params[\"suppress_stan\"]:\n pd.set_option(\"mode.chained_assignment\", None)\n\n # Skip measurements based on feedbacks given from SODA Cloud\n preprocessed_df = self.preprocess(time_series_df=self.raw_time_series_df)\n\n # Automatically detect frequency of the time series\n freq_detector = FrequencyDetector(\n logs=self.logs,\n params=self.params,\n time_series_df=preprocessed_df,\n manual_freq=self.training_dataset_params.frequency,\n )\n freq_detection_result = freq_detector.detect_frequency()\n\n # Return if frequency detection failed\n if freq_detection_result.error_code_int >= ERROR_CODE_LEVEL_CUTTOFF:\n return self.exit_with_warning(freq_detection_result)\n\n # Apply training dataset configurations\n training_df = self.apply_training_dataset_configs(\n time_series_df=preprocessed_df, freq_detection_result=freq_detection_result\n )\n\n # Remove big gaps from the time series to not confuse Prophet\n training_df = self.remove_big_gaps_from_time_series(\n time_series_df=training_df, freq_detection_result=freq_detection_result\n )\n\n # Only use the last n points for training based on the window length\n window_length = self.get_window_length(training_df=training_df)\n training_df = training_df.iloc[-window_length:]\n\n training_df_shape = training_df[\"y\"].dropna().shape[0]\n if training_df_shape <= self._min_n_points:\n freq_detection_result = get_not_enough_measurements_freq_result(n_data_points=training_df_shape)\n return self.exit_with_warning(freq_detection_result)\n\n model_hyperparameters = self.get_prophet_hyperparameters(time_series_df=training_df)\n\n predictions_df = self.setup_fit_predict(\n time_series_df=training_df, model_hyperparameters=model_hyperparameters\n )\n anomalies_df = self.detect_anomalies(time_series_df=training_df, predictions_df=predictions_df)\n anomalies_df = self.generate_severity_zones(anomalies_df=anomalies_df)\n anomalies_df = self.compute_alert_level(anomalies_df=anomalies_df)\n return anomalies_df, freq_detection_result\n except Exception as e:\n raise e\n```\n###test function signature:\n```python\n\ndef test_with_exit() -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_exit", + "reference": "\ndef test_with_exit() -> None:\n time_series_df = generate_random_dataframe(size=2, n_rows_to_convert_none=0, frequency=\"D\")\n detector = ProphetDetector(\n logs=LOGS,\n params=PARAMS,\n time_series_df=time_series_df,\n model_cfg=ModelConfigs(),\n training_dataset_params=TrainingDatasetParameters(),\n severity_level_params=SeverityLevelParameters(),\n )\n df_anomalies, frequency_result = detector.run()\n assert df_anomalies.empty\n assert frequency_result.error_code_int == 100\n assert frequency_result.freq_detection_strategy == \"not_enough_measurements_custom\"", + "completions": [ + "```python\ndef test_with_exit() -> None:\n from soda.sodacl.checks.check_frequency import FrequencyDetector\n from soda.sodacl.checks.check_frequency import get_not_enough_measurements_freq_result\n from soda.sodacl.checks.check_frequency import ERROR_CODE_LEVEL_CUTTOFF\n from soda.sodacl.checks.check_frequency import FreqDetectionResult\n from soda.sodacl.checks.check_frequency import CheckFrequency\n from soda.sodacl.checks.check_frequency import CheckFrequencyParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyLogs\n from soda.sodacl.checks.check_frequency import CheckFrequencyTrainingDatasetParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyPro" + ], + "line": 52, + "token": 310, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 37, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef average_transforms(\n R: torch.Tensor,\n t: torch.Tensor,\n w: torch.Tensor,\n mask: torch.Tensor,\n dim: int,\n t_edge: Optional[torch.Tensor] = None,\n dither: Optional[bool] = True,\n dither_eps: float = 1e-4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert dim >= 0, \"dimension must index from the left\"\n w = torch.where(\n mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)\n )\n\n # We use different averaging models based on the number of weights\n num_transform_weights = w.size(-1)\n if num_transform_weights == 1:\n # Share a single scalar weight between t and R.\n probs = w.softmax(dim)\n t_probs = probs\n R_probs = probs[..., None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 2:\n # Use separate scalar weights for each of t and R.\n probs = w.softmax(dim)\n t_probs, R_probs = probs.unbind(-1)\n t_probs = t_probs[..., None]\n R_probs = R_probs[..., None, None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 3:\n # For R use a signed scalar weight.\n R_probs = w[..., 2].softmax(dim)[..., None, None]\n\n # For t use a two-parameter precision matrix P = P_isometric + P_radial.\n # We need to hand compute softmax over the shared dim x 2 elements.\n w_t = w[..., :2]\n w_t_total = w_t.logsumexp([dim, -1], True)\n p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)\n\n # Use Gaussian fusion for translation.\n t_edge = t_edge * mask.to(t_edge.dtype)[..., None]\n t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)\n else:\n raise NotImplementedError\n\n # Average rotation via SVD\n R_avg_unc = (R * R_probs).sum(dim)\n R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)\n U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)\n R_avg = U @ Vh\n\n # Enforce that matrix is rotation matrix\n d = torch.linalg.det(R_avg)\n d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)\n Vh = Vh * d_expand\n R_avg = U @ Vh\n return R_avg, t_avg\n```\n###test function signature:\n```python\n\ndef test_average_transforms(rotations, translations):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_average_transforms", + "reference": "\ndef test_average_transforms(rotations, translations):\n R = torch.stack([rotations[0], torch.eye(3)])\n t = torch.stack([translations[0], torch.zeros(3)])\n w = torch.ones(2, 2)\n mask = torch.ones(2)\n # average of a transform with the identity is \"half\" the transform\n R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)\n R_total_fromavg, _ = compose_transforms(\n R_avg, torch.zeros(3), R_avg, torch.zeros(3)\n )\n _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)\n assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)\n assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)", + "completions": [ + "```python\ndef test_average_transforms(rotations, translations):\n # test case 1\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 2\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 3\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 4\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 5\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 6\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 7\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 8\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 9\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w" + ], + "line": 65, + "token": 537, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 38, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n return GPT2Config(\n vocab_size=btlm_config.vocab_size,\n n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n n_embd=btlm_config.hidden_size,\n n_layer=btlm_config.num_hidden_layers,\n n_head=btlm_config.num_attention_heads,\n n_inner=btlm_config.n_inner,\n activation_function=btlm_config.activation_function,\n resid_pdrop=btlm_config.resid_pdrop,\n embd_pdrop=btlm_config.embd_pdrop,\n attn_pdrop=btlm_config.attn_pdrop,\n layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n initializer_range=btlm_config.initializer_range,\n bos_token_id=btlm_config.bos_token_id,\n eos_token_id=btlm_config.eos_token_id,\n # These are new arguments not in the original GPT2Config\n use_alibi=btlm_config.position_embedding_type == \"alibi\",\n use_flash_attn=btlm_config.position_embedding_type == \"alibi\", # Alibi code path requires flash_attn\n mup_width_scale=btlm_config.mup_width_scale,\n mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n mup_output_multiplier=btlm_config.mup_output_alpha,\n mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n mlp_multiple_of=1,\n )\n```\n###test function signature:\n```python\ndef test_btlm_optimized(model_name):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_btlm_optimized", + "reference": "def test_btlm_optimized(model_name):\n dtype = torch.float16\n device = \"cuda\"\n config = btlm_config_to_gpt2_config(\n AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n )\n config.fused_bias_fc = True\n config.fused_dropout_add_ln = True\n config.residual_in_fp32 = True\n\n pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n model = GPTLMHeadModel(config, device=device, dtype=dtype)\n model.load_state_dict(pretrained_state_dict)\n model.eval()\n\n torch.manual_seed(0)\n batch_size = 2\n max_seqlen = 256\n seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n input_ids = torch.randint(\n 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n )\n with torch.no_grad():\n out = model.transformer(input_ids)\n logits = model(input_ids).logits\n del model\n\n # Without device_map, the model is loaded on the CPU, which is very slow\n # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n model_ref = AutoModelForCausalLM.from_pretrained(\n model_name, device_map=\"auto\", trust_remote_code=True\n )\n model_ref.eval()\n with torch.no_grad():\n out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n logits_ref = model_ref(input_ids).logits.to(device=device)\n del model_ref\n\n model_hf = AutoModelForCausalLM.from_pretrained(\n model_name,\n torch_dtype=dtype,\n device_map={\"\": device},\n trust_remote_code=True,\n )\n model_hf.eval()\n with torch.no_grad():\n out_hf = model_hf.transformer(input_ids).last_hidden_state\n logits_hf = model_hf(input_ids).logits\n del model_hf\n\n print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n assert (logits - logits_ref).abs().max().item() < 3 * (\n logits_hf - logits_ref\n ).abs().max().item()", + "completions": [ + "```python\ndef test_btlm_optimized(model_name):\n from transformers import GPT2Config, PretrainedConfig\n from bert_extended.optimized import btlm_config_to_gpt2_config\n\n config = PretrainedConfig.from_pretrained(model_name)\n gpt2_config = btlm_config_to_gpt2_config(config)\n gpt2_config.to_json_file(f\"{model_name}_optimized.json\")\n```\n" + ], + "line": 35, + "token": 169, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 39, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef grad_solid_harmonic(r: jnp.ndarray, l_max: int) -> jnp.ndarray:\n def cartesian_sum(r, ell):\n # computes all cartesian sums of order ell\n # i.e. all terms (a*x + b*y + c*z) for a+b+c = ell\n if ell == 0:\n return np.array([0.])\n return np.sum(np.stack(\n np.meshgrid(*[r]*ell, indexing='ij'), axis=0).reshape(ell, -1), axis=0)\n\n dharmonics = []\n for ell in range(l_max + 1):\n dx = cartesian_sum(np.array([1., 0., 0.]), ell)\n dy = cartesian_sum(np.array([0., 1., 0.]), ell)\n dz = cartesian_sum(np.array([0., 0., 1.]), ell)\n d = np.stack([dx, dy, dz], axis=1)\n cp = jax.vmap(functools.partial(cartesian_product, ell=ell), out_axes=-1)\n dcp = jnp.where(r.T[None, :, :] == 0, 0,\n d[:, :, None]*cp(r)[:, None, :] / r.T[None, :, :])\n sh = jnp.einsum('cs,cxn->snx', full_cart2sph(ell, True), dcp)\n pad = l_max - ell\n sh = jnp.pad(sh, [[pad, pad], [0, 0], [0, 0]])\n dharmonics.append(sh)\n\n return jnp.stack(dharmonics, axis=1) # [L, M, grid, xyz]\n```\n###test function signature:\n```python\n\n def test_grad_solid_harmonic(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_grad_solid_harmonic", + "reference": "\n def test_grad_solid_harmonic(self):\n np.random.seed(0)\n r = np.random.randn(100, 3)\n l_max = 5\n\n jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max))\n expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])),\n [1, 2, 0, 3])\n\n with self.subTest('by hand'):\n observed = gto.grad_solid_harmonic(r, l_max)\n np.testing.assert_allclose(observed, expected, atol=1.e-4)\n with self.subTest('by jax'):\n observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max)\n np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4)", + "completions": [ + "```python\n def test_grad_solid_harmonic(self):\n r = jnp.array([[0.5, 0.5, 0.5]])\n l_max = 2\n dharmonics = grad_solid_harmonic(r, l_max)\n self.assertEqual(dharmonics.shape, (l_max + 1, l_max + 1, 1, 3))\n self.assertTrue(jnp.allclose(dharmonics[0, 0, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 0, 0, :], jnp.array([1., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 1, 0, :], jnp.array([0., 1., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 0, 0, :], jnp.array([0., 0., 1.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 1, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 2, 0, :], jnp.array([0., 0., 0.])))\n```\n" + ], + "line": 32, + "token": 389, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 40, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_basis(\n base_shape, angular_tesselation, remove_symmetries=True, eps=1e-4\n):\n\n if base_shape == 'tetrahedron':\n verts = np.array([\n (np.sqrt(8 / 9), 0, -1 / 3),\n (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),\n (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3),\n (0, 0, 1),\n ])\n faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)])\n elif base_shape == 'icosahedron':\n a = (np.sqrt(5) + 1) / 2\n verts = np.array([\n (-1, 0, a),\n (1, 0, a),\n (-1, 0, -a),\n (1, 0, -a),\n (0, a, 1),\n (0, a, -1),\n (0, -a, 1),\n (0, -a, -1),\n (a, 1, 0),\n (-a, 1, 0),\n (a, -1, 0),\n (-a, -1, 0),\n ]) / np.sqrt(a + 2)\n faces = np.array([\n (0, 4, 1),\n (0, 9, 4),\n (9, 5, 4),\n (4, 5, 8),\n (4, 8, 1),\n (8, 10, 1),\n (8, 3, 10),\n (5, 3, 8),\n (5, 2, 3),\n (2, 7, 3),\n (7, 10, 3),\n (7, 6, 10),\n (7, 11, 6),\n (11, 0, 6),\n (0, 1, 6),\n (6, 1, 10),\n (9, 0, 11),\n (9, 11, 2),\n (9, 2, 5),\n (7, 2, 11),\n ])\n elif base_shape == 'octahedron':\n verts = np.array(\n [(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]\n )\n corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n else:\n raise ValueError(f'base_shape {base_shape} not supported')\n verts = tesselate_geodesic(verts, faces, angular_tesselation)\n\n if remove_symmetries:\n # Remove elements of `verts` that are reflections of each other.\n match = compute_sq_dist(verts.T, -verts.T) < eps\n verts = verts[~np.any(np.triu(match), axis=0), :]\n\n basis = verts[:, ::-1]\n return basis\n```\n###test function signature:\n```python\n def test_generate_basis_golden(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_basis_golden", + "reference": " def test_generate_basis_golden(self):\n\n basis = geopoly.generate_basis('tetrahedron', 1)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('tetrahedron', 2)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.57735027, 0.00000000, -0.81649658],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.57735027, -0.70710678, 0.40824829],\n [-0.57735027, 0.70710678, 0.40824829],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('icosahedron', 2)\n basis_golden = np.array([\n [0.85065081, 0.00000000, 0.52573111],\n [0.80901699, 0.50000000, 0.30901699],\n [0.52573111, 0.85065081, 0.00000000],\n [1.00000000, 0.00000000, 0.00000000],\n [0.80901699, 0.50000000, -0.30901699],\n [0.85065081, 0.00000000, -0.52573111],\n [0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, -0.85065081],\n [0.50000000, 0.30901699, -0.80901699],\n [0.00000000, 1.00000000, 0.00000000],\n [-0.52573111, 0.85065081, 0.00000000],\n [-0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, 0.85065081],\n [-0.30901699, 0.80901699, 0.50000000],\n [0.30901699, 0.80901699, 0.50000000],\n [0.50000000, 0.30901699, 0.80901699],\n [0.50000000, -0.30901699, 0.80901699],\n [0.00000000, 0.00000000, 1.00000000],\n [-0.50000000, 0.30901699, 0.80901699],\n [-0.80901699, 0.50000000, 0.30901699],\n [-0.80901699, 0.50000000, -0.30901699],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('octahedron', 4)\n basis_golden = np.array([\n [0.00000000, 0.00000000, -1.00000000],\n [0.00000000, -0.31622777, -0.94868330],\n [0.00000000, -0.70710678, -0.70710678],\n [0.00000000, -0.94868330, -0.31622777],\n [0.00000000, -1.00000000, 0.00000000],\n [-0.31622777, 0.00000000, -0.94868330],\n [-0.40824829, -0.40824829, -0.81649658],\n [-0.40824829, -0.81649658, -0.40824829],\n [-0.31622777, -0.94868330, 0.00000000],\n [-0.70710678, 0.00000000, -0.70710678],\n [-0.81649658, -0.40824829, -0.40824829],\n [-0.70710678, -0.70710678, 0.00000000],\n [-0.94868330, 0.00000000, -0.31622777],\n [-0.94868330, -0.31622777, 0.00000000],\n [-1.00000000, 0.00000000, 0.00000000],\n [0.00000000, -0.31622777, 0.94868330],\n [0.00000000, -0.70710678, 0.70710678],\n [0.00000000, -0.94868330, 0.31622777],\n [0.40824829, -0.40824829, 0.81649658],\n [0.40824829, -0.81649658, 0.40824829],\n [0.31622777, -0.94868330, 0.00000000],\n [0.81649658, -0.40824829, 0.40824829],\n [0.70710678, -0.70710678, 0.00000000],\n [0.94868330, -0.31622777, 0.00000000],\n [0.31622777, 0.00000000, -0.94868330],\n [0.40824829, 0.40824829, -0.81649658],\n [0.40824829, 0.81649658, -0.40824829],\n [0.70710678, 0.00000000, -0.70710678],\n [0.81649658, 0.40824829, -0.40824829],\n [0.94868330, 0.00000000, -0.31622777],\n [0.40824829, -0.40824829, -0.81649658],\n [0.40824829, -0.81649658, -0.40824829],\n [0.81649658, -0.40824829, -0.40824829],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))", + "completions": [ + "```python\n def test_generate_basis_golden(self):\n basis = generate_basis(\n base_shape='tetrahedron', angular_tesselation=10, remove_symmetries=True, eps=1e-4\n )\n self.assertTrue(np.allclose(basis, np.array([[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0]])))\n```\n" + ], + "line": 75, + "token": 692, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 41, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:\n if not isinstance(t, torch.Tensor):\n t = torch.Tensor([t]).float()\n\n t_max = self.t_max.to(t.device)\n t_min = self.t_min.to(t.device)\n t_tilde = t_min + (t_max - t_min) * t\n\n return t_tilde\n```\n###test function signature:\n```python\n\ndef test_noise_schedule_diffusion_coeff(kind):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_noise_schedule_diffusion_coeff", + "reference": "\ndef test_noise_schedule_diffusion_coeff(kind):\n noise_schedule = GaussianNoiseSchedule(kind=kind, log_snr_range=(-6, 6))\n ts = torch.linspace(1e-2, 1 - 1e-2, 10)\n t_map = noise_schedule.t_map(ts) # map time to the prescribed log_SNR range\n\n if kind == \"log_snr\":\n g = noise_schedule.g(ts)\n\n # compute true beta_t\n l_range = noise_schedule.log_snr_range\n snr = torch.exp((1 - t_map) * l_range[1] + t_map * l_range[0])\n g_true = (-(l_range[0] - l_range[1]) / (snr + 1)).sqrt()\n\n assert torch.allclose(g, g_true, atol=1e-4)\n\n if kind == \"ot_linear\":\n g = noise_schedule.g(ts)\n g_true = (2.0 * t_map / (1.0 - t_map)).sqrt()\n assert torch.allclose(g, g_true, atol=1e-4)", + "completions": [ + "```python\ndef test_noise_schedule_diffusion_coeff(kind):\n from torch.testing import assert_allclose\n from torch.nn import functional as F\n from torch import tensor\n from torch.distributions import Uniform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import Transform\n from torch.distributions.transforms import SigmoidTransform\n from torch.distributions.transforms import TanhTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import ComposeTransform\n from torch.distributions.transforms import PowerTransform\n from torch.distributions.transforms import ExpTransform\n from torch.distributions.transforms import LogTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import" + ], + "line": 17, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 42, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_data_quality", + "reference": "\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index, temperature_data, temperature_mean=False, data_quality=True\n )\n assert df.shape == (27, 4)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_not_null\",\n \"temperature_null\",\n ]\n assert round(df.temperature_not_null.mean(), 2) == 729.23\n assert round(df.temperature_null.mean(), 2) == 0.0", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"1H\", tz=\"America/Chicago\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0" + ], + "line": 168, + "token": 984, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 43, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n\n # when\n result = http_client.unload_model(model_id=\"some/1\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"some/1\",\n }", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n model_id = \"my_model_id\"\n de_aliased_model_id = \"my_de_aliased_model_id\"\n response_payload = {\n \"model_id\": de_aliased_model_id,\n \"model_name\": \"my_model_name\",\n \"model_version\": \"my_model_version\",\n \"model_type\": \"my_model_type\",\n \"model_status\": \"my_model_status\",\n \"model_description\": \"my_model_description\",\n \"model_tags\": [\"my_model_tag1\", \"my_model_tag2\"],\n \"model_created_at\": \"my_model_created_at\",\n \"model_updated_at\": \"my_model_updated_at\",\n }\n requests_mock.post(\n f\"{self.__api_url}/model/remove\",\n json=response_payload,\n headers=DEFAULT_HEADERS,\n )\n # when\n response = self.client.unload_model(model_id=model_id)\n # then\n assert response.model_id == de_aliased_model_id\n assert response.model_name == \"my_model_name\"\n assert response.model_version == \"my_model_version\"\n assert response.model_type == \"my_model_type\"\n assert response.model_status == \"my_model_status\"\n assert response.model_description == \"my_model_description\"\n assert response.model_tags == [\"my_model_tag1\", \"my_model_tag2\"]\n assert response.model_created_at == \"my_model_created_at\"\n assert response.model_updated_at == \"my_model_updated_at\"\n assert self.client.selected_model is None\n```\n" + ], + "line": 31, + "token": 162, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=20.0,\n max_iteration=100,\n distance_strict=False,\n seed=None,\n):\n rng = np.random.default_rng(seed=seed)\n units_locations = np.zeros((num_units, 3), dtype=\"float32\")\n\n minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um\n minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um\n\n units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)\n units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)\n units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)\n\n if minimum_distance is not None:\n solution_found = False\n renew_inds = None\n for i in range(max_iteration):\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n inds0, inds1 = np.nonzero(distances < minimum_distance)\n mask = inds0 != inds1\n inds0 = inds0[mask]\n inds1 = inds1[mask]\n\n if inds0.size > 0:\n if renew_inds is None:\n renew_inds = np.unique(inds0)\n else:\n # random only bad ones in the previous set\n renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]\n\n units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)\n units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)\n units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)\n else:\n solution_found = True\n break\n\n if not solution_found:\n if distance_strict:\n raise ValueError(\n f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} \"\n \"You can use distance_strict=False or reduce minimum distance\"\n )\n else:\n warnings.warn(f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}\")\n\n return units_locations\n```\n###test function signature:\n```python\n\ndef test_generate_unit_locations():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_unit_locations", + "reference": "\ndef test_generate_unit_locations():\n seed = 0\n\n probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)\n channel_locations = probe.contact_positions\n\n num_units = 100\n minimum_distance = 20.0\n unit_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=minimum_distance,\n max_iteration=500,\n distance_strict=False,\n seed=seed,\n )\n distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)\n dist_flat = np.triu(distances, k=1).flatten()\n dist_flat = dist_flat[dist_flat > 0]\n assert np.all(dist_flat > minimum_distance)\n\n # import matplotlib.pyplot as plt\n # fig, ax = plt.subplots()\n # ax.hist(dist_flat, bins = np.arange(0, 400, 10))\n # fig, ax = plt.subplots()\n # from probeinterface.plotting import plot_probe\n # plot_probe(probe, ax=ax)\n # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)\n # plt.show()", + "completions": [ + "```python\n\ndef test_generate_unit_locations():\n num_units = 10\n channel_locations = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])\n margin_um = 20.0\n minimum_z = 5.0\n maximum_z = 40.0\n minimum_distance = 20.0\n max_iteration = 100\n distance_strict = False\n seed = None\n\n units_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=margin_um,\n minimum_z=minimum_z,\n maximum_z=maximum_z,\n minimum_distance=minimum_distance,\n max_iteration=max_iteration,\n distance_strict=distance_strict,\n seed=seed,\n )\n\n assert units_locations.shape == (num_units, 3)\n assert np.all(units_locations[:, 0] >= margin_um)\n assert np.all(units_locations[:, 0] <= 100 - margin_um)\n assert np.all(units_locations[:, 1] >= margin_um)\n assert np.all(units_locations[:, 1] <= 100 - margin_um)\n assert np.all(units_locations[:, 2] >= minimum_z)\n assert np.all(units_locations[:, 2] <= maximum_z)\n\n # check distance\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n assert np.all(distances >= minimum_distance)\n\n # check no duplicate\n assert np.all(np.unique(units_locations, axis=0) == units_locations)\n```\n" + ], + "line": 58, + "token": 453, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 45, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def to_dataframe(\n self, index: Optional[Union[str, None]] = \"date\", sort_by: Optional[str] = None\n ) -> pd.DataFrame:\n\n def is_list_of_basemodel(items: Union[List[T], T]) -> bool:\n return isinstance(items, list) and all(\n isinstance(item, BaseModel) for item in items\n )\n\n if self.results is None or not self.results:\n raise OpenBBError(\"Results not found.\")\n\n if isinstance(self.results, pd.DataFrame):\n return self.results\n\n try:\n res = self.results\n df = None\n sort_columns = True\n\n # List[Dict]\n if isinstance(res, list) and len(res) == 1 and isinstance(res[0], dict):\n r = res[0]\n dict_of_df = {}\n\n for k, v in r.items():\n # Dict[str, List[BaseModel]]\n if is_list_of_basemodel(v):\n dict_of_df[k] = basemodel_to_df(v, index)\n sort_columns = False\n # Dict[str, Any]\n else:\n dict_of_df[k] = pd.DataFrame(v)\n\n df = pd.concat(dict_of_df, axis=1)\n\n # List[BaseModel]\n elif is_list_of_basemodel(res):\n dt: Union[List[Data], Data] = res # type: ignore\n df = basemodel_to_df(dt, index)\n sort_columns = False\n # List[List | str | int | float] | Dict[str, Dict | List | BaseModel]\n else:\n try:\n df = pd.DataFrame(res)\n # Set index, if any\n if index is not None and index in df.columns:\n df.set_index(index, inplace=True)\n\n except ValueError:\n if isinstance(res, dict):\n df = pd.DataFrame([res])\n\n if df is None:\n raise OpenBBError(\"Unsupported data format.\")\n\n # Drop columns that are all NaN, but don't rearrange columns\n if sort_columns:\n df.sort_index(axis=1, inplace=True)\n df = df.dropna(axis=1, how=\"all\")\n\n # Sort by specified column\n if sort_by:\n df.sort_values(by=sort_by, inplace=True)\n\n except OpenBBError as e:\n raise e\n except ValueError as ve:\n raise OpenBBError(\n f\"ValueError: {ve}. Ensure the data format matches the expected format.\"\n ) from ve\n except TypeError as te:\n raise OpenBBError(\n f\"TypeError: {te}. Check the data types in your results.\"\n ) from te\n except Exception as ex:\n raise OpenBBError(f\"An unexpected error occurred: {ex}\") from ex\n\n return df\n```\n###test function signature:\n```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_to_dataframe_w_args", + "reference": "def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n co = OBBject(results=results)\n\n # Act and Assert\n result = co.to_dataframe(index=index, sort_by=sort_by)\n assert isinstance(result, pd.DataFrame)\n assert result.index.name == index\n\n # check if dataframe is properly sorted\n assert result[sort_by].is_monotonic_increasing", + "completions": [ + "```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n from openbb_terminal.decorators import log_start_end\n from openbb_terminal.stocks.fundamental_analysis import financedatabase_model\n\n @log_start_end(log=logger)\n def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n df = financedatabase_model.to_dataframe(results, index, sort_by)\n # Act\n # Assert\n assert df.shape[0] > 0\n assert df.shape[1] > 0\n assert df.index.name == index\n assert df.columns.tolist() == sort_by\n\n # Act\n test_to_dataframe_w_args(results, index, sort_by)\n # Assert\n```\n" + ], + "line": 76, + "token": 566, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 46, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef unique(ar, return_index=False, return_inverse=False,\n return_counts=False, axis=None, *, equal_nan=True):\n ar = np.asanyarray(ar)\n if axis is None:\n ret = _unique1d(ar, return_index, return_inverse, return_counts, \n equal_nan=equal_nan)\n return _unpack_tuple(ret)\n\n # axis was specified and not None\n try:\n ar = np.moveaxis(ar, axis, 0)\n except np.AxisError:\n # this removes the \"axis1\" or \"axis2\" prefix from the error message\n raise np.AxisError(axis, ar.ndim) from None\n\n # Must reshape to a contiguous 2D array for this to work...\n orig_shape, orig_dtype = ar.shape, ar.dtype\n ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))\n ar = np.ascontiguousarray(ar)\n dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]\n\n # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured\n # data type with `m` fields where each field has the data type of `ar`.\n # In the following, we create the array `consolidated`, which has\n # shape `(n,)` with data type `dtype`.\n try:\n if ar.shape[1] > 0:\n consolidated = ar.view(dtype)\n else:\n # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is\n # a data type with itemsize 0, and the call `ar.view(dtype)` will\n # fail. Instead, we'll use `np.empty` to explicitly create the\n # array with shape `(len(ar),)`. Because `dtype` in this case has\n # itemsize 0, the total size of the result is still 0 bytes.\n consolidated = np.empty(len(ar), dtype=dtype)\n except TypeError as e:\n # There's no good way to do this for object arrays, etc...\n msg = 'The axis argument to unique is not supported for dtype {dt}'\n raise TypeError(msg.format(dt=ar.dtype)) from e\n\n def reshape_uniq(uniq):\n n = len(uniq)\n uniq = uniq.view(orig_dtype)\n uniq = uniq.reshape(n, *orig_shape[1:])\n uniq = np.moveaxis(uniq, 0, axis)\n return uniq\n\n output = _unique1d(consolidated, return_index,\n return_inverse, return_counts, equal_nan=equal_nan)\n output = (reshape_uniq(output[0]),) + output[1:]\n return _unpack_tuple(output)\n```\n###test function signature:\n```python\n\n def test_unique_1d(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_1d", + "reference": "\n def test_unique_1d(self):\n\n def check_all(a, b, i1, i2, c, dt):\n base_msg = 'check {0} failed for type {1}'\n\n msg = base_msg.format('values', dt)\n v = unique(a)\n assert_array_equal(v, b, msg)\n\n msg = base_msg.format('return_index', dt)\n v, j = unique(a, True, False, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i1, msg)\n\n msg = base_msg.format('return_inverse', dt)\n v, j = unique(a, False, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i2, msg)\n\n msg = base_msg.format('return_counts', dt)\n v, j = unique(a, False, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, c, msg)\n\n msg = base_msg.format('return_index and return_inverse', dt)\n v, j1, j2 = unique(a, True, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n\n msg = base_msg.format('return_index and return_counts', dt)\n v, j1, j2 = unique(a, True, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format('return_inverse and return_counts', dt)\n v, j1, j2 = unique(a, False, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i2, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format(('return_index, return_inverse '\n 'and return_counts'), dt)\n v, j1, j2, j3 = unique(a, True, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n assert_array_equal(j3, c, msg)\n\n a = [5, 7, 1, 2, 1, 5, 7]*10\n b = [1, 2, 5, 7]\n i1 = [2, 3, 0, 1]\n i2 = [2, 3, 0, 1, 0, 2, 3]*10\n c = np.multiply([2, 1, 2, 2], 10)\n\n # test for numeric arrays\n types = []\n types.extend(np.typecodes['AllInteger'])\n types.extend(np.typecodes['AllFloat'])\n types.append('datetime64[D]')\n types.append('timedelta64[D]')\n for dt in types:\n aa = np.array(a, dt)\n bb = np.array(b, dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for object arrays\n dt = 'O'\n aa = np.empty(len(a), dt)\n aa[:] = a\n bb = np.empty(len(b), dt)\n bb[:] = b\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for structured arrays\n dt = [('', 'i'), ('', 'i')]\n aa = np.array(list(zip(a, a)), dt)\n bb = np.array(list(zip(b, b)), dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for ticket #2799\n aa = [1. + 0.j, 1 - 1.j, 1]\n assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])\n\n # test for ticket #4785\n a = [(1, 2), (1, 2), (2, 3)]\n unq = [1, 2, 3]\n inv = [0, 1, 0, 1, 1, 2]\n a1 = unique(a)\n assert_array_equal(a1, unq)\n a2, a2_inv = unique(a, return_inverse=True)\n assert_array_equal(a2, unq)\n assert_array_equal(a2_inv, inv)\n\n # test for chararrays with return_inverse (gh-5099)\n a = np.chararray(5)\n a[...] = ''\n a2, a2_inv = np.unique(a, return_inverse=True)\n assert_array_equal(a2_inv, np.zeros(5))\n\n # test for ticket #9137\n a = []\n a1_idx = np.unique(a, return_index=True)[1]\n a2_inv = np.unique(a, return_inverse=True)[1]\n a3_idx, a3_inv = np.unique(a, return_index=True,\n return_inverse=True)[1:]\n assert_equal(a1_idx.dtype, np.intp)\n assert_equal(a2_inv.dtype, np.intp)\n assert_equal(a3_idx.dtype, np.intp)\n assert_equal(a3_inv.dtype, np.intp)\n\n # test for ticket 2111 - float\n a = [2.0, np.nan, 1.0, np.nan]\n ua = [1.0, 2.0, np.nan]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - complex\n a = [2.0-1j, np.nan, 1.0+1j, complex(0.0, np.nan), complex(1.0, np.nan)]\n ua = [1.0+1j, 2.0-1j, complex(0.0, np.nan)]\n ua_idx = [2, 0, 3]\n ua_inv = [1, 2, 0, 2, 2]\n ua_cnt = [1, 1, 3]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - datetime64\n nat = np.datetime64('nat')\n a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]\n ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - timedelta\n nat = np.timedelta64('nat')\n a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]\n ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for gh-19300\n all_nans = [np.nan] * 4\n ua = [np.nan]\n ua_idx = [0]\n ua_inv = [0, 0, 0, 0]\n ua_cnt = [4]\n assert_equal(np.unique(all_nans), ua)\n assert_equal(np.unique(all_nans, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(all_nans, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(all_nans, return_counts=True), (ua, ua_cnt))", + "completions": [ + "```python\n def test_unique_1d(self):\n ar = np.array([1, 2, 3, 4, 5, 5, 5, 5, 5, 5])\n self.assertEqual(unique(ar), np.array([1, 2, 3, 4, 5]))\n self.assertEqual(unique(ar, return_index=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n```\n" + ], + "line": 56, + "token": 571, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 47, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef scaled_dot_product_attention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n att_mask: Optional[Union[AttentionMask, \"SparseCS\", torch.Tensor]],\n dropout: Optional[torch.nn.Module] = None,\n block_size: int = 128,\n) -> torch.Tensor:\n autocast_disabled = (\n _has_cpp_library\n and isinstance(att_mask, SparseCS)\n or (att_mask is not None and att_mask.is_sparse)\n )\n seq_len = q.shape[-2]\n\n # switch if:\n # causal is required but mask is not sparse\n # fp16 or under amp context\n # sequence length is divisible by block size\n # same seq len for K and Q\n switch_to_blocksparse = (\n _is_blocksparse_available\n and (att_mask is not None and not att_mask.is_sparse)\n and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)\n and (q.dtype == torch.float16 or torch.is_autocast_enabled())\n and not seq_len % block_size\n and q.shape[-2] == k.shape[-2]\n )\n\n if switch_to_blocksparse:\n logger.info(\"Switching causal attention to Triton blocksparse...\")\n return blocksparse_attention(q, k, v, dropout, block_size)\n\n with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore\n if autocast_disabled:\n q, k, v = q.float(), k.float(), v.float()\n\n att = scaled_query_key_softmax(q, k, att_mask=att_mask)\n\n # Optional dropout, could be part of the masking in the future\n att = _apply_dropout(att, dropout)\n\n # Get to the predicted values, for all heads\n # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)\n y = bmm(att, v)\n return y\n```\n###test function signature:\n```python\n\ndef test_amp_attention_dense(device):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_amp_attention_dense", + "reference": "\ndef test_amp_attention_dense(device):\n b, s, d = 8, 64, 32\n prob = 0.9\n\n a = torch.rand(b, s, d, device=device)\n m = torch.rand(s, s, device=device) > prob\n\n with torch.cuda.amp.autocast():\n r = scaled_dot_product_attention(a, a, a, m)\n\n expected_device = torch.float16 if device == \"cuda\" else torch.float32\n assert r.dtype == expected_device", + "completions": [ + "```python\n\ndef test_amp_attention_dense(device):\n q = torch.randn(1, 128, 128, device=device)\n k = torch.randn(1, 128, 128, device=device)\n v = torch.randn(1, 128, 128, device=device)\n att_mask = torch.randn(1, 128, 128, device=device)\n dropout = torch.nn.Dropout(0.1)\n block_size = 128\n y = scaled_dot_product_attention(q, k, v, att_mask, dropout, block_size)\n assert y.shape == (1, 128, 128)\n```\n" + ], + "line": 50, + "token": 390, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 48, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def quantize(\n cls,\n tensor: Tensor,\n nbits: int = 4,\n channel_wise: bool = True,\n group_size: int = 64,\n optimize: bool = False,\n round_zero: bool = False,\n axis: int = 0,\n bitpack: bool = True,\n compute_dtype: torch.dtype | None = None,\n view_as_float: bool = False,\n device: str = \"cuda\",\n ) -> tuple:\n assert nbits in Quantizer.SUPPORTED_BITS, (\n \"nbits=\" + str(nbits) + \" not supported.\"\n )\n assert axis in [0, 1], \"axis should be either 0 or 1\"\n if group_size is not None:\n assert is_divisible(tensor.numel(), group_size), (\n \"group_size should be divisble by the total tensor dimensions. shape: \"\n + str(tensor.shape)\n + \", group_size: \"\n + str(group_size)\n )\n\n W = tensor.float()\n shape = W.shape\n\n # Reshape for grouping\n if (group_size is not None) and channel_wise:\n W = (\n W.reshape([-1, group_size])\n if (axis == 1)\n else W.reshape([group_size, -1])\n )\n\n # Get min/max values\n if not channel_wise:\n _min, _max = W.min(), W.max()\n optimize = False\n else:\n _min = W.min(axis=axis, keepdim=True)[0]\n _max = W.max(axis=axis, keepdim=True)[0]\n\n max_v = 2**nbits - 1\n min_v = 0\n min_max = [min_v, max_v]\n\n # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.\n scale = (max_v / (_max - _min)).clamp(\n max=2e4\n ) # clamp to avoid half-precision problems\n zero = -_min * scale\n\n # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14\n if round_zero:\n zero = torch.round(zero)\n\n # Fine-tune weights\n if optimize:\n W_q, scale, zero = Quantizer.optimize_weights(\n tensor=W,\n scale=scale,\n zero=zero,\n min_max=min_max,\n axis=axis,\n device=device,\n )\n else:\n W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])\n\n # Store meta-data (we invert the scale for dequantization)\n meta = {\n \"nbits\": nbits,\n \"group_size\": group_size,\n \"shape\": shape,\n \"scale\": 1.0 / scale,\n \"zero\": zero,\n \"axis\": axis,\n \"packing\": Quantizer.bit_to_packing[nbits],\n }\n meta[\"unpack_view_dtype\"] = Quantizer.unpack_view_dtype[meta[\"packing\"]]\n\n # Pack bits\n meta[\"view_as_float\"] = view_as_float\n if bitpack:\n W_q = Quantizer.pack[meta[\"packing\"]](W_q)\n if view_as_float:\n W_q = W_q.view(\n torch.float32 if compute_dtype is None else compute_dtype\n ) # store quantized weights as compute_dtype\n else:\n W_q = W_q.to(tensor.dtype)\n meta[\"packing\"] = None\n\n # cleanup\n del W, _min, _max\n torch.cuda.empty_cache()\n\n return W_q, meta\n```\n###test function signature:\n```python\n\n def test_quantizer(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_quantizer", + "reference": "\n def test_quantizer(self): \n for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:\n for nbits in [8,4,3,2,1]:\n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, view_as_float=False)\n if nbits == 3:\n assert W_q.dtype == torch.int32\n else:\n assert W_q.dtype == torch.uint8\n w_dq = Quantizer.dequantize(W_q, meta)\n norm1 = torch.norm(self.w - w_dq, p=0.7)\n \n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, compute_dtype=compute_dtype, view_as_float=True)\n assert W_q.dtype == compute_dtype\n w_dq = Quantizer.dequantize(W_q, meta)\n norm2 = torch.norm(self.w - w_dq, p=0.7)\n \n self.assertTrue(torch.equal(norm1, norm2))", + "completions": [ + "```python\n\n def test_quantizer(self):\n # test case 1\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 2\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=True, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], 0.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 3\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=True, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 4\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor," + ], + "line": 100, + "token": 645, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 49, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_num_ref_area", + "reference": "\ndef test_nanoplot_out_bars_with_num_ref_area():\n\n out_bars_with_num_ref_area = _generate_nanoplot(**CASES[13])\n\n assert _is_nanoplot_output(out_bars_with_num_ref_area)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_area = [1, 3]\n y_ref_line = 2\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_" + ], + "line": 749, + "token": 7038, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 50, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef mean(x, axis=None, keepdims=False):\n if isinstance(x, (list, tuple)):\n x = stack(x)\n x = convert_to_tensor(x)\n if axis == () or axis == []:\n # Torch handles the empty axis case differently from numpy.\n return x\n axis = to_tuple_or_list(axis) # see [NB] below\n\n ori_dtype = standardize_dtype(x.dtype)\n # torch.mean only supports floating point inputs\n compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n if \"int\" in ori_dtype or ori_dtype == \"bool\":\n result_dtype = compute_dtype\n else:\n result_dtype = ori_dtype\n\n # [NB] the python torch op torch.mean() is generated into\n # `torch._C._VariableFunctions.pyi`, and the method\n # signature is overloaded.\n # Dynamo won't actually find the correct signature of\n # `torch.mean()` if arguments are passed via kwargs\n # So we have to pass the arguments via positional args\n # EXCEPT for those that are forced as kwargs via the `*`\n # delimiter in the overloaded method signatures.\n # Additionally, we have to create a singleton-tuple\n # when `axis` is an int to match the existing fn signature\n result = torch.mean(\n x,\n axis,\n keepdims,\n dtype=to_torch_dtype(compute_dtype),\n )\n return cast(result, result_dtype)\n```\n###test function signature:\n```python\n\n def test_mean(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mean", + "reference": "\n def test_mean(self):\n x = np.array([[1, 2, 3], [3, 2, 1]])\n self.assertAllClose(knp.mean(x), np.mean(x))\n self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n self.assertAllClose(\n knp.mean(x, axis=1, keepdims=True),\n np.mean(x, axis=1, keepdims=True),\n )\n\n self.assertAllClose(knp.Mean()(x), np.mean(x))\n self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n self.assertAllClose(\n knp.Mean(axis=1, keepdims=True)(x),\n np.mean(x, axis=1, keepdims=True),\n )\n\n # test overflow\n x = np.array([65504, 65504, 65504], dtype=\"float16\")\n self.assertAllClose(knp.mean(x), np.mean(x))", + "completions": [ + "```python\n\n def test_mean(self):\n x = torch.randn(2, 3, 4)\n self.assertEqual(mean(x), torch.mean(x))\n self.assertEqual(mean(x, axis=0), torch.mean(x, dim=0))\n self.assertEqual(mean(x, axis=1), torch.mean(x, dim=1))\n self.assertEqual(mean(x, axis=2), torch.mean(x, dim=2))\n self.assertEqual(mean(x, axis=(0, 1)), torch.mean(x, dim=(0, 1)))\n self.assertEqual(mean(x, axis=(0, 2)), torch.mean(x, dim=(0, 2)))\n self.assertEqual(mean(x, axis=(1, 2)), torch.mean(x, dim=(1, 2)))\n self.assertEqual(mean(x, axis=(0, 1, 2)), torch.mean(x, dim=(0, 1, 2)))\n self.assertEqual(mean(x, axis=0, keepdims=True), torch.mean(x, dim=0, keepdim=True))\n self.assertEqual(mean(x, axis=1, keepdims=True), torch.mean(x, dim=1, keepdim=True))\n self.assertEqual(mean(x, axis=2, keepdims=True), torch.mean(x, dim=2, keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1), keepdims=True), torch.mean(x, dim=(0, 1), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 2), keepdims=True), torch.mean(x, dim=(0, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(1, 2), keepdims=True), torch.mean(x, dim=(1, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1, 2), keepdims=True), torch.mean(x, dim=(0, 1, 2), keepdim=True))\n```\n" + ], + "line": 42, + "token": 322, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 51, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_intrinsic_value(\n cash_flow: float,\n growth_rate: float,\n perpetual_growth_rate: float,\n weighted_average_cost_of_capital: float,\n cash_and_cash_equivalents: float,\n total_debt: float,\n shares_outstanding: float,\n periods: int = 5,\n) -> pd.DataFrame:\n components = {}\n\n cash_flow_projection = [cash_flow]\n\n # Cash Flow to Use\n for period in range(1, periods + 1):\n if period == 1:\n cash_flow_projection.append(cash_flow_projection[0] * (1 + growth_rate))\n else:\n cash_flow_projection.append(\n cash_flow_projection[period - 1] * (1 + growth_rate)\n )\n\n # Calculate the Terminal Value\n terminal_value = (\n cash_flow_projection[-1]\n * (1 + perpetual_growth_rate)\n / (weighted_average_cost_of_capital - perpetual_growth_rate)\n )\n\n # Add Terminal Value to the end of the cash flow projection\n cash_flow_projection[-1] = cash_flow_projection[-1] + terminal_value\n\n # Calculate the Present Value based on the Discounted Cash Flow Formula\n cash_flow_present_value = []\n for index, cash_flow_value in enumerate(cash_flow_projection):\n cash_flow_present_value.append(\n cash_flow_value / (1 + weighted_average_cost_of_capital) ** (index + 1)\n )\n\n # Calculate the Enterprise Value\n enterprise_value = sum(cash_flow_present_value)\n\n # Calculate the Equity Value\n equity_value = enterprise_value + cash_and_cash_equivalents - total_debt\n\n # Calculate the Intrinsic Value\n intrinsic_value = equity_value / shares_outstanding\n\n # Combine the components into a dictionary\n components = {\n \"Terminal Value\": terminal_value,\n \"Cash Flow Projection\": cash_flow_projection[-1],\n \"Enterprise Value\": enterprise_value,\n \"Equity Value\": equity_value,\n \"Intrinsic Value\": intrinsic_value,\n }\n\n return pd.DataFrame.from_dict(\n components, orient=\"index\", columns=[f\"Periods = {periods}\"]\n )\n```\n###test function signature:\n```python\n\ndef test_get_intrinsic_value(recorder):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_intrinsic_value", + "reference": "\ndef test_get_intrinsic_value(recorder):\n recorder.capture(\n intrinsic_model.get_intrinsic_value(\n cash_flow=500,\n growth_rate=0.05,\n perpetual_growth_rate=0.02,\n weighted_average_cost_of_capital=0.1,\n cash_and_cash_equivalents=100,\n total_debt=100,\n shares_outstanding=100,\n periods=5,\n )\n )", + "completions": [ + "```python\n\ndef test_get_intrinsic_value(recorder):\n cash_flow = 1000000\n growth_rate = 0.05\n perpetual_growth_rate = 0.05\n weighted_average_cost_of_capital = 0.1\n cash_and_cash_equivalents = 1000000\n total_debt = 1000000\n shares_outstanding = 1000000\n periods = 5\n\n result = get_intrinsic_value(\n cash_flow,\n growth_rate,\n perpetual_growth_rate,\n weighted_average_cost_of_capital,\n cash_and_cash_equivalents,\n total_debt,\n shares_outstanding,\n periods,\n )\n\n recorder.capture(result)\n```\n" + ], + "line": 61, + "token": 379, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 52, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef to_numpy(\n evset: EventSet,\n timestamp_to_datetime: bool = True,\n timestamps: bool = True,\n) -> Dict[str, ndarray]:\n timestamp_key = \"timestamp\"\n index_names = evset.schema.index_names()\n feature_names = evset.schema.feature_names()\n\n column_names = index_names + feature_names\n if timestamps:\n column_names += [timestamp_key]\n\n dst = {column_name: [] for column_name in column_names}\n for index, data in evset.data.items():\n assert isinstance(index, tuple)\n\n if timestamps:\n # Timestamps\n if evset.schema.is_unix_timestamp and timestamp_to_datetime:\n dst[timestamp_key].append(\n data.timestamps.astype(\"datetime64[s]\")\n )\n else:\n dst[timestamp_key].append(data.timestamps)\n\n # Features\n for feature_name, feature in zip(feature_names, data.features):\n dst[feature_name].append(feature)\n\n # Indexes\n num_timestamps = len(data.timestamps)\n for index_name, index_item in zip(index_names, index):\n dst[index_name].append(np.repeat(index_item, num_timestamps))\n\n dst = {k: np.concatenate(v) for k, v in dst.items()}\n return dst\n```\n###test function signature:\n```python\n\n def test_correct(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct", + "reference": "\n def test_correct(self):\n evset = event_set(\n timestamps=[\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n features={\n \"feature_1\": [0.5, 0.6],\n \"my_index\": [\"red\", \"blue\"],\n },\n indexes=[\"my_index\"],\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"timestamp\": np.array(\n [\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n dtype=\"datetime64[s]\",\n ),\n \"feature_1\": np.array([0.5, 0.6]),\n \"my_index\": np.array([b\"red\", b\"blue\"]),\n }\n\n for k in expected:\n np.testing.assert_array_equal(\n np.sort(result[k]), np.sort(expected[k])\n )", + "completions": [ + "```python\n def test_correct(self):\n evset = EventSet(\n data={\n \"timestamp\": [1610000000, 1610000000, 1610000000, 1610000000],\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n },\n schema=EventSetSchema(\n index_names=[\"index\"],\n feature_names=[\"feature1\", \"feature2\"],\n is_unix_timestamp=True,\n ),\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n \"timestamp\": [\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n ],\n }\n\n self.assertDictEqual(result, expected)\n```\n" + ], + "line": 41, + "token": 276, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 53, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def from_file(cls, path: str) -> \"LayerConfig\":\n logger.info(f\"Loading route config from {path}\")\n _, ext = os.path.splitext(path)\n with open(path, \"r\") as f:\n if ext == \".json\":\n layer = json.load(f)\n elif ext in [\".yaml\", \".yml\"]:\n layer = yaml.safe_load(f)\n else:\n raise ValueError(\n \"Unsupported file type. Only .json and .yaml are supported\"\n )\n\n if not is_valid(json.dumps(layer)):\n raise Exception(\"Invalid config JSON or YAML\")\n\n encoder_type = layer[\"encoder_type\"]\n encoder_name = layer[\"encoder_name\"]\n routes = []\n for route_data in layer[\"routes\"]:\n # Handle the 'llm' field specially if it exists\n if \"llm\" in route_data and route_data[\"llm\"] is not None:\n llm_data = route_data.pop(\n \"llm\"\n ) # Remove 'llm' from route_data and handle it separately\n # Use the module path directly from llm_data without modification\n llm_module_path = llm_data[\"module\"]\n # Dynamically import the module and then the class from that module\n llm_module = importlib.import_module(llm_module_path)\n llm_class = getattr(llm_module, llm_data[\"class\"])\n # Instantiate the LLM class with the provided model name\n llm = llm_class(name=llm_data[\"model\"])\n # Reassign the instantiated llm object back to route_data\n route_data[\"llm\"] = llm\n\n # Dynamically create the Route object using the remaining route_data\n route = Route(**route_data)\n routes.append(route)\n\n return cls(\n encoder_type=encoder_type, encoder_name=encoder_name, routes=routes\n )\n```\n###test function signature:\n```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_from_file_yaml", + "reference": "\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n config_path = tmp_path / \"config.yaml\"\n config_path.write_text(\n layer_yaml()\n ) # Assuming layer_yaml() returns a valid YAML string\n\n # Load the LayerConfig from the temporary file\n layer_config = LayerConfig.from_file(str(config_path))\n\n # Assertions to verify the loaded configuration\n assert layer_config.encoder_type == \"cohere\"\n assert layer_config.encoder_name == \"embed-english-v3.0\"\n assert len(layer_config.routes) == 2\n assert layer_config.routes[0].name == \"politics\"", + "completions": [ + "```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n yaml_file = tmp_path / \"test_layer.yaml\"\n yaml_file.write_text(\n \"\"\"\n encoder_type: openai\n encoder_name: gpt2\n routes:\n - name: test_route\n model_name: test_model\n llm:\n module: tests.test_layer\n class: TestLLM\n model: test_model\n encoder_kwargs:\n max_position_embeddings: 512\n num_attention_heads: 8\n intermediate_size: 2048\n hidden_act: gelu\n hidden_dropout_prob: 0.1\n attention_probs_dropout_prob: 0.1\n type_vocab_size: 16\n initializer_range: 0.02\n layer_norm_eps: 1e-12\n use_cache: True\n is_decoder: False\n classifier_dropout: 0.1\n use_multiplicative_input_mask: False\n num_labels: 2\n output_attentions: False\n output_hidden_states: False\n return_dict: True\n task_specific_params:\n summarization:\n max_length: 20\n min_length: 5\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n" + ], + "line": 49, + "token": 405, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 54, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def bind_all(self, circuit: QuantumCircuit) -> np.ndarray:\n arr = np.empty(self.shape, dtype=object)\n for idx in np.ndindex(self.shape):\n arr[idx] = self.bind(circuit, idx)\n return arr\n```\n###test function signature:\n```python\n def test_data(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_data", + "reference": " def test_data(self):\n with self.subTest(\"binding a single value\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 0)\n self.assertEqual(ba.shape, ())\n self.assertEqual(ba.size, 1)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuit = ba.bind(self.circuit, ())\n self.assertEqual(bound_circuit, self.circuit.assign_parameters(vals))\n\n with self.subTest(\"binding an array\"):\n vals = np.linspace(0, 1, 300).reshape((2, 3, 50))\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 2)\n self.assertEqual(ba.shape, (2, 3))\n self.assertEqual(ba.size, 6)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuits = ba.bind_all(self.circuit)\n self.assertEqual(bound_circuits.shape, (2, 3))\n self.assertEqual(bound_circuits[0, 0], self.circuit.assign_parameters(vals[0, 0]))\n self.assertEqual(bound_circuits[0, 1], self.circuit.assign_parameters(vals[0, 1]))\n self.assertEqual(bound_circuits[0, 2], self.circuit.assign_parameters(vals[0, 2]))\n self.assertEqual(bound_circuits[1, 0], self.circuit.assign_parameters(vals[1, 0]))\n self.assertEqual(bound_circuits[1, 1], self.circuit.assign_parameters(vals[1, 1]))\n self.assertEqual(bound_circuits[1, 2], self.circuit.assign_parameters(vals[1, 2]))\n\n with self.subTest(\"binding a single param\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params[0]: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 1)\n self.assertEqual(ba.ndim, 1)\n self.assertEqual(ba.shape, (50,))\n self.assertEqual(ba.size, 50)\n self.assertEqual(list(ba.data.keys()), [(self.params[0].name,)])\n np.testing.assert_allclose(list(ba.data.values()), [vals[..., np.newaxis]])", + "completions": [ + "```python\n def test_data(self):\n from qiskit import QuantumCircuit\n from qiskit.quantum_info import Pauli\n from qiskit.quantum_info import PauliTable\n from qiskit.quantum_info import PauliSumOp\n from qiskit.quantum_info import PauliOp\n from qiskit.quantum_info import PauliExpectation\n from qiskit.quantum_info import PauliMoment\n from qiskit.quantum_info import PauliTensor\n from qiskit.quantum_info import PauliList\n from qiskit.quantum_info import PauliTransform\n from qiskit.quantum_info import PauliTransformOp\n from qiskit.quantum_info import PauliTransformList\n from qiskit.quantum_info import PauliTransformSum\n from qiskit.quantum_info import PauliTransformProduct\n from qiskit.quantum_info import PauliTransformExpectation\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum" + ], + "line": 15, + "token": 122, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef set_seed(\n seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None\n) -> int:\n world_size, rank = get_world_size_and_rank()\n max_val = np.iinfo(np.uint32).max - world_size + 1\n min_val = np.iinfo(np.uint32).min\n if seed is None:\n rand_seed = torch.randint(min_val, max_val, (1,))\n seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks\n if seed < min_val or seed > max_val:\n raise ValueError(\n f\"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]\"\n )\n local_seed = seed + rank\n _log.debug(\n f\"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}\"\n )\n\n torch.manual_seed(local_seed)\n np.random.seed(local_seed)\n random.seed(local_seed)\n\n if debug_mode is not None:\n _log.debug(f\"Setting deterministic debug mode to {debug_mode}\")\n torch.set_deterministic_debug_mode(debug_mode)\n deterministic_debug_mode = torch.get_deterministic_debug_mode()\n if deterministic_debug_mode == 0:\n _log.debug(\"Disabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = False\n torch.backends.cudnn.benchmark = True\n else:\n _log.debug(\"Enabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.benchmark = False\n # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n return seed\n```\n###test function signature:\n```python\n\n def test_deterministic_true(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_deterministic_true", + "reference": "\n def test_deterministic_true(self) -> None:\n for det_debug_mode, det_debug_mode_str in [(1, \"warn\"), (2, \"error\")]:\n warn_only = det_debug_mode == 1\n for debug_mode in (det_debug_mode, det_debug_mode_str):\n set_seed(42, debug_mode=debug_mode)\n assert torch.backends.cudnn.deterministic\n assert not torch.backends.cudnn.benchmark\n assert det_debug_mode == torch.get_deterministic_debug_mode()\n assert torch.are_deterministic_algorithms_enabled()\n assert (\n warn_only == torch.is_deterministic_algorithms_warn_only_enabled()\n )\n assert os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] == \":4096:8\"", + "completions": [ + "```python\n def test_deterministic_true(self) -> None:\n seed = set_seed(debug_mode=True)\n self.assertEqual(seed, 123456789)\n```\n" + ], + "line": 45, + "token": 345, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 56, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n```\n###test function signature:\n```python\n\ndef test_concat_periodarray_2d():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_concat_periodarray_2d", + "reference": "\ndef test_concat_periodarray_2d():\n pi = pd.period_range(\"2016-01-01\", periods=36, freq=\"D\")\n arr = pi._data.reshape(6, 6)\n\n result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)\n tm.assert_period_array_equal(result, arr)\n\n result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)\n tm.assert_period_array_equal(result, arr)\n\n msg = (\n \"all the input array dimensions.* for the concatenation axis must match exactly\"\n )\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)\n\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:2], arr[2:]], axis=1)", + "completions": [ + "```python\n\ndef test_concat_periodarray_2d():\n import numpy as np\n import pandas as pd\n from pandas.core.dtypes.common import is_period_dtype\n\n def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n\n def _concat_datetime(to_concat, axis: AxisInt = 0):\n from pandas.core.arrays import PeriodArray\n\n if axis == 0:\n " + ], + "line": 70, + "token": 610, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 57, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef ingest_ocp_payload(request):\n request_id = uuid4().hex\n response_data = {\"request-id\": request_id, \"payload-name\": []}\n if request.method == \"POST\":\n if not request.FILES:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for _, file in request.FILES.items():\n payload_name = file.name\n response_data[\"payload-name\"].append(payload_name)\n s3_signature = get_s3_signature(settings.S3_ENDPOINT, payload_name, method=\"put_object\")\n res = upload_file_to_s3(s3_signature, data=file.file)\n if res.status_code == HTTPStatus.OK:\n response_data[\"upload\"] = \"success\"\n else:\n response_data[\"upload\"] = \"failed\"\n response_data[\"failed-reason\"] = res.reason\n return Response(response_data, status=res.status_code)\n send_payload(request_id, payload_name)\n else:\n params = request.query_params\n payload_names = params.get(\"payload_name\")\n if not payload_names:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for payload_name in payload_names.split(\",\"):\n send_payload(request_id, payload_name)\n response_data[\"payload-name\"].append(payload_name)\n\n response_data[\"ingest-started\"] = True\n\n return Response(response_data, status=HTTPStatus.ACCEPTED)\n```\n###test function signature:\n```python\n\n def test_ingest_ocp_payload(self):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ingest_ocp_payload", + "reference": "\n def test_ingest_ocp_payload(self):\n # Arrange\n file1 = SimpleUploadedFile(\"file1.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n file2 = SimpleUploadedFile(\"file2.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n test_table = [\n # Happy path tests\n (\n \"POST\",\n {\"file1\": file1, \"file2\": file2},\n \"\",\n HTTPStatus.ACCEPTED,\n {\"upload\": \"success\", \"ingest-started\": True},\n \"happy_path_post\",\n ),\n (\"GET\", {}, \"?payload_name=file1,file2\", HTTPStatus.ACCEPTED, {\"ingest-started\": True}, \"happy_path_get\"),\n # Edge cases\n (\"POST\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_files_post\"),\n (\"GET\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_payload_name_get\"),\n # Error cases\n (\"POST\", {\"file1\": file1}, \"\", HTTPStatus.INTERNAL_SERVER_ERROR, {\"upload\": \"failed\"}, \"error_case_post\"),\n (\"GET\", {}, \"?payload_name=non_existent_file\", HTTPStatus.ACCEPTED, {}, \"error_case_get\"),\n ]\n for test in test_table:\n method, files, query_params, expected_status, expected_response, test_id = test\n with self.subTest(test=test_id):\n\n mock_response = MagicMock()\n mock_response.status_code = (\n HTTPStatus.OK if expected_response.get(\"upload\") != \"failed\" else HTTPStatus.INTERNAL_SERVER_ERROR\n )\n mock_response.reason = \"Some reason\"\n\n with patch(\"masu.api.ingest_ocp_payload.get_s3_signature\") as mock_get_s3_signature, patch(\n \"masu.api.ingest_ocp_payload.upload_file_to_s3\", return_value=mock_response\n ) as mock_upload_file_to_s3, patch(\"masu.api.ingest_ocp_payload.send_payload\") as mock_send_payload:\n\n request = (\n APIRequestFactory().get(f\"/{query_params}\")\n if method == \"GET\"\n else APIRequestFactory().post(\"/\", data=files)\n )\n\n # Act\n response = ingest_ocp_payload(request)\n\n # Assert\n self.assertEqual(response.status_code, expected_status)\n if method == \"POST\" and files:\n mock_get_s3_signature.assert_called()\n mock_upload_file_to_s3.assert_called()\n if method == \"GET\" and \"payload_name\" in query_params:\n mock_send_payload.assert_called()\n for key, value in expected_response.items():\n self.assertEqual(response.data[key], value)", + "completions": [ + "```python\n def test_ingest_ocp_payload(self):\n # Arrange\n request = self.factory.post(\n \"/api/v1/ingest/ocp/payload\",\n data={\"payload_name\": \"payload_name\"},\n content_type=\"multipart/form-data\",\n )\n request.FILES = {\"payload_name\": \"payload_name\"}\n # Act\n response = ingest_ocp_payload(request)\n # Assert\n self.assertEqual(response.status_code, HTTPStatus.ACCEPTED)\n self.assertEqual(response.data[\"request-id\"], \"request_id\")\n self.assertEqual(response.data[\"payload-name\"], [\"payload_name\"])\n self.assertEqual(response.data[\"ingest-started\"], True)\n```\n" + ], + "line": 41, + "token": 308, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 58, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def compute_nb_items_with_metadata_field(\n project_id: str,\n metadata_field: str,\n collection_name: str,\n quantile_value: float = 0.1,\n) -> Tuple[int, float, int]:\n\n mongo_db = await get_mongo_db()\n\n try:\n pipeline = [\n # Match on the project id and the existence of the metadata field\n {\n \"$match\": {\n \"project_id\": project_id,\n f\"metadata.{metadata_field}\": {\"$exists\": True, \"$ne\": None},\n }\n },\n {\n \"$group\": {\n \"_id\": \"$user_id\", # Group by 'user_id'\n \"count\": {\"$sum\": 1}, # Count the documents in each group\n }\n },\n {\n \"$sort\": {\"count\": -1} # Sort by 'count' in descending order\n },\n ]\n\n # Execute the aggregation pipeline\n results = [] # List to hold the results\n async for group in mongo_db[collection_name].aggregate(pipeline):\n results.append(group[\"count\"])\n\n logger.debug(\n f\"Results for {metadata_field} in collection {collection_name} for project {project_id}: {results}\"\n )\n\n # If no results, return 0\n if len(results) == 0:\n return 0, 0.0, 0\n\n # Get the average and the quantiles\n average = sum(results) / len(results)\n bottom_quantile = results[int(len(results) * quantile_value)]\n top_quantile = results[int(len(results) * (1 - quantile_value))]\n\n return bottom_quantile, average, top_quantile\n\n except Exception as e:\n logger.warning(\n f\"Failed to fetch the number of {metadata_field} in collection {collection_name} for project {project_id}: {e}\",\n )\n return 0, 0.0, 0\n```\n###test function signature:\n```python\n\nasync def test_main_pipeline(db, populated_project):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_main_pipeline", + "reference": "\nasync def test_main_pipeline(db, populated_project):\n async for mongo_db in db:\n test_project_id = populated_project.id\n\n # Run the function\n topics = await fetch_topics(test_project_id)\n\n logger.debug(f\"Topics: {topics}\")\n\n count = await nb_items_with_a_metadata_field(\n test_project_id, \"tasks\", \"user_id\"\n )\n\n assert count > 0\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_successrate_metadata_quantiles(test_project_id, \"user_id\")\n\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert bottom_quantile >= 0\n # assert bottom_quantile <= top_quantile\n\n test_project_id = populated_project.id\n test_metadata_field = \"user_id\"\n test_collection_name = \"tasks\"\n test_quantile_value = 0.1\n\n # Run the function\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_nb_items_with_metadata_field(\n test_project_id,\n test_metadata_field,\n test_collection_name,\n test_quantile_value,\n )\n\n logger.debug(\"runned compute_nb_items_with_metadata_field\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, int)\n assert isinstance(average, float)\n assert isinstance(top_quantile, int)\n # assert bottom_quantile <= average <= top_quantile\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_session_length_per_metadata(\n test_project_id,\n test_metadata_field,\n test_quantile_value,\n )\n logger.debug(\"runned compute_session_length_per_metadata\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, float)\n assert isinstance(average, float)\n assert isinstance(top_quantile, float)\n # assert bottom_quantile <= average <= top_quantile", + "completions": [ + "```python\nasync def test_main_pipeline(db, populated_project):\n # Arrange\n project_id = populated_project.id\n collection_name = \"collection_name\"\n metadata_field = \"metadata_field\"\n quantile_value = 0.1\n\n # Act\n result = await compute_nb_items_with_metadata_field(\n project_id, metadata_field, collection_name, quantile_value\n )\n\n # Assert\n assert result == (0, 0.0, 0)\n```\n" + ], + "line": 56, + "token": 430, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 59, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n```\n###test function signature:\n```python\n\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_engine_auto_error_message", + "reference": "\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n from pandas.compat._optional import VERSIONS\n\n # Do we have engines installed, but a bad version of them?\n pa_min_ver = VERSIONS.get(\"pyarrow\")\n fp_min_ver = VERSIONS.get(\"fastparquet\")\n have_pa_bad_version = (\n False\n if not _HAVE_PYARROW\n else Version(pyarrow.__version__) < Version(pa_min_ver)\n )\n have_fp_bad_version = (\n False\n if not _HAVE_FASTPARQUET\n else Version(fastparquet.__version__) < Version(fp_min_ver)\n )\n # Do we have usable engines installed?\n have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version\n have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version\n\n if not have_usable_pa and not have_usable_fp:\n # No usable engines found.\n if have_pa_bad_version:\n match = f\"Pandas requires version .{pa_min_ver}. or newer of .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n\n if have_fp_bad_version:\n match = f\"Pandas requires version .{fp_min_ver}. or newer of .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")", + "completions": [ + "```python\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n import pyarrow\n import fastparquet\n from pyarrow.parquet import ParquetEngine\n from fastparquet.parquet import ParquetEngine as FastParquetEngine\n\n def get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n\n # pyarrow is installed\n assert get_engine(\"auto\").__class__ == PyArrowImpl\n assert get_engine(\"pyarrow\").__class__ == PyArrowImpl\n assert get_engine(\"fastparquet\").__class__ == FastParquetImpl\n\n # pyarrow is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'pyarrow'\" in str(excinfo.value)\n\n # fastparquet is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'fastparquet'\" in str(excinfo.value)\n\n # pyarrow is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n\n # fastparquet is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n" + ], + "line": 38, + "token": 277, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 60, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def attribute(\n self,\n inp: InterpretableInput,\n target: Union[str, torch.Tensor, None] = None,\n num_trials: int = 1,\n gen_args: Optional[Dict] = None,\n # internal callback hook can be used for logging\n _inspect_forward: Optional[Callable] = None,\n **kwargs,\n ) -> LLMAttributionResult:\n\n assert isinstance(\n inp, self.SUPPORTED_INPUTS\n ), f\"LLMAttribution does not support input type {type(inp)}\"\n\n if target is None:\n # generate when None\n assert hasattr(self.model, \"generate\") and callable(self.model.generate), (\n \"The model does not have recognizable generate function.\"\n \"Target must be given for attribution\"\n )\n\n if not gen_args:\n gen_args = DEFAULT_GEN_ARGS\n\n model_inp = self._format_model_input(inp.to_model_input())\n output_tokens = self.model.generate(model_inp, **gen_args)\n target_tokens = output_tokens[0][model_inp.size(1) :]\n else:\n assert gen_args is None, \"gen_args must be None when target is given\"\n\n if type(target) is str:\n # exclude sos\n target_tokens = self.tokenizer.encode(target)[1:]\n target_tokens = torch.tensor(target_tokens)\n elif type(target) is torch.Tensor:\n target_tokens = target\n\n attr = torch.zeros(\n [\n 1 + len(target_tokens) if self.include_per_token_attr else 1,\n inp.n_itp_features,\n ],\n dtype=torch.float,\n device=self.device,\n )\n\n for _ in range(num_trials):\n attr_input = inp.to_tensor().to(self.device)\n\n cur_attr = self.attr_method.attribute(\n attr_input,\n additional_forward_args=(inp, target_tokens, _inspect_forward),\n **kwargs,\n )\n\n # temp necessary due to FA & Shapley's different return shape of multi-task\n # FA will flatten output shape internally (n_output_token, n_itp_features)\n # Shapley will keep output shape (batch, n_output_token, n_input_features)\n cur_attr = cur_attr.reshape(attr.shape)\n\n attr += cur_attr\n\n attr = attr / num_trials\n\n attr = inp.format_attr(attr)\n\n return LLMAttributionResult(\n attr[0],\n attr[1:]\n if self.include_per_token_attr\n else None, # shape(n_output_token, n_input_features)\n inp.values,\n self.tokenizer.convert_ids_to_tokens(target_tokens),\n )\n```\n###test function signature:\n```python\n\n def test_llm_attr(self, AttrClass) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_llm_attr", + "reference": "\n def test_llm_attr(self, AttrClass) -> None:\n llm = DummyLLM()\n llm.to(self.device)\n tokenizer = DummyTokenizer()\n llm_attr = LLMAttribution(AttrClass(llm), tokenizer)\n\n inp = TextTemplateInput(\"{} b {} {} e {}\", [\"a\", \"c\", \"d\", \"f\"])\n res = llm_attr.attribute(inp, \"m n o p q\")\n\n self.assertEqual(res.seq_attr.shape, (4,))\n self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))\n self.assertEqual(res.input_tokens, [\"a\", \"c\", \"d\", \"f\"])\n self.assertEqual(res.output_tokens, [\"m\", \"n\", \"o\", \"p\", \"q\"])\n self.assertEqual(res.seq_attr.device.type, self.device)\n self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)", + "completions": [ + "```python\n def test_llm_attr(self, AttrClass) -> None:\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for L" + ], + "line": 72, + "token": 464, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 61, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def plan(self, task: str):\n agent_pool_description = self.agent_pool.get_agent_pool_description()\n feedback = self.get_last_feedback()\n last_plan = self.get_last_plan()\n\n # run the planning\n plan = self.planning.execute(task, agent_pool_description, last_plan, feedback)\n\n logger.info(f\"Plan: {plan}\")\n\n self.shared_memory.add(\n event_type=EventType.planning, agent=self.name, content=str(plan)\n )\n\n return plan\n```\n###test function signature:\n```python\n\ndef test_planner_succeeds(get_llm): # noqa: F811\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_planner_succeeds", + "reference": "\ndef test_planner_succeeds(get_llm): # noqa: F811\n llm = get_llm(__file__, test_planner_succeeds.__name__)\n\n physicist_description = (\n \"The physicist agent answers questions or research about physics-related topics\"\n )\n physicist = Physicist(\n name=\"Physicist\",\n description=physicist_description,\n llm=llm,\n )\n\n agent_pool = AgentPool()\n agent_pool.add_agents([physicist])\n\n shared_memory = SharedMemory(\n objective=\"Share the information across different agents.\",\n agent_pool=agent_pool,\n )\n\n planner = Planner(\n name=\"planner\",\n agent_pool=agent_pool,\n shared_memory=shared_memory,\n llm=llm,\n )\n\n task = \"\"\"We need to render a highly complex 3D image on the solar system. We can use any publicly avaliable\n resources to achieve this task.\"\"\" # noqa: E501\n\n p = planner.plan(task=task)\n assert len(p.steps) > 0", + "completions": [ + "```python\ndef test_planner_succeeds(get_llm): # noqa: F811\n planner = get_llm\n planner.agent_pool.set_agent_pool_description(\n {\n \"agent_pool\": {\n \"agents\": [\n {\n \"name\": \"agent1\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n {\n \"name\": \"agent2\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n ]\n }\n }\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task1\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task2\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task3\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task4\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task5\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task6\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task7\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\"," + ], + "line": 21, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_edge(self, from_step: str, to_step: str) -> None:\n if from_step not in self.G:\n raise ValueError(f\"Step with name '{from_step}' does not exist\")\n\n if to_step not in self.G:\n raise ValueError(f\"Step with name '{to_step}' does not exist\")\n\n if to_step in self.G[from_step]:\n raise ValueError(\n f\"There is already a edge from '{to_step}' to '{from_step}'\"\n )\n\n if to_step in nx.ancestors(self.G, from_step):\n raise ValueError(\n f\"Cannot add edge from '{from_step}' to '{to_step}' as it would create a cycle.\"\n )\n\n self.G.add_edge(from_step, to_step)\n```\n###test function signature:\n```python\n\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validate_missing_step_input", + "reference": "\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n class DummyStep3(Step):\n @property\n def inputs(self) -> List[str]:\n return [\"instruction\"]\n\n @property\n def outputs(self) -> List[str]:\n return [\"response\"]\n\n def process(self) -> \"StepOutput\": # type: ignore\n yield [{\"response\": \"response1\"}]\n\n dag = DAG()\n dag.add_step(dummy_generator_step)\n dag.add_step(DummyStep3(name=\"dummy_step_3\", pipeline=pipeline))\n dag.add_edge(\"dummy_generator_step\", \"dummy_step_3\")\n\n with pytest.raises(\n ValueError,\n match=\"Step 'dummy_step_3' should have a parameter with type hint `StepInput`\",\n ):\n dag.validate()", + "completions": [ + "```python\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_2\")\n```\n" + ], + "line": 26, + "token": 238, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 63, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_hourly_temp_mean", + "reference": "\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n meter_data = il_electricity_cdd_hdd_hourly[\"meter_data\"][\"2016-03-01\":\"2016-07-01\"]\n temperature_data = il_electricity_cdd_hdd_hourly[\"temperature_data\"][\n \"2016-03-01\":\"2016-07-01\"\n ]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert list(sorted(df.columns)) == [\n \"n_hours_dropped\",\n \"n_hours_kept\",\n \"temperature_mean\",\n ]\n assert df.shape == (2952, 3)\n\n assert round(df.temperature_mean.mean()) == 62.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n df = il_electricity_cdd_hdd_hourly.copy()\n df = df.iloc[100:150]\n df = df.assign(\n **{\n \"cdd_10\": np.maximum(df.temperature_mean - 10, 0),\n \"cdd_20\": np.maximum(df.temperature_mean - 20, 0),\n \"hdd_10\": np.maximum(10 - df.temperature_mean, 0),\n \"hdd_20\": np.maximum(20 - df.temperature_mean, 0),\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n df = df.drop([\"temperature_mean\"], axis=1)\n df = df.reindex(il_electricity_cdd_hdd_hourly.index)\n df = overwrite_partial_rows_with_nan(df)\n df = df.iloc[:-1].reindex(df.index)\n assert_frame_equal(df, il_electricity_cdd_hdd_hourly)\n```\n" + ], + "line": 167, + "token": 992, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 64, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef pandasSQL_builder(\n con,\n schema: str | None = None,\n need_transaction: bool = False,\n) -> PandasSQL:\n import sqlite3\n\n if isinstance(con, sqlite3.Connection) or con is None:\n return SQLiteDatabase(con)\n\n sqlalchemy = import_optional_dependency(\"sqlalchemy\", errors=\"ignore\")\n\n if isinstance(con, str) and sqlalchemy is None:\n raise ImportError(\"Using URI string without sqlalchemy installed.\")\n\n if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):\n return SQLDatabase(con, schema, need_transaction)\n\n warnings.warn(\n \"pandas only supports SQLAlchemy connectable (engine/connection) or \"\n \"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 \"\n \"objects are not tested. Please consider using SQLAlchemy.\",\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return SQLiteDatabase(con)\n```\n###test function signature:\n```python\n\n def test_execute_closed_connection(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_execute_closed_connection", + "reference": "\n def test_execute_closed_connection(self):\n create_sql = \"\"\"\n CREATE TABLE test\n (\n a TEXT,\n b TEXT,\n c REAL,\n PRIMARY KEY (a, b)\n );\n \"\"\"\n with contextlib.closing(sqlite3.connect(\":memory:\")) as conn:\n cur = conn.cursor()\n cur.execute(create_sql)\n\n with sql.pandasSQL_builder(conn) as pandas_sql:\n pandas_sql.execute('INSERT INTO test VALUES(\"foo\", \"bar\", 1.234)')\n\n msg = \"Cannot operate on a closed database.\"\n with pytest.raises(sqlite3.ProgrammingError, match=msg):\n tquery(\"select * from test\", con=conn)", + "completions": [ + "```python\n def test_execute_closed_connection(self):\n from pandas.io.sql import PandasSQL\n from pandas.io.sql.sqldatabase import SQLiteDatabase\n from pandas.io.sql.pandas_sql import PandasSQLQueryCompiler\n from pandas.io.sql.pandas_sql import PandasSQLEngine\n from pandas.io.sql.pandas_sql import PandasSQLTable\n from pandas.io.sql.pandas_sql import PandasSQLConverter\n from pandas.io.sql.pandas_sql import PandasSQLAlchemy\n from pandas.io.sql.pandas_sql import PandasSQLQuery\n from pandas.io.sql.pandas_sql import PandasSQLExecutionContext\n from pandas.io.sql.pandas_sql import PandasSQLCompiler\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas." + ], + "line": 31, + "token": 237, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 65, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_hourly_degree_days", + "reference": "\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index,\n temperature_data,\n heating_balance_points=[60, 61],\n cooling_balance_points=[65, 66],\n temperature_mean=False,\n degree_day_method=\"hourly\",\n )\n assert df.shape == (27, 6)\n assert list(sorted(df.columns)) == [\n \"cdd_65\",\n \"cdd_66\",\n \"hdd_60\",\n \"hdd_61\",\n \"n_hours_dropped\",\n \"n_hours_kept\",\n ]\n snapshot.assert_match(\n [\n round(df.hdd_60.mean(), 2),\n round(df.hdd_61.mean(), 2),\n round(df.cdd_65.mean(), 2),\n round(df.cdd_66.mean(), 2),\n round(df.n_hours_kept.mean(), 2),\n round(df.n_hours_dropped.mean(), 2),\n ],\n \"values\",\n )", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"H\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 1" + ], + "line": 168, + "token": 985, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 66, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef initialize_model_parallel(\n tensor_model_parallel_size: int = 1,\n pipeline_model_parallel_size: int = 1,\n virtual_pipeline_model_parallel_size: Optional[int] = None,\n pipeline_model_parallel_split_rank: Optional[int] = None,\n use_fp8: bool = False,\n) -> None:\n # Get world size and rank. Ensure some consistencies.\n assert torch.distributed.is_initialized()\n world_size: int = torch.distributed.get_world_size()\n\n if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:\n raise RuntimeError(\n f\"world_size ({world_size}) is not divisible by tensor_model_parallel_size \"\n f\"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n )\n\n data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n\n num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n num_data_parallel_groups: int = world_size // data_parallel_size\n\n if virtual_pipeline_model_parallel_size is not None:\n if not pipeline_model_parallel_size > 2:\n raise RuntimeError(\"pipeline-model-parallel size should be greater than 2 with \" \"interleaved schedule\")\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size\n\n if pipeline_model_parallel_split_rank is not None:\n global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK\n _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank\n\n rank = torch.distributed.get_rank()\n\n # Build the data-parallel groups.\n global _DATA_PARALLEL_GROUP\n global _DATA_PARALLEL_GROUP_GLOO\n global _DATA_PARALLEL_GLOBAL_RANKS\n assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'\n all_data_parallel_group_ranks = []\n for i in range(pipeline_model_parallel_size):\n start_rank = i * num_pipeline_model_parallel_groups\n end_rank = (i + 1) * num_pipeline_model_parallel_groups\n for j in range(tensor_model_parallel_size):\n ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)\n all_data_parallel_group_ranks.append(list(ranks))\n group = torch.distributed.new_group(ranks)\n group_gloo = torch.distributed.new_group(ranks, backend=\"gloo\")\n if rank in ranks:\n _DATA_PARALLEL_GROUP = group\n _DATA_PARALLEL_GROUP_GLOO = group_gloo\n _DATA_PARALLEL_GLOBAL_RANKS = ranks\n\n # Build the model-parallel groups.\n global _MODEL_PARALLEL_GROUP\n assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'\n for i in range(data_parallel_size):\n ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _MODEL_PARALLEL_GROUP = group\n\n # Build the tensor model-parallel groups.\n global _TENSOR_MODEL_PARALLEL_GROUP\n assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized'\n for i in range(num_tensor_model_parallel_groups):\n ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _TENSOR_MODEL_PARALLEL_GROUP = group\n\n # Build the pipeline model-parallel groups and embedding groups\n # (first and last rank in each pipeline model-parallel group).\n global _PIPELINE_MODEL_PARALLEL_GROUP\n global _PIPELINE_GLOBAL_RANKS\n assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized'\n global _EMBEDDING_GROUP\n global _EMBEDDING_GLOBAL_RANKS\n assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'\n global _POSITION_EMBEDDING_GROUP\n global _POSITION_EMBEDDING_GLOBAL_RANKS\n assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'\n for i in range(num_pipeline_model_parallel_groups):\n ranks = range(i, world_size, num_pipeline_model_parallel_groups)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _PIPELINE_MODEL_PARALLEL_GROUP = group\n _PIPELINE_GLOBAL_RANKS = ranks\n # Setup embedding group (to exchange gradients between\n # first and last stages).\n if len(ranks) > 1:\n embedding_ranks = [ranks[0], ranks[-1]]\n position_embedding_ranks = [ranks[0]]\n if pipeline_model_parallel_split_rank is not None:\n if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:\n embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]]\n if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:\n position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]\n else:\n embedding_ranks = ranks\n position_embedding_ranks = ranks\n\n group = torch.distributed.new_group(embedding_ranks)\n if rank in embedding_ranks:\n _EMBEDDING_GROUP = group\n if rank in ranks:\n _EMBEDDING_GLOBAL_RANKS = embedding_ranks\n\n group = torch.distributed.new_group(position_embedding_ranks)\n if rank in position_embedding_ranks:\n _POSITION_EMBEDDING_GROUP = group\n if rank in ranks:\n _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks\n\n # Build the FP8 groups.\n global _AMAX_REDUCTION_GROUP\n assert _AMAX_REDUCTION_GROUP is None, \\\n 'FP8 amax reduction group is already initialized'\n if use_fp8:\n amax_group_size: int = tensor_model_parallel_size * data_parallel_size\n num_amax_groups: int = world_size // amax_group_size\n for i in range(num_amax_groups):\n start_rank = i * amax_group_size\n end_rank = (i + 1) * amax_group_size\n ranks = range(start_rank, end_rank)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _AMAX_REDUCTION_GROUP = group\n\n # Initialize global memory buffer\n # This isn't really \"parallel state\" but there isn't another good place to\n # put this. If we end up with a more generic initialization of megatron-core\n # we could stick it there\n _set_global_memory_buffer()\n```\n###test function signature:\n```python\n\ndef test_initialize__and_destroy_model_parallel():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_initialize__and_destroy_model_parallel", + "reference": "\ndef test_initialize__and_destroy_model_parallel():\n with pytest.raises(AssertionError):\n assert(ps.initialize_model_parallel())\n Utils.initialize_distributed()\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))\n Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)\n\n assert(ps.model_parallel_is_initialized())\n assert(ps.get_model_parallel_group() is not None)\n assert(ps.get_tensor_model_parallel_group() is not None)\n assert(ps.get_pipeline_model_parallel_group() is not None)\n assert(ps.get_data_parallel_group() is not None) \n Utils.destroy_model_parallel()\n assert(ps._MODEL_PARALLEL_GROUP is None)", + "completions": [ + "```python\ndef test_initialize_and_destroy_model_parallel():\n initialize_model_parallel(\n tensor_model_parallel_size=1,\n pipeline_model_parallel_size=1,\n virtual_pipeline_model_parallel_size=None,\n pipeline_model_parallel_split_rank=None,\n use_fp8=False,\n )\n destroy_model_parallel()\n```\n" + ], + "line": 133, + "token": 886, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 67, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_shapes(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_shapes", + "reference": "\n def test_spectrogram_shapes(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 729))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 183))", + "completions": [ + "```python\n def test_spectrogram_shapes(self):\n waveform = np.random.rand(1000)\n window = np.hanning(100)\n frame_length = 100\n hop_length = 10\n fft_length = 100\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (10, 100))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 68, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_base64_image_given_and_no_resize_needed", + "reference": "\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n _, image = example_local_image\n _, buffer = cv2.imencode(\".jpg\", image)\n base64_image = base64.b64encode(buffer).decode(\"ascii\")\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=base64_image,\n max_width=None,\n max_height=None,\n )\n\n # then\n assert (\n scaling_factor is None\n ), \"No resize parameters given, so scaling factor should not be established\"\n assert (\n serialised_image == base64_image\n ), \"Serialised image should be identical with input\"", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n image_bytes = base64.b64decode(example_local_image[0])\n image = bytes_to_opencv_image(payload=image_bytes)\n # when\n result = load_image_from_string(\n reference=example_local_image[0],\n max_height=None,\n max_width=None,\n )\n # then\n assert result[0] == example_local_image[0]\n assert result[1] is None\n assert np.array_equal(image, result[1])\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 69, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def submit_payout_address(self, robot_index=1):\n path = reverse(\"order\")\n params = f\"?order_id={self.order_id}\"\n headers = self.get_robot_auth(robot_index)\n\n payout_address = create_address(\"robot\")\n signed_payout_address = sign_message(\n payout_address,\n passphrase_path=f\"tests/robots/{robot_index}/token\",\n private_key_path=f\"tests/robots/{robot_index}/enc_priv_key\",\n )\n body = {\n \"action\": \"update_address\",\n \"address\": signed_payout_address,\n \"mining_fee_rate\": 50,\n }\n self.response = self.client.post(path + params, body, **headers)\n```\n###test function signature:\n```python\n def test_successful_onchain(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_successful_onchain", + "reference": " def test_successful_onchain(self):\n trade = Trade(self.client)\n trade.publish_order()\n trade.take_order()\n trade.lock_taker_bond()\n trade.lock_escrow(trade.taker_index)\n trade.submit_payout_address(trade.maker_index)\n trade.confirm_fiat(trade.maker_index)\n trade.confirm_fiat(trade.taker_index)\n\n trade.process_payouts(mine_a_block=True)\n trade.get_order(trade.maker_index)\n\n data = trade.response.json()\n\n self.assertEqual(trade.response.status_code, 200)\n self.assertResponse(trade.response)\n\n self.assertEqual(data[\"status_message\"], Order.Status(Order.Status.SUC).label)\n self.assertTrue(data[\"is_fiat_sent\"])\n self.assertFalse(data[\"is_disputed\"])\n self.assertIsInstance(data[\"maker_summary\"][\"address\"], str)\n self.assertIsHash(data[\"maker_summary\"][\"txid\"])\n\n self.assert_order_logs(data[\"id\"])", + "completions": [ + "```python\n def test_successful_onchain(self):\n self.submit_payout_address(robot_index=1)\n self.assertEqual(self.response.status_code, 200)\n self.assertEqual(self.response.json()[\"status\"], \"success\")\n```\n" + ], + "line": 26, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 70, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_blank_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_blank_line", + "reference": "\ndef test_header_to_metadata_and_cell_blank_line():\n text = \"\"\"---\ntitle: Sample header\n---\n\nHeader is followed by a blank line\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {}\n assert lines[pos].startswith(\"Header is\")", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_blank_line():\n lines = [\"#' ---\", \"\"]\n header_prefix = \"#'\"\n header_suffix = \"\"\n ext = None\n root_level_metadata_as_raw_cell = True\n metadata, jupyter, cell, i = header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext, root_level_metadata_as_raw_cell\n )\n assert metadata == {}\n assert jupyter == []\n assert cell == new_raw_cell(source=\"\\n\".join([\"---\"]), metadata={\"lines_to_next_cell\": 1})\n assert i == 1\n```\n" + ], + "line": 104, + "token": 642, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 71, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def _create_output(\n self,\n accepted: torch.Tensor, # [batch_size, k]\n recovered_token_ids: torch.Tensor, # [batch_size, k]\n draft_token_ids: torch.Tensor, # [batch_size, k]\n bonus_token_ids: torch.Tensor, # [batch_size]\n ) -> torch.Tensor:\n bonus_token_ids = bonus_token_ids.squeeze()\n batch_size, k = recovered_token_ids.shape\n\n # Determine the index of the first False value for each row.\n limits = (accepted == 0).max(1).indices\n limits[~(accepted == 0).any(1)] = k\n\n # Create masks using the indices.\n indices = torch.arange(k, device=accepted.device).unsqueeze(0)\n accepted_mask = indices < limits.unsqueeze(1)\n after_false_mask = indices == limits.unsqueeze(1)\n\n # Create an extended output tensor\n output_with_bonus_tokens = -torch.ones(\n (batch_size, k + self._num_bonus_tokens),\n dtype=self.token_id_dtype,\n device=accepted.device)\n output = output_with_bonus_tokens[:, :k]\n\n # Fill in the first k columns of the output tensor using masks and data\n # tensors.\n output[:, :k] = torch.where(accepted_mask, draft_token_ids,\n -torch.ones_like(draft_token_ids))\n\n # Fill the last column.\n # We check output directly as accepted may have True values inconsistent\n # with causal acceptance.\n output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,\n bonus_token_ids, -1)\n\n # We disable bonus tokens because it causes corrupt KV cache for\n # proposal methods that require KV cache. We can fix it by \"prefilling\"\n # the bonus token in the proposer. The following issue tracks the fix.\n # https://github.com/vllm-project/vllm/issues/4212\n output_with_bonus_tokens[:, -1] = -1\n\n # Fill the recovered token ids.\n output.mul_(~after_false_mask).add_(\n recovered_token_ids.mul(after_false_mask))\n\n self.num_accepted_tokens += accepted.sum()\n self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()\n self.num_draft_tokens += batch_size * k\n\n return output_with_bonus_tokens\n```\n###test function signature:\n```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct_output_format", + "reference": "def test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n set_random_seed(seed)\n torch.set_default_device(device)\n\n batch_size = 10\n k = 5\n vocab_size = 3000\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"no_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"some_tokens_accepted\":\n last_accepted_indices = torch.randint(low=-1,\n high=k,\n size=(batch_size, ))\n accepted = mock_causal_accepted_tensor(k, last_accepted_indices)\n else:\n raise AssertionError()\n\n recovered_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n draft_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n bonus_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, 1),\n dtype=torch.int64)\n\n rejection_sampler = RejectionSampler()\n rejection_sampler.init_gpu_tensors(rank=0)\n output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access\n accepted,\n recovered_token_ids,\n draft_token_ids,\n bonus_token_ids,\n )\n\n # Bonus tokens are currently disabled. Verify they're set to -1.\n # See https://github.com/vllm-project/vllm/issues/4212\n expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n # Expect all tokens to be equal to draft tokens.\n assert torch.equal(output_token_ids[:, :-1], draft_token_ids)\n\n # Expect all bonus tokens to be included.\n assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)\n elif which_tokens_accepted == \"no_tokens_accepted\":\n # Expect first token to be equal to recovered tokens.\n assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])\n\n # Expect everything else to be -1.\n assert torch.equal(output_token_ids[:, 1:],\n torch.ones_like(output_token_ids[:, 1:]) * -1)\n elif which_tokens_accepted == \"some_tokens_accepted\":\n recovered_plus_bonus = torch.cat(\n (recovered_token_ids, expected_bonus_token_ids), dim=-1)\n # Assert first rejected token is a recovered token or bonus token.\n assert torch.equal(\n recovered_plus_bonus[torch.arange(0, batch_size),\n last_accepted_indices + 1],\n output_token_ids[torch.arange(0, batch_size),\n last_accepted_indices + 1])\n\n # Assert every subsequent token is -1.\n subsequent_mask = torch.arange(0, k + 1).expand(\n batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)\n assert torch.all(output_token_ids[subsequent_mask] == -1)", + "completions": [ + "```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n torch.manual_seed(seed)\n batch_size = 10\n k = 10\n num_bonus_tokens = 1\n accepted = torch.randint(0, 2, (batch_size, k), dtype=torch.int32)\n recovered_token_ids = torch.randint(0, 10, (batch_size, k),\n dtype=torch.int32)\n draft_token_ids = torch.randint(0, 10, (batch_size, k), dtype=torch.int32)\n bonus_token_ids = torch.randint(0, 10, (batch_size,), dtype=torch.int32)\n output = _create_output(accepted, recovered_token_ids, draft_token_ids,\n bonus_token_ids)\n assert output.shape == (batch_size, k + num_bonus_tokens)\n assert output.dtype == torch.int32\n assert output.device == device\n assert (output[:, :k] == torch.where(accepted, draft_token_ids,\n -torch.ones_like(draft_token_ids))).all()\n assert (output[:, -1] == torch.where(output[:, -1] != -1, bonus_token_ids,\n -1)).all()\n assert (output[:, -1] == -1).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted" + ], + "line": 54, + "token": 433, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 72, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_array(fp, allow_pickle=False, pickle_kwargs=None, *,\n max_header_size=_MAX_HEADER_SIZE):\n if allow_pickle:\n # Effectively ignore max_header_size, since `allow_pickle` indicates\n # that the input is fully trusted.\n max_header_size = 2**64\n\n version = read_magic(fp)\n _check_version(version)\n shape, fortran_order, dtype = _read_array_header(\n fp, version, max_header_size=max_header_size)\n if len(shape) == 0:\n count = 1\n else:\n count = numpy.multiply.reduce(shape, dtype=numpy.int64)\n\n # Now read the actual data.\n if dtype.hasobject:\n # The array contained Python objects. We need to unpickle the data.\n if not allow_pickle:\n raise ValueError(\"Object arrays cannot be loaded when \"\n \"allow_pickle=False\")\n if pickle_kwargs is None:\n pickle_kwargs = {}\n try:\n array = pickle.load(fp, **pickle_kwargs)\n except UnicodeError as err:\n # Friendlier error message\n raise UnicodeError(\"Unpickling a python object failed: %r\\n\"\n \"You may need to pass the encoding= option \"\n \"to numpy.load\" % (err,)) from err\n else:\n if isfileobj(fp):\n # We can use the fast fromfile() function.\n array = numpy.fromfile(fp, dtype=dtype, count=count)\n else:\n # This is not a real file. We have to read it the\n # memory-intensive way.\n # crc32 module fails on reads greater than 2 ** 32 bytes,\n # breaking large reads from gzip streams. Chunk reads to\n # BUFFER_SIZE bytes to avoid issue and reduce memory overhead\n # of the read. In non-chunked case count < max_read_count, so\n # only one read is performed.\n\n # Use np.ndarray instead of np.empty since the latter does\n # not correctly instantiate zero-width string dtypes; see\n # https://github.com/numpy/numpy/pull/6430\n array = numpy.ndarray(count, dtype=dtype)\n\n if dtype.itemsize > 0:\n # If dtype.itemsize == 0 then there's nothing more to read\n max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)\n\n for i in range(0, count, max_read_count):\n read_count = min(max_read_count, count - i)\n read_size = int(read_count * dtype.itemsize)\n data = _read_bytes(fp, read_size, \"array data\")\n array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n count=read_count)\n\n if fortran_order:\n array.shape = shape[::-1]\n array = array.transpose()\n else:\n array.shape = shape\n\n return array\n```\n###test function signature:\n```python\n\ndef test_version_2_0():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_version_2_0", + "reference": "\ndef test_version_2_0():\n f = BytesIO()\n # requires more than 2 byte for header\n dt = [((\"%d\" % i) * 100, float) for i in range(500)]\n d = np.ones(1000, dtype=dt)\n\n format.write_array(f, d, version=(2, 0))\n with warnings.catch_warnings(record=True) as w:\n warnings.filterwarnings('always', '', UserWarning)\n format.write_array(f, d)\n assert_(w[0].category is UserWarning)\n\n # check alignment of data portion\n f.seek(0)\n header = f.readline()\n assert_(len(header) % format.ARRAY_ALIGN == 0)\n\n f.seek(0)\n n = format.read_array(f, max_header_size=200000)\n assert_array_equal(d, n)\n\n # 1.0 requested but data cannot be saved this way\n assert_raises(ValueError, format.write_array, f, d, (1, 0))", + "completions": [ + "```python\ndef test_version_2_0():\n fp = open(\"test_file.npy\", \"rb\")\n read_array(fp)\n```\n" + ], + "line": 70, + "token": 516, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 73, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):\n if not torch.is_tensor(audio):\n if isinstance(audio, str):\n audio = load_audio(audio)\n audio = torch.from_numpy(audio)\n\n window = torch.hann_window(N_FFT).to(audio.device)\n stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)\n magnitudes = stft[..., :-1].abs() ** 2\n\n filters = mel_filters(audio.device, n_mels)\n mel_spec = filters @ magnitudes\n\n log_spec = torch.clamp(mel_spec, min=1e-10).log10()\n log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)\n log_spec = (log_spec + 4.0) / 4.0\n return log_spec\n```\n###test function signature:\n```python\n\ndef test_audio():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_audio", + "reference": "\ndef test_audio():\n audio_path = os.path.join(os.path.dirname(__file__), \"jfk.flac\")\n audio = load_audio(audio_path)\n assert audio.ndim == 1\n assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12\n assert 0 < audio.std() < 1\n\n mel_from_audio = log_mel_spectrogram(audio)\n mel_from_file = log_mel_spectrogram(audio_path)\n\n assert np.allclose(mel_from_audio, mel_from_file)\n assert mel_from_audio.max() - mel_from_audio.min() <= 2.0", + "completions": [ + "```python\ndef test_audio():\n audio = torch.randn(1, 16000)\n log_mel_spectrogram(audio)\n```\n" + ], + "line": 24, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 74, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef transform_grid_cell(\n grid: Grid,\n cell: npt.NDArray[np.float_] = np.eye(3),\n output_shape: Optional[Tuple[int, int, int]] = None,\n mode: str = \"constant\",\n order: int = 1,\n **kwargs,\n) -> Grid:\n # Take the current shape of the grid if no output shape was provided\n if output_shape is None:\n output_shape = grid.shape\n\n # Make sure the cell has type float\n cell = np.asarray(cell, dtype=float)\n\n # Get the current cell in coordinates of the destination axes\n inv_cell = cell_invert(cell).T\n projected_cell = grid.cell.dot(inv_cell)\n\n # From that, infere how long will the bounding box of the cell be\n lengths = abs(projected_cell).sum(axis=0)\n\n # Create the transformation matrix. Since we want to control the shape\n # of the output, we can not use grid.dcell directly, we need to modify it.\n scales = output_shape / lengths\n forward_t = (grid.dcell.dot(inv_cell) * scales).T\n\n # Scipy's affine transform asks for the inverse transformation matrix, to\n # map from output pixels to input pixels. By taking the inverse of our\n # transformation matrix, we get exactly that.\n tr = cell_invert(forward_t).T\n\n # Calculate the offset of the image so that all points of the grid \"fall\" inside\n # the output array.\n # For this we just calculate the centers of the input and output images\n center_input = 0.5 * (_a.asarrayd(grid.shape) - 1)\n center_output = 0.5 * (_a.asarrayd(output_shape) - 1)\n\n # And then make sure that the input center that is interpolated from the output\n # falls in the actual input's center\n offset = center_input - tr.dot(center_output)\n\n # We pass all the parameters to scipy's affine_transform\n transformed_image = affine_transform(\n grid.grid,\n tr,\n order=1,\n offset=offset,\n output_shape=output_shape,\n mode=mode,\n **kwargs,\n )\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1))\n new_grid.grid = transformed_image\n new_grid.geometry = grid.geometry\n new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset))\n\n # Find the offset between the origin before and after the transformation\n return new_grid\n```\n###test function signature:\n```python\n\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_grid_cell", + "reference": "\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10))\n\n assert new_grid.shape == (10, 10, 10)\n assert new_grid.lattice.is_cartesian()\n\n if not skewed:\n assert np.allclose(new_grid.lattice.cell, grid.lattice.cell)\n\n assert np.allclose(new_grid.grid, grid.grid) == (not skewed)\n assert not np.allclose(new_grid.grid, 0)\n\n assert np.allclose(new_grid.origin, grid.origin) == (not skewed)\n\n # Convert to a skewed cell\n directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]])\n new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5))\n\n assert new_grid.shape == (5, 5, 5)\n for i in range(3):\n n = new_grid.lattice.cell[i] / directions[i]\n assert np.allclose(n, n[0])", + "completions": [ + "```python\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n cell = grid.cell.dot(cell_invert(grid.cell))\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = transform_grid_cell(grid, cell, grid.shape, mode=\"constant\", order=1)\n\n # Check that the new grid has the same shape and cell as the original grid\n assert new_grid.shape == grid.shape\n assert np.allclose(new_grid.cell, grid.cell)\n\n # Check that the new grid has the same geometry as the original grid\n assert new_grid.geometry == grid.geometry\n\n # Check that the new grid has the same origin as the original grid\n assert np.allclose(new_grid.origin, grid.origin)\n\n # Check that the new grid has the same dcell as the original grid\n assert np.allclose(new_grid.dcell, grid.dcell)\n\n # Check that the new grid has the same grid as the original grid\n assert np.allclose(new_grid.grid, grid.grid)\n\n # Check that the new grid has the same skewed as the original grid\n assert np.allclose(new_grid.skewed, grid.skewed)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid" + ], + "line": 63, + "token": 533, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 75, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:\n\n index_norm = norm(index, axis=1)\n xq_norm = norm(xq.T)\n sim = np.dot(index, xq.T) / (index_norm * xq_norm)\n return sim\n```\n###test function signature:\n```python\ndef test_top_scores__is_sorted(test_index):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_top_scores__is_sorted", + "reference": "def test_top_scores__is_sorted(test_index):\n\n xq = test_index[0] # should have max similarity\n\n sim = similarity_matrix(xq, test_index)\n _, idx = top_scores(sim, 3)\n\n # Scores and indexes should be sorted ascending\n assert np.array_equal(idx, np.array([2, 1, 0]))", + "completions": [ + "```python\ndef test_top_scores__is_sorted(test_index):\n sim = similarity_matrix(xq, test_index)\n assert np.all(np.diff(sim, axis=1) >= 0)\n```\n" + ], + "line": 15, + "token": 125, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_mixed_ref_area_2", + "reference": "\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n\n out_bars_with_mixed_ref_area_2 = _generate_nanoplot(**CASES[16])\n\n assert _is_nanoplot_output(out_bars_with_mixed_ref_area_2)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n y_ref_area = [1, 10]\n y_ref_line = 5\n x_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n expand_x = \"2020-01-01 00:00:00 UTC\"\n expand_y = \"2020-01-01 00:00:00 UTC\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n all_single_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_" + ], + "line": 749, + "token": 7038, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 77, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def mean(self, lower=0.0, upper=np.inf):\n loc = []\n for idx in self.index:\n qpd = self.qpd.loc[idx, :].values[0]\n l, _ = quad(exp_func, args=(qpd), a=lower, b=upper)\n loc.append(l)\n loc_arr = np.array(loc)\n return pd.DataFrame(loc_arr, index=self.index, columns=self.columns)\n```\n###test function signature:\n```python\ndef test_qpd_b_simple_use():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_qpd_b_simple_use", + "reference": "def test_qpd_b_simple_use():\n from skpro.distributions.qpd import QPD_B\n\n qpd = QPD_B(\n alpha=0.2,\n qv_low=[1, 2],\n qv_median=[3, 4],\n qv_high=[5, 6],\n lower=0,\n upper=10,\n )\n\n qpd.mean()", + "completions": [ + "```python\ndef test_qpd_b_simple_use():\n import numpy as np\n import pandas as pd\n from scipy.integrate import quad\n from scipy.special import expit\n\n def exp_func(x, qpd):\n return expit(qpd * x)\n\n qpd = pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z'])\n qpd_b = qpd.mean(lower=0.0, upper=np.inf)\n assert qpd_b.equals(pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z']))\n```\n" + ], + "line": 18, + "token": 146, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef create(\n supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False\n) -> adapters.HTTPAdapter:\n settings = supply_settings()\n allowed_methods = set(util.Retry.DEFAULT_ALLOWED_METHODS)\n if include_post:\n allowed_methods.add(\"POST\")\n\n retries = util.Retry(\n total=settings.total,\n backoff_factor=settings.backoff_factor,\n status_forcelist=settings.statuses,\n allowed_methods=allowed_methods,\n )\n return adapters.HTTPAdapter(max_retries=retries)\n```\n###test function signature:\n```python\n\ndef test_should_create_adapter(include_post):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_should_create_adapter", + "reference": "\ndef test_should_create_adapter(include_post):\n total_retries = 5\n backoff_factor = 0.4\n statuses = (random.randint(400, 600),)\n settings = config.ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses)\n adapter = http_adapter.create(supply_settings=lambda: settings, include_post=include_post)\n assert adapter.max_retries.total == total_retries\n assert adapter.max_retries.backoff_factor == backoff_factor\n assert adapter.max_retries.status_forcelist == statuses\n allowed_methods = adapter.max_retries.allowed_methods or []\n if include_post:\n assert \"POST\" in allowed_methods\n else:\n assert \"POST\" not in allowed_methods", + "completions": [ + "```python\n\ndef test_should_create_adapter(include_post):\n assert create(include_post=include_post) == adapters.HTTPAdapter(max_retries=util.Retry(\n total=config.ResilienceSettings().total,\n backoff_factor=config.ResilienceSettings().backoff_factor,\n status_forcelist=config.ResilienceSettings().statuses,\n allowed_methods=set(util.Retry.DEFAULT_ALLOWED_METHODS) if not include_post else set(util.Retry.DEFAULT_ALLOWED_METHODS).union({\"POST\"})\n ))\n```\n" + ], + "line": 24, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 79, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:\n summary_path = os.path.join(trial_path, 'summary.csv')\n if not os.path.exists(summary_path):\n raise ValueError(f\"summary.csv does not exist in {trial_path}.\")\n trial_summary_df = load_summary_file(summary_path, dict_columns=['best_module_params'])\n config_yaml_path = os.path.join(trial_path, 'config.yaml')\n with open(config_yaml_path, 'r') as f:\n config_dict = yaml.safe_load(f)\n yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)\n if output_path is not None:\n with open(output_path, 'w') as f:\n yaml.dump(yaml_dict, f)\n return yaml_dict\n```\n###test function signature:\n```python\n\ndef test_runner(evaluator):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_runner", + "reference": "\ndef test_runner(evaluator):\n os.environ['BM25'] = 'bm25'\n evaluator.start_trial(os.path.join(resource_dir, 'simple.yaml'))\n project_dir = evaluator.project_dir\n\n def runner_test(runner: Runner):\n answer = runner.run('What is the best movie in Korea? Have Korea movie ever won Oscar?',\n 'retrieved_contents')\n assert len(answer) == 1\n assert isinstance(answer, list)\n assert isinstance(answer[0], str)\n\n runner = Runner.from_trial_folder(os.path.join(project_dir, '0'))\n runner_test(runner)\n runner_test(runner)\n\n with tempfile.NamedTemporaryFile(suffix='yaml', mode='w+t') as yaml_path:\n extract_best_config(os.path.join(project_dir, '0'), yaml_path.name)\n runner = Runner.from_yaml(yaml_path.name, project_dir=project_dir)\n runner_test(runner)", + "completions": [ + "```python\ndef test_runner(evaluator):\n def test_extract_best_config():\n trial_path = 'path/to/trial'\n output_path = 'path/to/output'\n yaml_dict = extract_best_config(trial_path, output_path)\n assert yaml_dict is not None\n assert isinstance(yaml_dict, dict)\n assert 'best_module_params' in yaml_dict\n assert isinstance(yaml_dict['best_module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['" + ], + "line": 23, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 80, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def init_weights(self):\n logger = MMLogger.get_current_instance()\n if self.init_cfg is None:\n logger.warn(f'No pre-trained weights for '\n f'{self.__class__.__name__}, '\n f'training start from scratch')\n if self.use_abs_pos_embed:\n trunc_normal_(self.absolute_pos_embed, std=0.02)\n for m in self.modules():\n if isinstance(m, nn.Linear):\n trunc_normal_init(m, std=.02, bias=0.)\n elif isinstance(m, nn.LayerNorm):\n constant_init(m, 1.0)\n else:\n assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n f'specify `Pretrained` in ' \\\n f'`init_cfg` in ' \\\n f'{self.__class__.__name__} '\n ckpt = CheckpointLoader.load_checkpoint(\n self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n if 'state_dict' in ckpt:\n _state_dict = ckpt['state_dict']\n elif 'model' in ckpt:\n _state_dict = ckpt['model']\n else:\n _state_dict = ckpt\n if self.convert_weights:\n # supported loading weight from original repo,\n _state_dict = swin_converter(_state_dict)\n\n state_dict = OrderedDict()\n for k, v in _state_dict.items():\n if k.startswith('backbone.'):\n state_dict[k[9:]] = v\n\n # strip prefix of state_dict\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n # reshape absolute position embedding\n if state_dict.get('absolute_pos_embed') is not None:\n absolute_pos_embed = state_dict['absolute_pos_embed']\n N1, L, C1 = absolute_pos_embed.size()\n N2, C2, H, W = self.absolute_pos_embed.size()\n if N1 != N2 or C1 != C2 or L != H * W:\n logger.warning('Error in loading absolute_pos_embed, pass')\n else:\n state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n # interpolate position bias table if needed\n relative_position_bias_table_keys = [\n k for k in state_dict.keys()\n if 'relative_position_bias_table' in k\n ]\n for table_key in relative_position_bias_table_keys:\n table_pretrained = state_dict[table_key]\n table_current = self.state_dict()[table_key]\n L1, nH1 = table_pretrained.size()\n L2, nH2 = table_current.size()\n if nH1 != nH2:\n logger.warning(f'Error in loading {table_key}, pass')\n elif L1 != L2:\n S1 = int(L1**0.5)\n S2 = int(L2**0.5)\n table_pretrained_resized = F.interpolate(\n table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n size=(S2, S2),\n mode='bicubic')\n state_dict[table_key] = table_pretrained_resized.view(\n nH2, L2).permute(1, 0).contiguous()\n\n # load state_dict\n self.load_state_dict(state_dict, False)\n```\n###test function signature:\n```python\ndef test_swin_transformer():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_swin_transformer", + "reference": "def test_swin_transformer():\n\n with pytest.raises(TypeError):\n # Pretrained arg must be str or None.\n SwinTransformer(pretrained=123)\n\n with pytest.raises(AssertionError):\n # Because swin uses non-overlapping patch embed, so the stride of patch\n # embed must be equal to patch size.\n SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)\n\n # test pretrained image size\n with pytest.raises(AssertionError):\n SwinTransformer(pretrain_img_size=(224, 224, 224))\n\n # Test absolute position embedding\n temp = torch.randn((1, 3, 224, 224))\n model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)\n model.init_weights()\n model(temp)\n\n # Test patch norm\n model = SwinTransformer(patch_norm=False)\n model(temp)\n\n # Test normal inference\n temp = torch.randn((1, 3, 32, 32))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 31, 31))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 112, 137))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 28, 35)\n assert outs[1].shape == (1, 192, 14, 18)\n assert outs[2].shape == (1, 384, 7, 9)\n assert outs[3].shape == (1, 768, 4, 5)\n\n model = SwinTransformer(frozen_stages=4)\n model.train()\n for p in model.parameters():\n assert not p.requires_grad", + "completions": [ + "```python\ndef test_swin_transformer():\n from mmcv.cnn import (build_pos_embed, build_pos_embed_sine,\n constant_init, trunc_normal_init)\n from mmcv.runner import CheckpointLoader\n from mmdet.models.builder import build_backbone\n from mmdet.models.utils import swin_converter\n from mmdet.utils import get_root_logger\n from mmdet.utils.logger import MMLogger\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import PatchEmbed\n from mmdet.models.backbones.swin import PatchMerging\n from mmdet.models.backbones.swin import BasicLayer\n from mmdet.models.backbones.swin import BasicBlock\n from mmdet.models.backbones.swin import DepthwiseSeparableConv\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransform" + ], + "line": 79, + "token": 586, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 81, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\n\ndef test_date_columns_equal():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_date_columns_equal", + "reference": "\ndef test_date_columns_equal():\n data = \"\"\"a|b|expected\n2017-01-01|2017-01-01|True\n2017-01-02|2017-01-02|True\n2017-10-01|2017-10-10|False\n2017-01-01||False\n|2017-01-01|False\n||True\"\"\"\n df = pl.read_csv(\n io.StringIO(data),\n separator=\"|\",\n null_values=[\"NULL\"],\n missing_utf8_is_empty_string=True,\n )\n # First compare just the strings\n actual_out = columns_equal(df[\"a\"], df[\"b\"], rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)\n\n # Then compare converted to datetime objects\n col_a = df[\"a\"].str.to_datetime(strict=False)\n col_b = df[\"b\"].str.to_datetime(strict=False)\n actual_out = columns_equal(col_a, col_b, rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)", + "completions": [ + "```python\ndef test_date_columns_equal():\n col_1 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n col_2 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 82, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def upsert_child_firing(\n db: PrefectDBInterface,\n session: AsyncSession,\n firing: Firing,\n):\n automation_id = firing.trigger.automation.id\n parent_trigger_id = firing.trigger.parent.id\n child_trigger_id = firing.trigger.id\n\n upsert = (\n postgresql.insert(db.CompositeTriggerChildFiring)\n .values(\n automation_id=automation_id,\n parent_trigger_id=parent_trigger_id,\n child_trigger_id=child_trigger_id,\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n )\n .on_conflict_do_update(\n index_elements=[\n db.CompositeTriggerChildFiring.automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id,\n ],\n set_=dict(\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n updated=pendulum.now(\"UTC\"),\n ),\n )\n )\n\n await session.execute(upsert)\n\n result = await session.execute(\n sa.select(db.CompositeTriggerChildFiring).filter(\n db.CompositeTriggerChildFiring.automation_id == automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id == parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id == child_trigger_id,\n )\n )\n\n return result.scalars().one()\n```\n###test function signature:\n```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_insert_compound_trigger_child_firing", + "reference": "\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n firing = Firing(\n trigger=saw_a_red_dragon,\n trigger_states={TriggerState.Triggered},\n triggered=frozen_time.add(seconds=1),\n triggering_labels={},\n triggering_event=baby_red_dragon_passed_by,\n )\n res = await upsert_child_firing(\n session=automations_session,\n firing=firing,\n )\n\n assert res and res.id\n assert res.automation_id == raise_the_alarm.id\n assert res.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert res.child_trigger_id == saw_a_red_dragon.id\n assert res.child_firing_id == firing.id\n assert res.child_fired_at == firing.triggered\n\n child_firings = await get_child_firings(\n session=automations_session, trigger=saw_a_red_and_green_dragon\n )\n assert len(child_firings) == 1\n assert child_firings[0].id == res.id\n assert child_firings[0].child_firing_id == firing.id", + "completions": [ + "```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n \"\"\"\n Test that a compound trigger child firing is inserted correctly\n \"\"\"\n # Arrange\n firing = Firing(\n id=uuid4(),\n triggered=frozen_time,\n trigger=saw_a_red_and_green_dragon,\n event=baby_red_dragon_passed_by,\n )\n\n # Act\n result = await upsert_child_firing(\n db=automations_session,\n session=automations_session,\n firing=firing,\n )\n\n # Assert\n assert result.automation_id == raise_the_alarm.id\n assert result.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert result.child_trigger_id == saw_a_red_dragon.id\n assert result.child_firing_id == firing.id\n assert result.child_fired_at == firing.triggered\n assert result.child_firing == firing.dict()\n assert result.updated == frozen_time\n\n```\n" + ], + "line": 58, + "token": 225, + "line_diff": 2, + "token_diff": 0 + }, + { + "id": 83, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef write_output(output_file, output_type, results):\n\n # If json or txt, write contents directly to file.\n if output_type == \"json\" or output_type == \"txt\":\n with open(output_file, \"w\") as f:\n json.dump(results, f, indent=4)\n return\n # If csv or xlsx, convert results to pandas dataframes.\n urls_df = get_urls_df(results)\n codes_df = get_codes_df(results)\n\n # If csv, write dataframes to separate csv files for urls, codes.\n if output_type == \"csv\":\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n urls_df.to_csv(urls_output_file, index=False)\n codes_df.to_csv(codes_output_file, index=False)\n\n # If xlsx, write dataframes to separate sheets for urls, codes.\n if output_type == \"xlsx\":\n writer = pd.ExcelWriter(output_file, engine=\"xlsxwriter\")\n urls_df.to_excel(writer, sheet_name=\"URLs\", index=False)\n codes_df.to_excel(writer, sheet_name=\"Codes\", index=False)\n writer.close()\n```\n###test function signature:\n```python\n def test_write_output_csv(self, mock_urls, mock_codes):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_output_csv", + "reference": " def test_write_output_csv(self, mock_urls, mock_codes):\n\n test_file = \"./test_output/test_file.csv\"\n test_file_urls = \"./test_output/test_file_urls.csv\"\n test_file_codes = \"./test_output/test_file_codes.csv\"\n test_results = {\"test\": \"test\"}\n mock_urls.return_value = pd.DataFrame([test_results])\n mock_codes.return_value = pd.DataFrame([test_results])\n\n with open(test_file_urls, \"w\") as f:\n pass\n\n with open(test_file_codes, \"w\") as f:\n pass\n\n write_output(test_file, \"csv\", test_results)\n\n test_data_urls = pd.read_csv(test_file_urls).to_dict(orient=\"records\")[0]\n test_data_codes = pd.read_csv(test_file_codes).to_dict(orient=\"records\")[0]\n\n os.remove(test_file_urls)\n os.remove(test_file_codes)\n\n self.assertEqual(test_data_urls, test_results)\n self.assertEqual(test_data_codes, test_results)", + "completions": [ + "```python\ndef test_write_output_csv(self, mock_urls, mock_codes):\n \"\"\"\n Test write_output function with csv output type.\n \"\"\"\n output_file = \"output.csv\"\n output_type = \"csv\"\n results = {\"urls\": mock_urls, \"codes\": mock_codes}\n write_output(output_file, output_type, results)\n\n # Check that the urls and codes dataframes are written to separate csv files.\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n self.assertTrue(os.path.isfile(urls_output_file))\n self.assertTrue(os.path.isfile(codes_output_file))\n\n # Check that the dataframes are correct.\n urls_df = pd.read_csv(urls_output_file)\n codes_df = pd.read_csv(codes_output_file)\n self.assertEqual(urls_df.shape, (2, 2))\n self.assertEqual(codes_df.shape, (2, 2))\n self.assertEqual(urls_df.iloc[0, 0], \"url1\")\n self.assertEqual(codes_df.iloc[0, 0], \"code1\")\n self.assertEqual(urls_df.iloc[1, 0], \"url2\")\n self.assertEqual(codes_df.iloc[1, 0], \"code2\")\n```\n" + ], + "line": 31, + "token": 279, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 84, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def terminate_idle_instance() -> None:\n async with get_session_ctx() as session:\n async with PROCESSING_POOL_LOCK:\n res = await session.execute(\n select(InstanceModel)\n .where(\n InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE,\n InstanceModel.deleted == False,\n InstanceModel.job == None, # noqa: E711\n InstanceModel.status == InstanceStatus.IDLE,\n )\n .options(joinedload(InstanceModel.project))\n )\n instances = res.scalars().all()\n\n for instance in instances:\n last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)\n if instance.last_job_processed_at is not None:\n last_time = instance.last_job_processed_at.replace(\n tzinfo=datetime.timezone.utc\n )\n\n idle_seconds = instance.termination_idle_time\n delta = datetime.timedelta(seconds=idle_seconds)\n\n current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc)\n\n if last_time + delta < current_time:\n jpd = JobProvisioningData.__response__.parse_raw(\n instance.job_provisioning_data\n )\n await terminate_job_provisioning_data_instance(\n project=instance.project, job_provisioning_data=jpd\n )\n instance.deleted = True\n instance.deleted_at = get_current_datetime()\n instance.finished_at = get_current_datetime()\n instance.status = InstanceStatus.TERMINATED\n instance.termination_reason = \"Idle timeout\"\n\n idle_time = current_time - last_time\n logger.info(\n \"Instance %s terminated by termination policy: idle time %ss\",\n instance.name,\n str(idle_time.seconds),\n extra={\n \"instance_name\": instance.name,\n \"instance_status\": InstanceStatus.TERMINATED.value,\n },\n )\n\n await session.commit()\n```\n###test function signature:\n```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_terminate_by_idle_timeout", + "reference": "\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n project = await create_project(session=session)\n pool = await create_pool(session, project)\n\n instance = await create_instance(session, project, pool, status=InstanceStatus.IDLE)\n instance.termination_idle_time = 300\n instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE\n instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)\n await session.commit()\n\n with patch(\n \"dstack._internal.server.background.tasks.process_instances.terminate_job_provisioning_data_instance\"\n ):\n await terminate_idle_instance()\n\n await session.refresh(instance)\n\n assert instance is not None\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"", + "completions": [ + "```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n # given\n project = ProjectModel(name=\"test_project\")\n await session.add(project)\n await session.commit()\n\n instance = InstanceModel(\n name=\"test_instance\",\n project=project,\n termination_policy=TerminationPolicy.DESTROY_AFTER_IDLE,\n termination_idle_time=10,\n status=InstanceStatus.IDLE,\n created_at=datetime.datetime.now(datetime.timezone.utc),\n )\n await session.add(instance)\n await session.commit()\n\n # when\n await terminate_idle_instance()\n\n # then\n instance = await session.get(InstanceModel, instance.id)\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"\n```\n" + ], + "line": 56, + "token": 289, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 85, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef nargsort(\n items,\n kind: str = \"quicksort\",\n ascending: bool = True,\n na_position: str = \"last\",\n key: Callable | None = None,\n mask: npt.NDArray[np.bool_] | None = None,\n) -> npt.NDArray[np.intp]:\n\n if key is not None:\n items = ensure_key_mapped(items, key)\n return nargsort(\n items,\n kind=kind,\n ascending=ascending,\n na_position=na_position,\n key=None,\n mask=mask,\n )\n\n if isinstance(items, ABCRangeIndex):\n return items.argsort(ascending=ascending) # TODO: test coverage with key?\n elif not isinstance(items, ABCMultiIndex):\n items = extract_array(items)\n if mask is None:\n mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?\n\n if is_extension_array_dtype(items):\n return items.argsort(ascending=ascending, kind=kind, na_position=na_position)\n else:\n items = np.asanyarray(items)\n\n idx = np.arange(len(items))\n non_nans = items[~mask]\n non_nan_idx = idx[~mask]\n\n nan_idx = np.nonzero(mask)[0]\n if not ascending:\n non_nans = non_nans[::-1]\n non_nan_idx = non_nan_idx[::-1]\n indexer = non_nan_idx[non_nans.argsort(kind=kind)]\n if not ascending:\n indexer = indexer[::-1]\n # Finally, place the NaNs at the end or the beginning according to\n # na_position\n if na_position == \"last\":\n indexer = np.concatenate([indexer, nan_idx])\n elif na_position == \"first\":\n indexer = np.concatenate([nan_idx, indexer])\n else:\n raise ValueError(f\"invalid na_position: {na_position}\")\n return ensure_platform_int(indexer)\n```\n###test function signature:\n```python\n\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nargsort", + "reference": "\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)\n\n # mergesort is the most difficult to get right because we want it to be\n # stable.\n\n # According to numpy/core/tests/test_multiarray, \"\"\"The number of\n # sorted items must be greater than ~50 to check the actual algorithm\n # because quick and merge sort fall over to insertion sort for small\n # arrays.\"\"\"\n\n result = nargsort(\n items, kind=\"mergesort\", ascending=ascending, na_position=na_position\n )\n tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)", + "completions": [ + "```python\n def test_nargsort(self, ascending, na_position, exp, box):\n items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216," + ], + "line": 58, + "token": 412, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 86, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_rotated_rb_circuits(\n n_qubits: int,\n num_cliffords: int,\n theta: Optional[float] = None,\n trials: int = 1,\n return_type: Optional[str] = None,\n seed: Optional[int] = None,\n) -> List[QPROGRAM]:\n\n circuits = cast(\n List[cirq.Circuit],\n generate_rb_circuits(n_qubits, num_cliffords, trials, seed=seed),\n )\n\n if theta is None:\n theta = 2 * np.pi * np.random.rand()\n\n for circ in circuits:\n qubits = list(circ.all_qubits())\n circ.insert(len(circ) // 2, cirq.Rz(rads=theta).on(qubits[0]))\n\n return_type = \"cirq\" if not return_type else return_type\n return [convert_from_mitiq(circuit, return_type) for circuit in circuits]\n```\n###test function signature:\n```python\n\ndef test_rotated_rb_circuits(n_qubits, theta):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rotated_rb_circuits", + "reference": "\ndef test_rotated_rb_circuits(n_qubits, theta):\n depth = 10\n for trials in [2, 3]:\n circuits = generate_rotated_rb_circuits(\n n_qubits=n_qubits, num_cliffords=depth, theta=theta, trials=trials\n )\n for circ in circuits:\n zero_prob = (\n cirq.DensityMatrixSimulator()\n .simulate(circ)\n .final_density_matrix[0, 0]\n .real\n )\n assert -1.0001 <= zero_prob <= 1.0001", + "completions": [ + "```python\ndef test_rotated_rb_circuits(n_qubits, theta):\n from cirq import Circuit, LineQubit\n from cirq.testing import assert_has_diagram\n from mitiq.generate import generate_rotated_rb_circuits\n\n circuits = generate_rotated_rb_circuits(n_qubits, 10, theta)\n for circuit in circuits:\n assert_has_diagram(\n circuit,\n f\"\"\"\n0: ───Rz({theta})───\n1: ───Rz({theta})───\n2: ───Rz({theta})───\n3: ───Rz({theta})───\n4: ───Rz({theta})───\n5: ───Rz({theta})───\n6: ───Rz({theta})───\n7: ───Rz({theta})───\n8: ───Rz({theta})───\n9: ───Rz({theta})───\n10: ───Rz({theta})───\n11: ───Rz({theta})───\n12: ───Rz({theta})───\n13: ───Rz({theta})───\n14: ───Rz({theta})───\n15: ───Rz({theta})───\n16: ───Rz({theta})───\n17: ───Rz({theta})───\n18: ───Rz({theta})───\n19: ───Rz({theta})───\n20: ───Rz({theta})───\n21: ───Rz({theta})───\n22: ───Rz({theta})───\n23: ───Rz({theta})───\n24: ───Rz({theta})───\n25: ───Rz({theta})───\n26: ───Rz({theta})───\n27: ───Rz({theta})───\n28: ───Rz({theta})───\n29: ───Rz({theta})───\n30: ───Rz({theta})───\n31: ───Rz({theta})───\n32: ───Rz({theta})───\n33: ───Rz({theta})───\n34: ───Rz({theta})───\n35: ───Rz({theta})───\n36: ───Rz({theta})───\n37: ───Rz({theta})───\n38: ───Rz({theta})───\n39: ───Rz({theta})───\n40: ───Rz({theta})───\n41: ───Rz({theta})───\n42: ───Rz({theta})───\n43: ───Rz({theta})───\n44: ───Rz({theta})───\n45: ───Rz({theta})───\n46: ───Rz({theta})───\n47: ───Rz({theta})───\n48: ───Rz({theta})───\n49: ───Rz({theta})───\n50: ───Rz({theta})───\n51: ───Rz({theta})───\n52: ───Rz({theta})───\n53: ───Rz({theta})───\n54: ───Rz({theta})───\n55: ───Rz({theta})───\n56: ───Rz({theta})───\n57: ───Rz({theta})───\n58: ───Rz({theta})───\n59: ───Rz({theta})───\n60: ───Rz({theta})───\n61: ───Rz({theta})───\n62: ───Rz({theta})───\n63: ───" + ], + "line": 29, + "token": 226, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 87, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef eye(N, M=None, k=None, dtype=None):\n dtype = to_torch_dtype(dtype or config.floatx())\n M = N if M is None else M\n k = 0 if k is None else k\n if k == 0:\n # TODO: torch.eye doesn't support bfloat16 with cpu\n if get_device() == \"cpu\" and dtype == torch.bfloat16:\n return cast(\n torch.eye(\n N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n ),\n dtype,\n )\n return torch.eye(N, M, dtype=dtype, device=get_device())\n diag_length = builtins.max(N, M)\n diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n return torch.diag(diag, diagonal=k)[:N, :M]\n```\n###test function signature:\n```python\n\n def test_eye(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_eye", + "reference": "\n def test_eye(self):\n self.assertAllClose(knp.eye(3), np.eye(3))\n self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n self.assertAllClose(knp.Eye()(3), np.eye(3))\n self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1))\n\n # Test k >= N\n self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3))\n\n # Test k > 0 and N >= M\n self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1))\n\n # Test k > 0 and N < M and N + k > M\n self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2))\n\n # Test k < 0 and M >= N\n self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1))\n\n # Test k < 0 and M < N and M - k > N\n self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))", + "completions": [ + "```python\n\n def test_eye(self):\n self.assertEqual(eye(3), torch.eye(3))\n self.assertEqual(eye(3, 4), torch.eye(3, 4))\n self.assertEqual(eye(3, 4, 1), torch.eye(3, 4, 1))\n self.assertEqual(eye(3, 4, -1), torch.eye(3, 4, -1))\n self.assertEqual(eye(3, 4, 2), torch.eye(3, 4, 2))\n self.assertEqual(eye(3, 4, -2), torch.eye(3, 4, -2))\n self.assertEqual(eye(3, 4, 3), torch.eye(3, 4, 3))\n self.assertEqual(eye(3, 4, -3), torch.eye(3, 4, -3))\n self.assertEqual(eye(3, 4, 4), torch.eye(3, 4, 4))\n self.assertEqual(eye(3, 4, -4), torch.eye(3, 4, -4))\n self.assertEqual(eye(3, 4, 5), torch.eye(3, 4, 5))\n self.assertEqual(eye(3, 4, -5), torch.eye(3, 4, -5))\n self.assertEqual(eye(3, 4, 6), torch.eye(3, 4, 6))\n self.assertEqual(eye(3, 4, -6), torch.eye(3, 4, -6))\n self.assertEqual(eye(3, 4, 7), torch.eye(3, 4, 7))\n self.assertEqual(eye(3, 4, -7), torch.eye(3, 4, -7))\n self.assertEqual(eye(3, 4, 8), torch.eye(3, 4, 8))\n self.assertEqual(eye(3, 4, -8), torch.eye(3, 4, -8))\n self.assertEqual(eye(3, 4, 9), torch.eye(3, 4, 9))\n self.assertEqual(eye(3, 4, -9), torch.eye(3, 4, -9))\n self.assertEqual(eye(3, 4, 10), torch.eye(3, 4, 10))\n self.assertEqual(eye(3, 4, -10), torch.eye(3, 4, -10))\n self.assertEqual(eye(3, 4, 11), torch.eye(3, 4, 11))\n self.assertEqual(eye(3, 4, -11), torch.eye(3, 4, -11))\n self.assertEqual(eye(3, 4, 12), torch.eye(3, 4, 12))\n self.assertEqual(eye(3, 4, -12), torch.eye(3, 4, -12))\n self.assertEqual(eye(3, 4, 13), torch.eye(3, 4, 13))\n self.assertEqual(eye(3, 4, -13), torch.eye(3, 4, -13))\n self.assertEqual(eye(3, 4, 14), torch.eye(3, 4, 14))\n self.assertEqual(eye(3, 4, -14), torch.eye(3, 4, -14))\n self.assertEqual(eye(3, 4, 15), torch.eye(" + ], + "line": 27, + "token": 219, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef ragged_qk_dotprod(\n query: RaggedActivations, key: RaggedActivations, lut: RaggedQkPidLookupTable\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n # check constraints\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n # allocates output\n # max_n_ctx_q_across_seqs = query.max_n_ctx_per_seq\n\n assert query.n_seqs == key.n_seqs\n # TODO: flag use zeros for garbage\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n # Integers\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n```\n###test function signature:\n```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ragged_qk_dotprod_single_seq", + "reference": "\ndef test_ragged_qk_dotprod_single_seq(dtype):\n d_head = 2\n\n key = RaggedActivations.from_list(\n [\n make_seq(n_ctx=3, value=42, d_model=d_head, dtype=dtype),\n ]\n )\n query = RaggedActivations.from_list(\n [\n make_seq(n_ctx=4, value=55, d_model=d_head, dtype=dtype),\n ]\n )\n torch_scores = scores_via_qk_dotprod(query, key)\n print(f\"{torch_scores=}\")\n\n lut = RaggedQkPidLookupTable.from_query_and_key_tokens_per_seq(\n n_ctx_q_per_seq=query.n_ctx_per_seq, n_ctx_k_per_seq=key.n_ctx_per_seq\n )\n\n scores = ragged_qk_dotprod(query, key, lut)\n assert_eq(torch_scores, scores)", + "completions": [ + "```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n # query\n query_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n query = RaggedActivations(\n raw_tensor=query_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # key\n key_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n key = RaggedActivations(\n raw_tensor=key_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # lut\n pid_to_in_q_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_in_k_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_q_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_k_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_seq_idx = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n lut = RaggedQkPidLookupTable(\n pid_to_in_q_token_offset=pid_to_in_q_token_offset,\n pid_to_in_k_token_offset=pid_to_in_k_token_offset,\n pid_to_out_q_block=pid_to_out_q_block,\n pid_to_out_k_block=pid_to_out_k_block,\n pid_to_out_seq_idx=pid_to_out_seq_idx,\n n_pids_total=1,\n )\n\n # scores_out\n scores_out = torch.ones(\n (1, 1, 1), dtype=dtype, device=torch.device(\"cuda\")\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_" + ], + "line": 57, + "token": 310, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 89, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):\n # Ravel both arrays, behavior for the first array could be different\n ar1 = np.asarray(ar1).ravel()\n ar2 = np.asarray(ar2).ravel()\n\n # Ensure that iteration through object arrays yields size-1 arrays\n if ar2.dtype == object:\n ar2 = ar2.reshape(-1, 1)\n\n if kind not in {None, 'sort', 'table'}:\n raise ValueError(\n f\"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.\")\n\n # Can use the table method if all arrays are integers or boolean:\n is_int_arrays = all(ar.dtype.kind in (\"u\", \"i\", \"b\") for ar in (ar1, ar2))\n use_table_method = is_int_arrays and kind in {None, 'table'}\n\n if use_table_method:\n if ar2.size == 0:\n if invert:\n return np.ones_like(ar1, dtype=bool)\n else:\n return np.zeros_like(ar1, dtype=bool)\n\n # Convert booleans to uint8 so we can use the fast integer algorithm\n if ar1.dtype == bool:\n ar1 = ar1.astype(np.uint8)\n if ar2.dtype == bool:\n ar2 = ar2.astype(np.uint8)\n\n ar2_min = np.min(ar2)\n ar2_max = np.max(ar2)\n\n ar2_range = int(ar2_max) - int(ar2_min)\n\n # Constraints on whether we can actually use the table method:\n # 1. Assert memory usage is not too large\n below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)\n # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype\n range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max\n # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype\n if ar1.size > 0:\n ar1_min = np.min(ar1)\n ar1_max = np.max(ar1)\n\n # After masking, the range of ar1 is guaranteed to be\n # within the range of ar2:\n ar1_upper = min(int(ar1_max), int(ar2_max))\n ar1_lower = max(int(ar1_min), int(ar2_min))\n\n range_safe_from_overflow &= all((\n ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,\n ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min\n ))\n\n # Optimal performance is for approximately\n # log10(size) > (log10(range) - 2.27) / 0.927.\n # However, here we set the requirement that by default\n # the intermediate array can only be 6x\n # the combined memory allocation of the original\n # arrays. See discussion on \n # https://github.com/numpy/numpy/pull/12065.\n\n if (\n range_safe_from_overflow and \n (below_memory_constraint or kind == 'table')\n ):\n\n if invert:\n outgoing_array = np.ones_like(ar1, dtype=bool)\n else:\n outgoing_array = np.zeros_like(ar1, dtype=bool)\n\n # Make elements 1 where the integer exists in ar2\n if invert:\n isin_helper_ar = np.ones(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 0\n else:\n isin_helper_ar = np.zeros(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 1\n\n # Mask out elements we know won't work\n basic_mask = (ar1 <= ar2_max) & (ar1 >= ar2_min)\n outgoing_array[basic_mask] = isin_helper_ar[ar1[basic_mask] -\n ar2_min]\n\n return outgoing_array\n elif kind == 'table': # not range_safe_from_overflow\n raise RuntimeError(\n \"You have specified kind='table', \"\n \"but the range of values in `ar2` or `ar1` exceed the \"\n \"maximum integer of the datatype. \"\n \"Please set `kind` to None or 'sort'.\"\n )\n elif kind == 'table':\n raise ValueError(\n \"The 'table' method is only \"\n \"supported for boolean or integer arrays. \"\n \"Please select 'sort' or None for kind.\"\n )\n\n\n # Check if one of the arrays may contain arbitrary objects\n contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject\n\n # This code is run when\n # a) the first condition is true, making the code significantly faster\n # b) the second condition is true (i.e. `ar1` or `ar2` may contain\n # arbitrary objects), since then sorting is not guaranteed to work\n if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:\n if invert:\n mask = np.ones(len(ar1), dtype=bool)\n for a in ar2:\n mask &= (ar1 != a)\n else:\n mask = np.zeros(len(ar1), dtype=bool)\n for a in ar2:\n mask |= (ar1 == a)\n return mask\n\n # Otherwise use sorting\n if not assume_unique:\n ar1, rev_idx = np.unique(ar1, return_inverse=True)\n ar2 = np.unique(ar2)\n\n ar = np.concatenate((ar1, ar2))\n # We need this to be a stable sort, so always use 'mergesort'\n # here. The values from the first array should always come before\n # the values from the second array.\n order = ar.argsort(kind='mergesort')\n sar = ar[order]\n if invert:\n bool_ar = (sar[1:] != sar[:-1])\n else:\n bool_ar = (sar[1:] == sar[:-1])\n flag = np.concatenate((bool_ar, [invert]))\n ret = np.empty(ar.shape, dtype=bool)\n ret[order] = flag\n\n if assume_unique:\n return ret[:len(ar1)]\n else:\n return ret[rev_idx]\n```\n###test function signature:\n```python\n\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_in1d_ravel", + "reference": "\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n a = np.arange(6).reshape(2, 3)\n b = np.arange(3, 9).reshape(3, 2)\n long_b = np.arange(3, 63).reshape(30, 2)\n ec = np.array([False, False, False, True, True, True])\n\n assert_array_equal(in1d(a, b, assume_unique=True, kind=kind),\n ec)\n assert_array_equal(in1d(a, b, assume_unique=False,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=True,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=False,\n kind=kind),\n ec)", + "completions": [ + "```python\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n ar1 = np.array([1, 2, 3, 4, 5])\n ar2 = np.array([1, 2, 3, 4, 5])\n ar1_ravel = ar1.ravel()\n ar2_ravel = ar2.ravel()\n self.assertTrue(np.array_equal(ar1_ravel, ar1))\n self.assertTrue(np.array_equal(ar2_ravel, ar2))\n self.assertTrue(np.array_equal(in1d(ar1, ar2, kind=kind),\n in1d(ar1_ravel, ar2_ravel, kind=kind)))\n```\n" + ], + "line": 133, + "token": 1112, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 90, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_metadata", + "reference": "\ndef test_header_to_metadata_and_cell_metadata():\n text = \"\"\"---\ntitle: Sample header\njupyter:\n mainlanguage: python\n---\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {\"mainlanguage\": \"python\"}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {\"lines_to_next_cell\": 0}\n assert pos == len(lines)", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_metadata():\n lines = [\n \"#' ---\",\n \"#' jupyter:\",\n \"#' jupytext:\",\n \"#' text_representation:\",\n \"#' extension: .py\",\n \"#' format_name: light\",\n \"#' format_version: '1.5'\",\n \"#' jupytext_version: 1.11.4\",\n \"#' kernelspec:\",\n \"#' display_name: Python 3\",\n \"#' language: python\",\n \"#' name: python3\",\n \"#' ---\",\n \"#'\",\n \"#' This is a comment\",\n \"#'\",\n \"#' And another one\",\n \"#'\",\n \"#' ```python\",\n \"#' import numpy as np\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' import matplotlib.pyplot as plt\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %matplotlib inline\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2" + ], + "line": 104, + "token": 642, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 91, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef import_optional_dependency(\n name: str,\n extra: str = \"\",\n errors: str = \"raise\",\n min_version: str | None = None,\n):\n\n assert errors in {\"warn\", \"raise\", \"ignore\"}\n\n package_name = INSTALL_MAPPING.get(name)\n install_name = package_name if package_name is not None else name\n\n msg = (\n f\"Missing optional dependency '{install_name}'. {extra} \"\n f\"Use pip or conda to install {install_name}.\"\n )\n try:\n module = importlib.import_module(name)\n except ImportError:\n if errors == \"raise\":\n raise ImportError(msg)\n return None\n\n # Handle submodules: if we have submodule, grab parent module from sys.modules\n parent = name.split(\".\")[0]\n if parent != name:\n install_name = parent\n module_to_get = sys.modules[install_name]\n else:\n module_to_get = module\n minimum_version = min_version if min_version is not None else VERSIONS.get(parent)\n if minimum_version:\n version = get_version(module_to_get)\n if version and Version(version) < Version(minimum_version):\n msg = (\n f\"Pandas requires version '{minimum_version}' or newer of '{parent}' \"\n f\"(version '{version}' currently installed).\"\n )\n if errors == \"warn\":\n warnings.warn(\n msg,\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return None\n elif errors == \"raise\":\n raise ImportError(msg)\n\n return module\n```\n###test function signature:\n```python\n\ndef test_bad_version(monkeypatch):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_bad_version", + "reference": "\ndef test_bad_version(monkeypatch):\n name = \"fakemodule\"\n module = types.ModuleType(name)\n module.__version__ = \"0.9.0\"\n sys.modules[name] = module\n monkeypatch.setitem(VERSIONS, name, \"1.0.0\")\n\n match = \"Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'\"\n with pytest.raises(ImportError, match=match):\n import_optional_dependency(\"fakemodule\")\n\n # Test min_version parameter\n result = import_optional_dependency(\"fakemodule\", min_version=\"0.8\")\n assert result is module\n\n with tm.assert_produces_warning(UserWarning):\n result = import_optional_dependency(\"fakemodule\", errors=\"warn\")\n assert result is None\n\n module.__version__ = \"1.0.0\" # exact match is OK\n result = import_optional_dependency(\"fakemodule\")\n assert result is module", + "completions": [ + "```python\ndef test_bad_version(monkeypatch):\n monkeypatch.setattr(import_optional_dependency, \"get_version\", lambda x: \"0.1.0\")\n monkeypatch.setattr(import_optional_dependency, \"VERSIONS\", {\"pandas\": \"0.2.0\"})\n with pytest.raises(ImportError):\n import_optional_dependency(\"pandas\")\n```\n" + ], + "line": 54, + "token": 352, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 92, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_mel_spectrogram(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mel_spectrogram", + "reference": "\n def test_mel_spectrogram(self):\n waveform = self._load_datasamples(1)[0]\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=513,\n num_mel_filters=13,\n min_frequency=100,\n max_frequency=4000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"htk\",\n )\n self.assertEqual(mel_filters.shape, (513, 13))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n )\n self.assertEqual(spec.shape, (513, 732))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n mel_filters=mel_filters,\n )\n self.assertEqual(spec.shape, (13, 732))\n\n # fmt: off\n expected = np.array([\n 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,\n 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,\n 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,\n 9.44153646e-04\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:, 300], expected))", + "completions": [ + "```python\n def test_mel_spectrogram(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n mel_floor = 1e-10\n log_mel = \"log\"\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (1, 10))\n self.assertEqual(spectrogram[0][0], 1)\n self.assertEqual(spectrogram[0][1], 2)\n self.assertEqual(spectrogram[0][2], 3)\n self.assertEqual(spectrogram[0][3], 4)\n self.assertEqual(spectrogram[0][4], 5)\n self.assertEqual(spectrogram[0][5], 6)\n self.assertEqual(spectrogram[0][6], 7)\n self.assertEqual(spectrogram[0][7], 8)\n self.assertEqual(spectrogram[0][8], 9)\n self.assertEqual(spectrogram[0][9], 10)\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 93, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef create(\n meta_schema,\n validators=(),\n version=None,\n type_checker=_types.draft202012_type_checker,\n format_checker=_format.draft202012_format_checker,\n id_of=_id_of,\n applicable_validators=methodcaller(\"items\"),\n):\n # preemptively don't shadow the `Validator.format_checker` local\n format_checker_arg = format_checker\n\n @attr.s\n class Validator:\n\n VALIDATORS = dict(validators)\n META_SCHEMA = dict(meta_schema)\n TYPE_CHECKER = type_checker\n FORMAT_CHECKER = format_checker_arg\n ID_OF = staticmethod(id_of)\n\n schema = attr.ib(repr=reprlib.repr)\n resolver = attr.ib(default=None, repr=False)\n format_checker = attr.ib(default=None)\n\n def __init_subclass__(cls):\n warnings.warn(\n (\n \"Subclassing validator classes is not intended to \"\n \"be part of their public API. A future version \"\n \"will make doing so an error, as the behavior of \"\n \"subclasses isn't guaranteed to stay the same \"\n \"between releases of jsonschema. Instead, prefer \"\n \"composition of validators, wrapping them in an object \"\n \"owned entirely by the downstream library.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n\n def __attrs_post_init__(self):\n if self.resolver is None:\n self.resolver = RefResolver.from_schema(\n self.schema,\n id_of=id_of,\n )\n\n @classmethod\n def check_schema(cls, schema, format_checker=_UNSET):\n Validator = validator_for(cls.META_SCHEMA, default=cls)\n if format_checker is _UNSET:\n format_checker = Validator.FORMAT_CHECKER\n validator = Validator(\n schema=cls.META_SCHEMA,\n format_checker=format_checker,\n )\n for error in validator.iter_errors(schema):\n raise exceptions.SchemaError.create_from(error)\n\n def evolve(self, **changes):\n # Essentially reproduces attr.evolve, but may involve instantiating\n # a different class than this one.\n cls = self.__class__\n\n schema = changes.setdefault(\"schema\", self.schema)\n NewValidator = validator_for(schema, default=cls)\n\n for field in attr.fields(cls):\n if not field.init:\n continue\n attr_name = field.name # To deal with private attributes.\n init_name = attr_name if attr_name[0] != \"_\" else attr_name[1:]\n if init_name not in changes:\n changes[init_name] = getattr(self, attr_name)\n\n return NewValidator(**changes)\n\n def iter_errors(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.iter_errors \"\n \"is deprecated and will be removed in a future \"\n \"release. Call validator.evolve(schema=new_schema).\"\n \"iter_errors(...) instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n else:\n _schema = self.schema\n\n if _schema is True:\n return\n elif _schema is False:\n yield exceptions.ValidationError(\n f\"False schema does not allow {instance!r}\",\n validator=None,\n validator_value=None,\n instance=instance,\n schema=_schema,\n )\n return\n\n scope = id_of(_schema)\n if scope:\n self.resolver.push_scope(scope)\n try:\n for k, v in applicable_validators(_schema):\n validator = self.VALIDATORS.get(k)\n if validator is None:\n continue\n\n errors = validator(self, v, instance, _schema) or ()\n for error in errors:\n # set details if not already set by the called fn\n error._set(\n validator=k,\n validator_value=v,\n instance=instance,\n schema=_schema,\n type_checker=self.TYPE_CHECKER,\n )\n if k not in {\"if\", \"$ref\"}:\n error.schema_path.appendleft(k)\n yield error\n finally:\n if scope:\n self.resolver.pop_scope()\n\n def descend(self, instance, schema, path=None, schema_path=None):\n for error in self.evolve(schema=schema).iter_errors(instance):\n if path is not None:\n error.path.appendleft(path)\n if schema_path is not None:\n error.schema_path.appendleft(schema_path)\n yield error\n\n def validate(self, *args, **kwargs):\n for error in self.iter_errors(*args, **kwargs):\n raise error\n\n def is_type(self, instance, type):\n try:\n return self.TYPE_CHECKER.is_type(instance, type)\n except exceptions.UndefinedTypeCheck:\n raise exceptions.UnknownType(type, instance, self.schema)\n\n def is_valid(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.is_valid is deprecated \"\n \"and will be removed in a future release. Call \"\n \"validator.evolve(schema=new_schema).is_valid(...) \"\n \"instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n self = self.evolve(schema=_schema)\n\n error = next(self.iter_errors(instance), None)\n return error is None\n\n if version is not None:\n safe = version.title().replace(\" \", \"\").replace(\"-\", \"\")\n Validator.__name__ = Validator.__qualname__ = f\"{safe}Validator\"\n Validator = validates(version)(Validator)\n\n return Validator\n```\n###test function signature:\n```python\n\n def test_validates_registers_meta_schema_id(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validates_registers_meta_schema_id", + "reference": "\n def test_validates_registers_meta_schema_id(self):\n meta_schema_key = \"meta schema id\"\n my_meta_schema = {\"id\": meta_schema_key}\n\n validators.create(\n meta_schema=my_meta_schema,\n version=\"my version\",\n id_of=lambda s: s.get(\"id\", \"\"),\n )\n self.addCleanup(validators._META_SCHEMAS.pop, meta_schema_key)\n\n self.assertIn(meta_schema_key, validators._META_SCHEMAS)", + "completions": [ + "```python\n def test_validates_registers_meta_schema_id(self):\n validator = create(\n meta_schema={\n \"$id\": \"https://example.com/schema.json\",\n \"type\": \"object\",\n \"properties\": {\"foo\": {\"type\": \"string\"}},\n }\n )\n self.assertEqual(\n validator.resolver.store[\"https://example.com/schema.json\"],\n validator.schema,\n )\n```\n" + ], + "line": 160, + "token": 896, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 94, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_metrics(directory: str) -> pd.Series:\n\n dfs = []\n # Loop through all csv files in the directory\n csv_files = [\n filename for filename in os.listdir(directory) if filename.endswith(\".csv\")\n ]\n\n if len(csv_files) == 0:\n return pd.Series()\n\n for filename in csv_files:\n filepath = os.path.join(directory, filename)\n # Read the csv file into a pandas DataFrame\n df = pd.read_csv(\n filepath,\n header=None,\n names=[\n \"initial_time_iso\",\n \"lead_time_hours\",\n \"channel\",\n \"metric\",\n \"value\",\n ],\n )\n out = pd.DataFrame()\n out[\"initial_time\"] = pd.to_datetime(df[\"initial_time_iso\"])\n out[\"lead_time\"] = pd.to_timedelta(df[\"lead_time_hours\"], unit=\"h\")\n out[\"channel\"] = df[\"channel\"]\n out[\"metric\"] = df[\"metric\"]\n out[\"value\"] = df[\"value\"]\n dfs.append(out)\n\n df = pd.concat(dfs, axis=0)\n df.set_index([\"initial_time\", \"lead_time\", \"channel\", \"metric\"], inplace=True)\n return df[\"value\"]\n```\n###test function signature:\n```python\n\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_and_read_metric", + "reference": "\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n initial_time = datetime.datetime(2022, 1, 1, 0, 0, 0)\n lead_time = datetime.timedelta(hours=24)\n channel = \"t2m\"\n metric = \"rmse\"\n value = 25.6\n\n # Read the metric from the directory\n directory = tmpdir.mkdir(\"metrics\")\n filename = \"test_metric.csv\"\n filepath = os.path.join(directory, filename)\n with open(filepath, \"w\") as f:\n forecast_metrics_io.write_metric(\n f, initial_time, lead_time, channel, metric, value\n )\n metrics = forecast_metrics_io.read_metrics(directory)\n\n # Check that the metric was written correctly\n expected_index = pd.MultiIndex.from_tuples(\n [(initial_time, lead_time, channel, metric)],\n names=[\"initial_time\", \"lead_time\", \"channel\", \"metric\"],\n )\n expected_value = pd.Series([value], index=expected_index, name=\"value\")\n pd.testing.assert_series_equal(metrics, expected_value)", + "completions": [ + "```python\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n metric = pd.Series(\n data=[1, 2, 3],\n index=[\n pd.Timestamp(\"2020-01-01 00:00:00\"),\n pd.Timestamp(\"2020-01-01 01:00:00\"),\n pd.Timestamp(\"2020-01-01 02:00:00\"),\n ],\n name=\"metric\",\n )\n metric.to_csv(tmpdir.join(\"metric.csv\"), header=False)\n\n # Read the metric from the file object\n read_metric = read_metrics(tmpdir.strpath)\n\n # Check that the metric is the same\n assert_series_equal(metric, read_metric)\n```\n" + ], + "line": 43, + "token": 320, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 95, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n\n def test_kde_bandwidth_method(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_kde_bandwidth_method", + "reference": "\n def test_kde_bandwidth_method(self):\n\n np.random.seed(8765678)\n n_basesample = 50\n xn = np.random.randn(n_basesample)\n\n # Default\n gkde = mlab.GaussianKDE(xn)\n # Supply a callable\n gkde2 = mlab.GaussianKDE(xn, 'scott')\n # Supply a scalar\n gkde3 = mlab.GaussianKDE(xn, bw_method=gkde.factor)\n\n xs = np.linspace(-7, 7, 51)\n kdepdf = gkde.evaluate(xs)\n kdepdf2 = gkde2.evaluate(xs)\n assert kdepdf.all() == kdepdf2.all()\n kdepdf3 = gkde3.evaluate(xs)\n assert kdepdf.all() == kdepdf3.all()", + "completions": [ + "```python\n def test_kde_bandwidth_method(self):\n from sklearn.neighbors import KernelDensity\n from sklearn.metrics import euclidean_distances\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n from sklearn.model_selection import train_test_split\n from sklearn.metrics import accuracy_score\n from sklearn.metrics import f1_score\n from sklearn.metrics import precision_score\n from sklearn.metrics import recall_score\n from sklearn.metrics import roc_auc_score\n from sklearn.metrics import confusion_matrix\n from sklearn.metrics import classification_report\n from sklearn.metrics import cohen_kappa_score\n from sklearn.metrics import mean_squared_error\n from sklearn.metrics import mean_absolute_error\n from sklearn.metrics import mean_squared_log_error\n from sklearn.metrics import median_absolute_error\n from sklearn.metrics import r2_score\n from sklearn.metrics import explained_variance_score\n from sklearn.metrics import silhouette_score\n from sklearn.metrics import calinski_harabasz_score\n from sklearn.metrics import davies_bouldin_score\n from sklearn.metrics import mutual_info_score\n from sklearn.metrics import normalized_mutual_info_score\n from sklearn.metrics import adjusted_mutual_info_score\n from sklearn.metrics import adjusted_rand_score\n from sklearn.metrics import completeness_score\n from sklearn.metrics import contingency_matrix\n from sklearn.metrics import homogeneity_completeness_v_measure\n from sklearn.metrics import homogeneity_score\n from sklearn.metrics import v_measure_score\n from sklearn.metrics import normalized_v_measure_score\n from sklearn.metrics import poisson\n from sklearn.metrics import kld\n from sklearn.metrics import kld_normalized\n from sklearn.metrics import kld_unnormalized\n from sklearn.metrics import kld_unnormalized_normalized\n from sklearn.metrics import kld_unnormalized_normalized_sym\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from" + ], + "line": 33, + "token": 273, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 96, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_top_k_embeddings(\n query_embedding: List[float],\n embeddings: List[List[float]],\n similarity_fn: Optional[Callable[..., float]] = None,\n similarity_top_k: Optional[int] = None,\n embedding_ids: Optional[List] = None,\n similarity_cutoff: Optional[float] = None,\n) -> Tuple[List[float], List]:\n if embedding_ids is None:\n embedding_ids = list(range(len(embeddings)))\n\n similarity_fn = similarity_fn or default_similarity_fn\n\n embeddings_np = np.array(embeddings)\n query_embedding_np = np.array(query_embedding)\n\n similarity_heap: List[Tuple[float, Any]] = []\n for i, emb in enumerate(embeddings_np):\n similarity = similarity_fn(query_embedding_np, emb)\n if similarity_cutoff is None or similarity > similarity_cutoff:\n heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))\n if similarity_top_k and len(similarity_heap) > similarity_top_k:\n heapq.heappop(similarity_heap)\n result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)\n\n result_similarities = [s for s, _ in result_tups]\n result_ids = [n for _, n in result_tups]\n\n return result_similarities, result_ids\n```\n###test function signature:\n```python\ndef test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_top_k_mmr_embeddings", + "reference": "def test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n query_embedding = [5.0, 0.0, 0.0]\n embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]]\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.8\n )\n\n assert np.isclose(0.8 * 4 / 5, result_similarities[0], atol=0.00001)\n assert np.isclose(\n 0.8 * 3 / 5 - (1 - 0.8) * (3 * 4 / 25 + 3 * 4 / 25),\n result_similarities[1],\n atol=0.00001,\n )\n assert np.isclose(\n 0.8 * -4 / 5 - (1 - 0.8) * (3 * -4 / 25 + 4 * 3 / 25),\n result_similarities[2],\n atol=0.00001,\n )\n assert result_ids == [0, 1, 2]\n\n # Tests that if the first embedding vector is close to the second,\n # it will return the third\n query_embedding = [1.0, 0.0, 1.0]\n embeddings = [[1.0, 0.0, 0.9], [1.0, 0.0, 0.8], [0.7, 0.0, 1.0]]\n\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.5\n )\n assert result_ids == [0, 2, 1]\n\n # Tests that embedding ids map properly to results\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, embedding_ids=[\"A\", \"B\", \"C\"], mmr_threshold=0.5\n )\n assert result_ids == [\"A\", \"C\", \"B\"]\n # Test that it will go back to the original order under a high threshold\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n assert result_ids == [0, 1, 2]\n\n # Test similarity_top_k works\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1, similarity_top_k=2\n )\n assert result_ids == [0, 1]\n\n # Test the results for get_top_k_embeddings and get_top_k_mmr_embeddings are the\n # same for threshold = 1\n query_embedding = [10, 23, 90, 78]\n embeddings = [[1, 23, 89, 68], [1, 74, 144, 23], [0.23, 0.0, 1.0, 9]]\n result_similarities_no_mmr, result_ids_no_mmr = get_top_k_embeddings(\n query_embedding, embeddings\n )\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n\n for result_no_mmr, result_with_mmr in zip(\n result_similarities_no_mmr, result_similarities\n ):\n assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)", + "completions": [ + "```python\ndef test_get_top_k_mmr_embeddings() -> None:\n query_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]\n embeddings = [\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n ]\n result_similarities, result_ids = get_top_k_embeddings(\n query_embedding, embeddings, similarity_top_k=3\n )\n assert result_similarities == [0.5, 0.4, 0.3]\n assert result_ids == [0, 1, 2]\n```\n" + ], + "line": 35, + "token": 309, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 97, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef frame_accuracy(\n logits: np.ndarray,\n label: np.ndarray,\n background_logit_threshold: float = 0.0,\n) -> float:\n top1_idx = np.argmax(logits, axis=-1)\n background_label = np.sum(label, axis=-1) == 0\n pred_background = (\n np.max(logits, axis=-1) <= background_logit_threshold\n ).astype(np.int32)\n\n # Extracts the label at the highest logit index for each input.\n top1_correct = np.take_along_axis(label, top1_idx[..., None], axis=-1)\n top1_correct = np.squeeze(top1_correct)\n foreground_correct = ~pred_background.astype(bool) * top1_correct\n\n # Count correctly classified background samples.\n background_correct = pred_background * background_label\n correct = foreground_correct + background_correct\n return np.sum(correct) / len(logits)\n```\n###test function signature:\n```python\n\n def test_frame_accuracy_all_background(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_frame_accuracy_all_background", + "reference": "\n def test_frame_accuracy_all_background(self):\n logits = np.array(\n [\n [1.5, 0.0, -1.0],\n [0.1, 0.7, 0.2],\n [0.1, 0.7, 0.2],\n [0.1, -0.7, 0.2],\n [-0.1, -0.7, -0.2],\n ],\n np.float32,\n )\n label = np.array(\n [\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n ],\n np.int32,\n )\n actual = metrics.frame_accuracy(logits, label)\n self.assertAlmostEqual(actual, 0.2)", + "completions": [ + "```python\ndef test_frame_accuracy_all_background(self):\n logits = np.array([[0.1, 0.2, 0.3, 0.4]])\n label = np.array([[0, 0, 0, 0]])\n self.assertAlmostEqual(frame_accuracy(logits, label), 1.0)\n```\n" + ], + "line": 28, + "token": 206, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_local_image_to_be_loaded", + "reference": "\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n file_path, _ = example_local_image\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=file_path,\n max_width=64,\n max_height=128,\n )\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert abs(scaling_factor - 0.5) < 1e-5\n assert decoding_result.shape == (64, 64, 3)\n assert (decoding_result == 0).all()", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n reference, image = example_local_image\n # when\n result = load_image_from_string(\n reference=reference,\n max_height=image.shape[0],\n max_width=image.shape[1],\n )\n # then\n assert result[0] == reference\n assert result[1] is None\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 99, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = \"htk\") -> Union[float, np.ndarray]:\n\n if mel_scale not in [\"slaney\", \"htk\", \"kaldi\"]:\n raise ValueError('mel_scale should be one of \"htk\", \"slaney\" or \"kaldi\".')\n\n if mel_scale == \"htk\":\n return 2595.0 * np.log10(1.0 + (freq / 700.0))\n elif mel_scale == \"kaldi\":\n return 1127.0 * np.log(1.0 + (freq / 700.0))\n\n min_log_hertz = 1000.0\n min_log_mel = 15.0\n logstep = 27.0 / np.log(6.4)\n mels = 3.0 * freq / 200.0\n\n if isinstance(freq, np.ndarray):\n log_region = freq >= min_log_hertz\n mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep\n elif freq >= min_log_hertz:\n mels = min_log_mel + np.log(freq / min_log_hertz) * logstep\n\n return mels\n```\n###test function signature:\n```python\n\n def test_hertz_to_mel(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_hertz_to_mel", + "reference": "\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(0.0), 0.0)\n self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)\n\n inputs = np.array([100, 200])\n expected = np.array([150.48910241, 283.22989816])\n self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))\n\n self.assertEqual(hertz_to_mel(0.0, \"slaney\"), 0.0)\n self.assertEqual(hertz_to_mel(100, \"slaney\"), 1.5)\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"slaney\"), expected))\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"kaldi\"), expected))\n\n with pytest.raises(ValueError):\n hertz_to_mel(100, mel_scale=None)", + "completions": [ + "```python\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(" + ], + "line": 27, + "token": 262, + "line_diff": 0, + "token_diff": 1 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/4/QS/line_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Python/4/QS/line_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..845f8613e4aa3d102fd7eaae538e584d275a824f --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/4/QS/line_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(15~28),subset_1(29~47),subset_2(48~75),subset_3(76~749) +StarCoder2-15b,28.61,22.39,21.84,24.65 +CodeLlama-7b,23.83,27.10,26.08,28.90 +CodeLlama-13b,29.39,26.21,27.09,24.41 +CodeLlama-34b,30.72,24.59,25.41,27.52 +DeepSeek-Coder-1.3b,23.47,26.99,24.69,27.34 +DeepSeek-Coder-6.7b,24.92,25.28,24.63,23.41 +DeepSeek-Coder-33b,29.89,27.96,28.83,29.35 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/4/QS/token_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Python/4/QS/token_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..55288265ecdea80b23e1c5462a7bc24945c43fc4 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/4/QS/token_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(122~225),subset_1(226~322),subset_2(345~571),subset_3(586~7038) +StarCoder2-15b,26.21,25.31,22.31,23.79 +CodeLlama-7b,24.57,28.22,26.47,26.65 +CodeLlama-13b,30.95,22.86,27.72,24.83 +CodeLlama-34b,28.46,27.89,24.06,28.03 +DeepSeek-Coder-1.3b,24.09,26.57,26.00,25.83 +DeepSeek-Coder-6.7b,24.78,25.51,24.96,22.99 +DeepSeek-Coder-33b,29.04,29.32,28.66,29.12 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/4/QS/tongji.py b/dataset/Test Generation/ComplexCodeEval-Python/4/QS/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/4/QS/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/5/EI/EI.json b/dataset/Test Generation/ComplexCodeEval-Python/5/EI/EI.json new file mode 100644 index 0000000000000000000000000000000000000000..0e6c0c3a75346f5bfc19b5a70007ff792ff1af67 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/5/EI/EI.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_pair_vector_and_distance(g: dgl.DGLGraph):\n dst_pos = g.ndata[\"pos\"][g.edges()[1]] + g.edata[\"pbc_offshift\"]\n src_pos = g.ndata[\"pos\"][g.edges()[0]]\n bond_vec = dst_pos - src_pos\n bond_dist = torch.norm(bond_vec, dim=1)\n\n return bond_vec, bond_dist\n```\n###test function signature:\n```python\n\n def test_compute_angle(self, graph_Mo, graph_CH4):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_angle", + "reference": "\n def test_compute_angle(self, graph_Mo, graph_CH4):\n s1, g1, state1 = graph_Mo\n lattice = torch.tensor(s1.lattice.matrix, dtype=matgl.float_th).unsqueeze(dim=0)\n g1.edata[\"pbc_offshift\"] = torch.matmul(g1.edata[\"pbc_offset\"], lattice[0])\n g1.ndata[\"pos\"] = g1.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g1)\n g1.edata[\"bond_vec\"] = bv\n g1.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g1, 4.0)\n\n line_graph = create_line_graph(g1, 4.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))\n np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata[\"theta\"])), decimal=4)\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n s2, g2, state2 = graph_CH4\n lattice = torch.tensor(np.identity(3), dtype=matgl.float_th).unsqueeze(dim=0)\n g2.edata[\"pbc_offshift\"] = torch.matmul(g2.edata[\"pbc_offset\"], lattice[0])\n g2.ndata[\"pos\"] = g2.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g2)\n g2.edata[\"bond_vec\"] = bv\n g2.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g2, 2.0)\n\n line_graph = create_line_graph(g2, 2.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata[\"theta\"]))\n )\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )", + "completions": [ + "```python\n def test_compute_angle(self, graph_Mo, graph_CH4):\n \"\"\"\n Test compute_angle function\n \"\"\"\n # compute angle between Mo-C and Mo-H\n angle_Mo_C = compute_angle(graph_Mo, graph_CH4, 0, 1)\n angle_Mo_H = compute_angle(graph_Mo, graph_CH4, 0, 2)\n\n # check angle between Mo-C and Mo-H\n self.assertAlmostEqual(angle_Mo_C, 104.5, places=1)\n self.assertAlmostEqual(angle_Mo_H, 109.5, places=1)\n```\n" + ], + "line": 16, + "token": 148, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 1, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def _make_schedule(self, targets: List[Task]) -> List[Task]:\n self.schedule = []\n self.dependencies = self._build_dependencies(targets)\n\n edge_tups = []\n for node in self.dependencies:\n for dependency in self.dependencies[node]:\n edge_tups.append((dependency, node))\n\n for task in targets:\n # add edges from a dummy node to each target to guarantee\n # they will be included in the final schedule\n edge_tups.append((Executor.DUMMY_TASK_VALUE, task))\n\n def _compare_key(task: Union[Task, str]):\n if task == Executor.DUMMY_TASK_VALUE:\n return (\"\", 0)\n return (\n task.group_label() or \"\",\n -task.priority(),\n )\n\n graph = networkx.DiGraph(edge_tups)\n res = [\n t\n for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)\n if t != Executor.DUMMY_TASK_VALUE\n ]\n return res\n```\n###test function signature:\n```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_group_label_with_dependencies", + "reference": "\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = create_mock_task(\"task1\", result=1, group_label=\"group1\")\n task2 = create_mock_task(\n \"task2\", result=2, dependencies={\"task1\": task1}, group_label=\"group2\"\n )\n task3 = create_mock_task(\n \"task3\", result=3, dependencies={\"task2\": task2}, group_label=\"group1\"\n )\n\n executor = Executor([task3])\n schedule = executor._make_schedule([task3])\n scheduled_labels = [\n task.group_label() for task in schedule if task.group_label()\n ]\n\n # Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1\n group1_indices = [\n i for i, label in enumerate(scheduled_labels) if label == \"group1\"\n ]\n group2_index = scheduled_labels.index(\"group2\")\n\n assert (\n group1_indices[-1] > group2_index\n ), \"Task with the same group label but later dependency was not scheduled after different group label\"", + "completions": [ + "```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = Task(name=\"task1\", group_label=\"group1\")\n task2 = Task(name=\"task2\", group_label=\"group1\")\n task3 = Task(name=\"task3\", group_label=\"group2\")\n task4 = Task(name=\"task4\", group_label=\"group2\")\n task5 = Task(name=\"task5\", group_label=\"group3\")\n task6 = Task(name=\"task6\", group_label=\"group3\")\n task7 = Task(name=\"task7\", group_label=\"group4\")\n task8 = Task(name=\"task8\", group_label=\"group4\")\n task9 = Task(name=\"task9\", group_label=\"group5\")\n task10 = Task(name=\"task10\", group_label=\"group5\")\n\n # Creating dependencies between tasks\n task1.depends_on(task2)\n task2.depends_on(task3)\n task3.depends_on(task4)\n task4.depends_on(task5)\n task5.depends_on(task6)\n task6.depends_on(task7)\n task7.depends_on(task8)\n task8.depends_on(task9)\n task9.depends_on(task10)\n\n # Creating a list of tasks\n tasks = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies\n tasks_with_dependencies = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with group labels\n tasks_with_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_" + ], + "line": 36, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 2, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n):\n if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:\n return\n\n head_dim = dim // heads\n max_seq_len = find_multiple(max_seq_len, 8)\n\n for b in layers:\n b.attention.kv_cache = KVCache(\n max_batch_size, max_seq_len, heads, head_dim\n )\n\n freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)\n causal_mask = torch.tril(\n torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)\n )\n\n return causal_mask, freq_cis\n```\n###test function signature:\n```python\n\ndef test_setup_cache_max_batch_size_greater_than_max():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_setup_cache_max_batch_size_greater_than_max", + "reference": "\ndef test_setup_cache_max_batch_size_greater_than_max():\n layers = [\n Mock(attention=Mock(kw_cache=None)),\n Mock(attention=Mock(kw_cache=None)),\n ]\n max_batch_size = 64\n max_seq_len = 32\n heads = 4\n head_dim = 32\n block_size = 32\n rope_base = 1000\n setup_cache(\n max_batch_size + 10,\n max_seq_len,\n head_dim * heads,\n heads,\n layers,\n block_size,\n rope_base,\n )\n for layer in layers:\n assert isinstance(layer.attention.kw_cache, KVCache)\n assert layer.attention.kw_cache.k_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )\n assert layer.attention.kw_cache.v_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )", + "completions": [ + "```python\ndef test_setup_cache_max_batch_size_greater_than_max():\n max_batch_size = 10\n max_seq_len = 10\n dim = 10\n heads = 10\n layers = [1, 2, 3]\n block_size = 10\n rope_base = 10\n\n causal_mask, freq_cis = setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n )\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size" + ], + "line": 26, + "token": 168, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_file_to_be_downloaded", + "reference": "\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n response = Response()\n response.status_code = 200\n image = np.zeros((128, 128, 3), dtype=np.uint8)\n _, encoded_image = cv2.imencode(\".jpg\", image)\n response._content = encoded_image\n requests_get_mock.return_value = response\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=\"https://some/file.jpg\"\n )\n\n # then\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert scaling_factor is None\n assert decoding_result.shape == image.shape\n assert np.allclose(decoding_result, image)\n requests_get_mock.assert_called_once_with(\"https://some/file.jpg\")", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n reference = \"https://www.example.com/image.jpg\"\n max_height = 100\n max_width = 100\n requests_get_mock.return_value.content = b\"image_bytes\"\n # when\n result = load_image_from_string(\n reference=reference, max_height=max_height, max_width=max_width\n )\n # then\n assert result == (\n \"data:image/jpeg;base64,image_bytes\",\n None,\n )\n```\n" + ], + "line": 47, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 4, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_random_circuits", + "reference": "\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n placer = placer(connectivity=connectivity)\n layout_circ = Circuit(5)\n initial_layout = placer(layout_circ)\n router = Sabre(connectivity=connectivity, lookahead=look, decay_lookahead=decay)\n circuit = generate_random_circuit(nqubits=5, ngates=n_gates)\n measurement = gates.M(*range(5))\n circuit.add(measurement)\n transpiled_circuit, final_qubit_map = router(circuit, initial_layout)\n assert router.added_swaps >= 0\n assert_connectivity(connectivity, transpiled_circuit)\n assert_placement(transpiled_circuit, final_qubit_map)\n assert n_gates + router.added_swaps + 1 == transpiled_circuit.ngates\n assert_circuit_equivalence(\n original_circuit=circuit,\n transpiled_circuit=transpiled_circuit,\n final_map=final_qubit_map,\n initial_map=initial_layout,\n )\n assert transpiled_circuit.queue[-1].register_name == measurement.register_name", + "completions": [ + "```python\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n import random\n import networkx as nx\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import" + ], + "line": 27, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 5, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef xx_circuit_step(source, strength, target, embodiment):\n\n permute_source_for_overlap, permute_target_for_overlap = None, None\n\n # apply all possible reflections, shifts to the source\n for source_reflection_name in reflection_options:\n reflected_source_coord, source_reflection, reflection_phase_shift = apply_reflection(\n source_reflection_name, source\n )\n for source_shift_name in shift_options:\n shifted_source_coord, source_shift, shift_phase_shift = apply_shift(\n source_shift_name, reflected_source_coord\n )\n\n # check for overlap, back out permutation\n source_shared, target_shared = None, None\n for i, j in [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]:\n\n if (\n abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi)) < EPSILON\n or abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi) - np.pi)\n < EPSILON\n ):\n source_shared, target_shared = i, j\n break\n if source_shared is None:\n continue\n\n # pick out the other coordinates\n source_first, source_second = (x for x in [0, 1, 2] if x != source_shared)\n target_first, target_second = (x for x in [0, 1, 2] if x != target_shared)\n\n # check for arccos validity\n r, s, u, v, x, y = decompose_xxyy_into_xxyy_xx(\n float(target[target_first]),\n float(target[target_second]),\n float(shifted_source_coord[source_first]),\n float(shifted_source_coord[source_second]),\n float(strength),\n )\n if any(math.isnan(val) for val in (r, s, u, v, x, y)):\n continue\n\n # OK: this combination of things works.\n # save the permutation which rotates the shared coordinate into ZZ.\n permute_source_for_overlap = canonical_rotation_circuit(source_first, source_second)\n permute_target_for_overlap = canonical_rotation_circuit(target_first, target_second)\n break\n\n if permute_source_for_overlap is not None:\n break\n\n if permute_source_for_overlap is None:\n raise QiskitError(\n \"Error during RZX decomposition: Could not find a suitable Weyl \"\n f\"reflection to match {source} to {target} along {strength}.\"\n )\n\n prefix_circuit, affix_circuit = QuantumCircuit(2), QuantumCircuit(2)\n\n # the basic formula we're trying to work with is:\n # target^p_t_f_o =\n # rs * (source^s_reflection * s_shift)^p_s_f_o * uv * operation * xy\n # but we're rearranging it into the form\n # target = affix source prefix\n # and computing just the prefix / affix circuits.\n\n # the outermost prefix layer comes from the (inverse) target permutation.\n prefix_circuit.compose(permute_target_for_overlap.inverse(), inplace=True)\n # the middle prefix layer comes from the local Z rolls.\n prefix_circuit.rz(2 * x, [0])\n prefix_circuit.rz(2 * y, [1])\n prefix_circuit.compose(embodiment, inplace=True)\n prefix_circuit.rz(2 * u, [0])\n prefix_circuit.rz(2 * v, [1])\n # the innermost prefix layer is source_reflection, shifted by source_shift,\n # finally conjugated by p_s_f_o.\n prefix_circuit.compose(permute_source_for_overlap, inplace=True)\n prefix_circuit.compose(source_reflection, inplace=True)\n prefix_circuit.global_phase += -np.log(reflection_phase_shift).imag\n prefix_circuit.global_phase += -np.log(shift_phase_shift).imag\n\n # the affix circuit is constructed in reverse.\n # first (i.e., innermost), we install the other half of the source transformations and p_s_f_o.\n affix_circuit.compose(source_reflection.inverse(), inplace=True)\n affix_circuit.compose(source_shift, inplace=True)\n affix_circuit.compose(permute_source_for_overlap.inverse(), inplace=True)\n # then, the other local rolls in the middle.\n affix_circuit.rz(2 * r, [0])\n affix_circuit.rz(2 * s, [1])\n # finally, the other half of the p_t_f_o conjugation.\n affix_circuit.compose(permute_target_for_overlap, inplace=True)\n\n return {\"prefix_circuit\": prefix_circuit, \"affix_circuit\": affix_circuit}\n```\n###test function signature:\n```python\n def test_xx_circuit_step(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_xx_circuit_step", + "reference": " def test_xx_circuit_step(self):\n\n for _ in range(100):\n source_coordinate, interaction, target_coordinate = self._generate_xxyy_test_case()\n\n source_embodiment = qiskit.QuantumCircuit(2)\n source_embodiment.append(UnitaryGate(canonical_matrix(*source_coordinate)), [0, 1])\n interaction_embodiment = qiskit.QuantumCircuit(2)\n interaction_embodiment.append(UnitaryGate(canonical_matrix(*interaction)), [0, 1])\n\n prefix_circuit, affix_circuit = itemgetter(\"prefix_circuit\", \"affix_circuit\")(\n xx_circuit_step(\n source_coordinate, interaction[0], target_coordinate, interaction_embodiment\n )\n )\n\n target_embodiment = QuantumCircuit(2)\n target_embodiment.compose(prefix_circuit, inplace=True)\n target_embodiment.compose(source_embodiment, inplace=True)\n target_embodiment.compose(affix_circuit, inplace=True)\n self.assertTrue(\n np.all(\n np.abs(\n qiskit.quantum_info.operators.Operator(target_embodiment).data\n - canonical_matrix(*target_coordinate)\n )\n < EPSILON\n )\n )", + "completions": [ + "```python\n def test_xx_circuit_step(self):\n source = np.array([0.1, 0.2, 0.3])\n strength = 0.4\n target = np.array([0.5, 0.6, 0.7])\n embodiment = QuantumCircuit(2)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.cx(0, 1)\n embodiment.barrier()\n embodiment.rz(0.1, 0)\n embodiment.rz(0.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.3, 0)\n embodiment.rz(0.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.5, 0)\n embodiment.rz(0.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.7, 0)\n embodiment.rz(0.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.9, 0)\n embodiment.rz(1.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.1, 0)\n embodiment.rz(1.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.3, 0)\n embodiment.rz(1.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.5, 0)\n embodiment.rz(1.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.7, 0)\n embodiment.rz(1.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.9, 0)\n embodiment.rz(2.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(2.1, 0)\n embodiment.rz(2.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment" + ], + "line": 90, + "token": 807, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_einsum(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_einsum", + "reference": "\n def test_einsum(self):\n x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n self.assertAllClose(\n knp.einsum(\"ijk,lkj->il\", x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"ijk,ikj->i\", x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"i...,j...k->...ijk\", x, y),\n np.einsum(\"i..., j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n self.assertAllClose(\n knp.Einsum(\"ijk,lkj->il\")(x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"ijk,ikj->i\")(x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n np.einsum(\"i...,j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))", + "completions": [ + "```python\n\n def test_arange(self):\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32),\n torch.arange(0, 10, 1, dtype=torch.float32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64),\n torch.arange(0, 10, 1, dtype=torch.float64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32),\n torch.arange(0, 10, 1, dtype=torch.int32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64),\n torch.arange(0, 10, 1, dtype=torch.int64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool),\n torch.arange(0, 10, 1, dtype=torch.bool),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64),\n torch.arange(0, 10, 1, dtype=torch.complex64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128),\n torch.arange(0, 10, 1, dtype=torch.complex128),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n )\n```\n" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 7, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_center_padding(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_center_padding", + "reference": "\n def test_spectrogram_center_padding(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"reflect\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,\n 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,\n 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,\n 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,\n 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,\n 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,\n 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,\n 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,\n 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,\n 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,\n 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,\n 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,\n 0.00217659, 0.00276204, 0.00260835, 0.00299299,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"constant\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,\n 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,\n 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,\n 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,\n 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,\n 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,\n 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,\n 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,\n 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,\n 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,\n 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,\n 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,\n 0.00788239, 0.00664407, 0.00824227, 0.00628301,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=False,\n )\n self.assertEqual(spec.shape, (257, 728))\n\n # fmt: off\n expected = np.array([\n 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,\n 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,\n 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,\n 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,\n 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,\n 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,\n 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,\n 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,\n 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,\n 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,\n 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,\n 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,\n 0.00811857, 0.00538216, 0.00685749, 0.00535275,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))", + "completions": [ + "```python\n def test_spectrogram_center_padding(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 4\n hop_length = 2\n fft_length = 16\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n np.testing.assert_array_equal(\n spectrogram,\n np.array(\n [\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0," + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 8, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_kword_ref_line", + "reference": "\ndef test_nanoplot_out_bars_with_kword_ref_line():\n\n out_bars_with_kword_ref_line = _generate_nanoplot(**CASES[12])\n\n assert _is_nanoplot_output(out_bars_with_kword_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = \"max\"\n y_ref_area = [1, 2]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = \"min\"\n expand_y = \"min\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_stroke_color=data_bar_negative_stroke_" + ], + "line": 749, + "token": 7038, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 9, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_take_along_axis(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_take_along_axis", + "reference": "\n def test_take_along_axis(self):\n x = np.arange(24).reshape([1, 2, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=1),\n np.take_along_axis(x, indices, axis=1),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=1)(x, indices),\n np.take_along_axis(x, indices, axis=1),\n )\n\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )\n\n # Test with axis=None\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.array([1, 2, 3], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=None),\n np.take_along_axis(x, indices, axis=None),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=None)(x, indices),\n np.take_along_axis(x, indices, axis=None),\n )\n\n # Test with negative indices\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )", + "completions": [ + "```python\n\n def test_take_along_axis(self):\n # test arange\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=1, dtype=torch.float32),\n torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=2, dtype=torch.float32),\n torch.tensor([0, 2, 4, 6, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=3, dtype=torch.float32),\n torch.tensor([0, 3, 6, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=4, dtype=torch.float32),\n torch.tensor([0, 4, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=5, dtype=torch.float32),\n torch.tensor([0, 5], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=6, dtype=torch.float32),\n torch.tensor([0, 6], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=7, dtype=torch.float32),\n torch.tensor([0, 7], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=8, dtype=torch.float32),\n torch.tensor([0, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=9, dtype=torch.float32),\n torch.tensor([0, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=10, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=11, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=12, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=13, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=14, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 10, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_correlation(\n response_level_scores: dict[str, list[dict[str, dict[str, Any]]]],\n name: str = '',\n) -> dict[str, dict[str, dict[str, float]]]:\n factscore_metrics = set(list_metrics(response_level_scores[FACTSCORE]))\n pred_metrics = set(list_metrics(response_level_scores[EVAL_METHOD]))\n result = {}\n\n for metric in factscore_metrics.intersection(pred_metrics):\n factscore_scores = [\n find_metric(response_level_scores[FACTSCORE][i], metric)\n for i in range(len(response_level_scores[FACTSCORE]))\n ]\n pred_scores = [\n find_metric(response_level_scores[EVAL_METHOD][i], metric)\n for i in range(len(response_level_scores[EVAL_METHOD]))\n ]\n\n # Remove failed runs\n for i in range(len(factscore_scores) - 1, -1, -1):\n if factscore_scores[i] == -1 or pred_scores[i] == -1:\n factscore_scores.pop(i)\n pred_scores.pop(i)\n\n # Cannot calculate correlation with only one data point\n if len(factscore_scores) <= 1 or len(pred_scores) <= 1:\n pearson_result, spearman_result = None, None\n else:\n pearson_result = stats.pearsonr(factscore_scores, pred_scores)\n spearman_result = stats.spearmanr(factscore_scores, pred_scores)\n\n scatter_plot(\n factscore_scores,\n pred_scores,\n f'{name}-{metric}',\n FACTSCORE,\n EVAL_METHOD,\n )\n result[metric] = {\n _PEARSON: CorrelationResult(pearson_result).result,\n _SPEARMAN: CorrelationResult(spearman_result).result,\n }\n\n return result\n```\n###test function signature:\n```python\n\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_correlation_one_data_point", + "reference": "\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n response_level_scores = {\n correlation_vs_factscore.FACTSCORE: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'N': 1}},\n ],\n correlation_vs_factscore.EVAL_METHOD: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'M': 1}},\n ],\n }\n actual_output = correlation_vs_factscore.compute_correlation(\n response_level_scores=response_level_scores\n )\n mock_pearsonr.assert_not_called()\n mock_spearmanr.assert_not_called()\n mock_scatter_plot.assert_called()\n self.assertIsInstance(actual_output, dict)\n self.assertLen(actual_output, 1)\n self.assertIn('S', actual_output)\n self.assertIsInstance(actual_output['S'], dict)\n self.assertLen(actual_output['S'], 2)\n self.assertIn('Pearson', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Pearson'], dict)\n self.assertIn('correlation', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['correlation']))\n self.assertIn('p_value', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['p_value']))\n self.assertIn('Spearman', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Spearman'], dict)\n self.assertIn('correlation', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['correlation']))\n self.assertIn('p_value', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['p_value']))", + "completions": [ + "```python\ndef test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n) -> None:\n response_level_scores = {\n FACTSCORE: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n EVAL_METHOD: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n }\n result = compute_correlation(response_level_scores)\n self.assertEqual(result, {})\n mock_spearmanr.assert_not_called()\n mock_pearsonr.assert_not_called()\n mock_scatter_plot.assert_not_called()\n```\n" + ], + "line": 54, + "token": 384, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_data_lines():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_data_lines", + "reference": "\ndef test_nanoplot_out_data_lines():\n\n out_data_lines = _generate_nanoplot(**CASES[0])\n\n assert _is_nanoplot_output(out_data_lines)\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_data_lines():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n expected = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_" + ], + "line": 749, + "token": 7038, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 12, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_daily_temp_mean", + "reference": "\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data = il_electricity_cdd_hdd_daily[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_daily[\"temperature_data\"]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert df.shape == (810, 3)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_mean\",\n ]\n\n assert round(df.temperature_mean.mean()) == 55.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"D\", tz=\"UTC\"\n )\n temperature_data = pd.DataFrame(\n {\"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]},\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n df = compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=[],\n cooling_balance_points=[],\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n )\n expected = pd.DataFrame(\n {\n \"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],\n \"n_hours_kept\": [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],\n \"n_hours_dropped\": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n },\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n pd.testing.assert_frame_equal(df, expected)\n```\n" + ], + "line": 166, + "token": 983, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 13, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef power_color(\n frequency,\n power,\n power_err=None,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n):\n freq_edges = np.asarray(freq_edges)\n if len(freq_edges) != 5:\n raise ValueError(\"freq_edges must have 5 elements\")\n\n frequency = np.asarray(frequency)\n power = np.asarray(power)\n\n if df is None:\n df = np.median(np.diff(frequency))\n input_frequency_low_edges = frequency - df / 2\n input_frequency_high_edges = frequency + df / 2\n\n if freq_edges.min() < input_frequency_low_edges[0]:\n raise ValueError(\"The minimum frequency is larger than the first frequency edge\")\n if freq_edges.max() > input_frequency_high_edges[-1]:\n raise ValueError(\"The maximum frequency is lower than the last frequency edge\")\n\n if power_err is None:\n power_err = power / np.sqrt(m)\n else:\n power_err = np.asarray(power_err)\n\n if freqs_to_exclude is not None:\n if len(np.shape(freqs_to_exclude)) == 1:\n freqs_to_exclude = [freqs_to_exclude]\n\n if (\n not isinstance(freqs_to_exclude, Iterable)\n or len(np.shape(freqs_to_exclude)) != 2\n or np.shape(freqs_to_exclude)[1] != 2\n ):\n raise ValueError(\"freqs_to_exclude must be of format [[f0, f1], [f2, f3], ...]\")\n for f0, f1 in freqs_to_exclude:\n frequency_mask = (input_frequency_low_edges > f0) & (input_frequency_high_edges < f1)\n idx0, idx1 = np.searchsorted(frequency, [f0, f1])\n power[frequency_mask] = np.mean([power[idx0], power[idx1]])\n\n var00, var00_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[:2],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var01, var01_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[2:4],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var10, var10_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[1:3],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var11, var11_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[3:5],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n pc0 = var00 / var01\n pc1 = var10 / var11\n pc0_err = pc0 * (var00_err / var00 + var01_err / var01)\n pc1_err = pc1 * (var10_err / var10 + var11_err / var11)\n if return_log:\n pc0_err = 1 / pc0 * pc0_err\n pc1_err = 1 / pc1 * pc1_err\n pc0 = np.log10(pc0)\n pc1 = np.log10(pc1)\n return pc0, pc0_err, pc1, pc1_err\n```\n###test function signature:\n```python\n\n def test_with_power_err(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_power_err", + "reference": "\n def test_with_power_err(self):\n pc0, pc0_err, pc1, pc1_err = power_color(\n self.freq,\n self.power,\n power_err=self.power / 2,\n )\n pc0e, pc0e_err, pc1e, pc1e_err = power_color(\n self.freq,\n self.power,\n power_err=self.power,\n )\n assert np.isclose(pc0, 1, atol=0.001)\n assert np.isclose(pc1, 1, atol=0.001)\n assert np.isclose(pc0e, 1, atol=0.001)\n assert np.isclose(pc1e, 1, atol=0.001)\n assert np.isclose(pc0e_err / pc0_err, 2, atol=0.001)\n assert np.isclose(pc1e_err / pc1_err, 2, atol=0.001)", + "completions": [ + "```python\n def test_with_power_err(self):\n frequency = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power_err = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n pc0, pc0_err, pc1, pc1_err = power_color(\n frequency,\n power,\n power_err=power_err,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n )\n self.assertEqual(pc0, 0.0)\n self.assertEqual(pc0_err, 0.0)\n self.assertEqual(pc1, 0.0)\n self.assertEqual(pc1_err, 0.0)\n```\n" + ], + "line": 97, + "token": 560, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 14, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def apply_top_k(\n self, logits: numpy.ndarray, filter_value=_MIN_FLOAT\n ) -> numpy.ndarray:\n logits_shape = logits.shape\n logits = logits.reshape(logits.shape[-1])\n top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]\n logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value\n\n return logits.reshape(logits_shape)\n```\n###test function signature:\n```python\n\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_apply_topk", + "reference": "\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11))\n\n token_generator = TokenGenerator(\n logits_shape=logits[-1].shape[-1],\n top_k=3,\n )\n\n filter_value = -float(\"Inf\")\n new_logits = token_generator.apply_top_k(\n logits.copy(), filter_value=filter_value\n )\n\n for _ in range(token_generator.top_k):\n curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits)\n assert curr_max > filter_value\n new_logits = numpy.delete(new_logits, idx)\n\n assert numpy.all(new_logits == filter_value)", + "completions": [ + "```python\n def test_apply_topk(\n self,\n ):\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9," + ], + "line": 21, + "token": 152, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(\n self,\n sources: List[Union[str, Path, ByteStream]],\n meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,\n ):\n parser = MarkdownIt(renderer_cls=RendererPlain)\n if self.table_to_single_line:\n parser.enable(\"table\")\n\n documents = []\n meta_list = normalize_metadata(meta=meta, sources_count=len(sources))\n\n for source, metadata in tqdm(\n zip(sources, meta_list),\n total=len(sources),\n desc=\"Converting markdown files to Documents\",\n disable=not self.progress_bar,\n ):\n try:\n bytestream = get_bytestream_from_source(source)\n except Exception as e:\n logger.warning(\"Could not read {source}. Skipping it. Error: {error}\", source=source, error=e)\n continue\n try:\n file_content = bytestream.data.decode(\"utf-8\")\n text = parser.render(file_content)\n except Exception as conversion_e:\n logger.warning(\n \"Failed to extract text from {source}. Skipping it. Error: {error}\",\n source=source,\n error=conversion_e,\n )\n continue\n\n merged_metadata = {**bytestream.meta, **metadata}\n document = Document(content=text, meta=merged_metadata)\n documents.append(document)\n\n return {\"documents\": documents}\n```\n###test function signature:\n```python\n\n def test_run_calls_normalize_metadata(self, test_files_path):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_calls_normalize_metadata", + "reference": "\n def test_run_calls_normalize_metadata(self, test_files_path):\n bytestream = ByteStream(data=b\"test\", meta={\"author\": \"test_author\", \"language\": \"en\"})\n\n converter = MarkdownToDocument()\n\n with patch(\"haystack.components.converters.markdown.normalize_metadata\") as normalize_metadata, patch(\n \"haystack.components.converters.markdown.MarkdownIt\"\n ):\n converter.run(sources=[bytestream, test_files_path / \"markdown\" / \"sample.md\"], meta={\"language\": \"it\"})\n\n # check that the metadata normalizer is called properly\n normalize_metadata.assert_called_with(meta={\"language\": \"it\"}, sources_count=2)", + "completions": [ + "```python\n def test_run_calls_normalize_metadata(self, test_files_path):\n # Arrange\n sources = [test_files_path / \"test_file.md\"]\n meta = {\"key\": \"value\"}\n\n # Act\n run(sources, meta)\n\n # Assert\n normalize_metadata.assert_called_once_with(meta=meta, sources_count=1)\n```\n" + ], + "line": 45, + "token": 311, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 16, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_config(config: object) -> None:\n if hasattr(config, \"kind\"):\n Config.running_config.classes[config.kind] = config # type: ignore\n Config.running_config.types[config.kind] = {}\n for field in fields(config): # type: ignore\n if hasattr(field, \"type\"):\n Config.running_config.types[config.kind][field.name] = optional_origin(field.type)\n else:\n raise RuntimeError(\"Config must have a 'kind' attribute\")\n```\n###test function signature:\n```python\n\ndef test_config():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_config", + "reference": "\ndef test_config():\n arg_parser = get_arg_parser()\n core_add_args(arg_parser)\n arg_parser.parse_known_args()\n cfg = Config(\"test\")\n cfg.add_config(ConfigTest)\n cfg.init_default_config()\n assert Config.dict() == {\n \"configtest\": {\n \"testvar1\": \"testing123\",\n \"testvar2\": 12345,\n \"testvar3\": {\n \"mydict\": {\"foo\": \"bar\", \"abc\": {\"def\": \"ghi\"}},\n \"myint\": 0,\n \"mystr\": \"Hello\",\n },\n }\n }\n cfg2 = Config(\"test2\")\n assert cfg.configtest.testvar1 == cfg2.configtest.testvar1\n assert cfg.configtest.testvar1 == Config.configtest.testvar1\n assert cfg.configtest.testvar2 == cfg2.configtest.testvar2\n assert cfg.configtest.testvar3 == cfg2.configtest.testvar3\n Config.configtest.testvar2 += 1\n assert cfg.configtest.testvar2 == 12346\n with pytest.raises(ConfigNotFoundError):\n Config.does_not_exist.foo = \"bar\"\n with pytest.raises(ConfigNotFoundError):\n cfg.does_not_exist.foo = \"bar\"", + "completions": [ + "```python\ndef test_config():\n from dataclasses import dataclass, field\n from typing import Optional\n\n @dataclass\n class Config:\n kind: str\n a: Optional[int] = field(default=None)\n b: Optional[int] = field(default=None)\n\n add_config(Config)\n\n assert Config.running_config.classes[\"Config\"] == Config\n assert Config.running_config.types[\"Config\"] == {\"a\": int, \"b\": int}\n```\n" + ], + "line": 19, + "token": 166, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 17, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_with_num_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_with_num_ref_line", + "reference": "\ndef test_nanoplot_out_with_num_ref_line():\n\n out_with_num_ref_line = _generate_nanoplot(**CASES[1])\n\n assert _is_nanoplot_output(out_with_num_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_with_num_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color," + ], + "line": 749, + "token": 7038, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 18, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def begin_typing(self, x=None):\n self.capturekeystrokes = True\n # Disable keypress shortcuts, which may otherwise cause the figure to\n # be saved, closed, etc., until the user stops typing. The way to\n # achieve this depends on whether toolmanager is in use.\n stack = ExitStack() # Register cleanup actions when user stops typing.\n self._on_stop_typing = stack.close\n toolmanager = getattr(\n self.ax.figure.canvas.manager, \"toolmanager\", None)\n if toolmanager is not None:\n # If using toolmanager, lock keypresses, and plan to release the\n # lock when typing stops.\n toolmanager.keypresslock(self)\n stack.callback(toolmanager.keypresslock.release, self)\n else:\n # If not using toolmanager, disable all keypress-related rcParams.\n # Avoid spurious warnings if keymaps are getting deprecated.\n with _api.suppress_matplotlib_deprecation_warning():\n stack.enter_context(mpl.rc_context(\n {k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n```\n###test function signature:\n```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_TextBox", + "reference": "\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n plt.rcParams._set(\"toolbar\", toolbar)\n\n submit_event = mock.Mock(spec=noop, return_value=None)\n text_change_event = mock.Mock(spec=noop, return_value=None)\n tool = widgets.TextBox(ax, '')\n tool.on_submit(submit_event)\n tool.on_text_change(text_change_event)\n\n assert tool.text == ''\n\n do_event(tool, '_click')\n\n tool.set_val('x**2')\n\n assert tool.text == 'x**2'\n assert text_change_event.call_count == 1\n\n tool.begin_typing()\n tool.stop_typing()\n\n assert submit_event.call_count == 2\n\n do_event(tool, '_click')\n do_event(tool, '_keypress', key='+')\n do_event(tool, '_keypress', key='5')\n\n assert text_change_event.call_count == 3", + "completions": [ + "```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n with ExitStack() as stack:\n stack.enter_context(mpl.rc_context({k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n ax.begin_typing()\n assert ax.capturekeystrokes is True\n assert toolbar.keypresslock.is_locked(ax)\n```\n" + ], + "line": 31, + "token": 258, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def read_events(\n db: PrefectDBInterface,\n session: AsyncSession,\n events_filter: EventFilter,\n limit: \"int | None\" = None,\n offset: \"int | None\" = None,\n) -> Sequence[ORMEvent]:\n # Always order by occurred timestamp, with placeholder for order direction\n order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc\n\n # Check if distinct fields are provided\n if distinct_fields := build_distinct_queries(events_filter):\n # Define window function\n window_function = (\n sa.func.row_number()\n .over(partition_by=distinct_fields, order_by=order(db.Event.occurred))\n .label(\"row_number\")\n )\n # Create a subquery with the window function\n subquery = (\n sa.select(db.Event, window_function)\n .where(\n sa.and_(\n *events_filter.build_where_clauses(db)\n ) # Ensure the same filters are applied here\n )\n .subquery()\n )\n\n # Alias the subquery for easier column references\n aliased_table = aliased(db.Event, subquery)\n\n # Create the final query from the subquery, filtering to get only rows with row_number = 1\n select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1)\n\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(subquery.c.occurred))\n\n else:\n # If no distinct fields are provided, create a query for all events\n select_events_query = sa.select(db.Event).where(\n sa.and_(*events_filter.build_where_clauses(db))\n )\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(db.Event.occurred))\n\n if limit is not None:\n limit = max(0, min(limit, events_filter.logical_limit))\n select_events_query = select_events_query.limit(limit=limit)\n if offset is not None:\n select_events_query = select_events_query.offset(offset=offset)\n\n logger.debug(\"Running PostgreSQL query: %s\", select_events_query)\n\n select_events_query_result = await session.execute(select_events_query)\n return select_events_query_result.scalars().unique().all()\n```\n###test function signature:\n```python\n\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_read_events", + "reference": "\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_1.id\n\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=3),\n until=pendulum.now(\"UTC\").subtract(days=1),\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_2.id\n\n # Read criteria should apply AND logic\n # Nothing should be returned for this resource\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n resource=EventResourceFilter(id=[\"prefect.garbage.foo\"]),\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 0", + "completions": [ + "```python\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n # Arrange\n events_filter = EventFilter(\n occurred_after=datetime.datetime(2022, 1, 1),\n occurred_before=datetime.datetime(2022, 1, 2),\n order=EventOrder.DESC,\n )\n # Act\n events = await read_events(\n db=self.db,\n session=session,\n events_filter=events_filter,\n )\n # Assert\n self.assertEqual(len(events), 2)\n self.assertEqual(events[0], event_2)\n self.assertEqual(events[1], event_1)\n```\n" + ], + "line": 58, + "token": 412, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 20, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n def test_evaluate_diff_dim(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_evaluate_diff_dim", + "reference": " def test_evaluate_diff_dim(self):\n x1 = np.arange(3, 10, 2)\n kde = mlab.GaussianKDE(x1)\n x2 = np.arange(3, 12, 2)\n y_expected = [\n 0.08797252, 0.11774109, 0.11774109, 0.08797252, 0.0370153\n ]\n y = kde.evaluate(x2)\n np.testing.assert_array_almost_equal(y, y_expected, 7)", + "completions": [ + "```python\n def test_evaluate_diff_dim(self):\n \"\"\"\n Test evaluate with different dimensions\n \"\"\"\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n\n # Generate a random dataset\n X, _ = make_blobs(n_samples=100, n_features=10, centers=5,\n random_state=42)\n\n # Standardize the data\n scaler = StandardScaler()\n X = scaler.fit_transform(X)\n\n # Create a GaussianMixture object\n gmm = GaussianMixture(n_components=5, covariance_type='full')\n gmm.fit(X)\n\n # Generate a random set of points\n n_points = 100\n points = np.random.rand(n_points, 10)\n\n # Evaluate the points\n result = gmm.evaluate(points)\n\n # Check that the result is correct\n self.assertEqual(result.shape, (n_points,))\n self.assertTrue(np.all(result >= 0))\n self.assertTrue(np.all(result <= 1))\n```\n" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 21, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _differentiate(func, x, *, args=(), atol=None, rtol=None, maxiter=10,\n order=8, initial_step=0.5, step_factor=2.0,\n step_direction=0, preserve_shape=False, callback=None):\n # TODO (followup):\n # - investigate behavior at saddle points\n # - array initial_step / step_factor?\n # - multivariate functions?\n\n res = _differentiate_iv(func, x, args, atol, rtol, maxiter, order, initial_step,\n step_factor, step_direction, preserve_shape, callback)\n (func, x, args, atol, rtol, maxiter, order,\n h0, fac, hdir, preserve_shape, callback) = res\n\n # Initialization\n # Since f(x) (no step) is not needed for central differences, it may be\n # possible to eliminate this function evaluation. However, it's useful for\n # input validation and standardization, and everything else is designed to\n # reduce function calls, so let's keep it simple.\n temp = eim._initialize(func, (x,), args, preserve_shape=preserve_shape)\n func, xs, fs, args, shape, dtype = temp\n x, f = xs[0], fs[0]\n df = np.full_like(f, np.nan)\n # Ideally we'd broadcast the shape of `hdir` in `_elementwise_algo_init`, but\n # it's simpler to do it here than to generalize `_elementwise_algo_init` further.\n # `hdir` and `x` are already broadcasted in `_differentiate_iv`, so we know\n # that `hdir` can be broadcasted to the final shape.\n hdir = np.broadcast_to(hdir, shape).flatten()\n\n status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress\n nit, nfev = 0, 1 # one function evaluations performed above\n # Boolean indices of left, central, right, and (all) one-sided steps\n il = hdir < 0\n ic = hdir == 0\n ir = hdir > 0\n io = il | ir\n\n # Most of these attributes are reasonably obvious, but:\n # - `fs` holds all the function values of all active `x`. The zeroth\n # axis corresponds with active points `x`, the first axis corresponds\n # with the different steps (in the order described in\n # `_differentiate_weights`).\n # - `terms` (which could probably use a better name) is half the `order`,\n # which is always even.\n work = _RichResult(x=x, df=df, fs=f[:, np.newaxis], error=np.nan, h=h0,\n df_last=np.nan, error_last=np.nan, h0=h0, fac=fac,\n atol=atol, rtol=rtol, nit=nit, nfev=nfev,\n status=status, dtype=dtype, terms=(order+1)//2,\n hdir=hdir, il=il, ic=ic, ir=ir, io=io)\n # This is the correspondence between terms in the `work` object and the\n # final result. In this case, the mapping is trivial. Note that `success`\n # is prepended automatically.\n res_work_pairs = [('status', 'status'), ('df', 'df'), ('error', 'error'),\n ('nit', 'nit'), ('nfev', 'nfev'), ('x', 'x')]\n\n def pre_func_eval(work):\n \"\"\"Determine the abscissae at which the function needs to be evaluated.\n\n See `_differentiate_weights` for a description of the stencil (pattern\n of the abscissae).\n\n In the first iteration, there is only one stored function value in\n `work.fs`, `f(x)`, so we need to evaluate at `order` new points. In\n subsequent iterations, we evaluate at two new points. Note that\n `work.x` is always flattened into a 1D array after broadcasting with\n all `args`, so we add a new axis at the end and evaluate all point\n in one call to the function.\n\n For improvement:\n - Consider measuring the step size actually taken, since `(x + h) - x`\n is not identically equal to `h` with floating point arithmetic.\n - Adjust the step size automatically if `x` is too big to resolve the\n step.\n - We could probably save some work if there are no central difference\n steps or no one-sided steps.\n \"\"\"\n n = work.terms # half the order\n h = work.h # step size\n c = work.fac # step reduction factor\n d = c**0.5 # square root of step reduction factor (one-sided stencil)\n # Note - no need to be careful about dtypes until we allocate `x_eval`\n\n if work.nit == 0:\n hc = h / c**np.arange(n)\n hc = np.concatenate((-hc[::-1], hc))\n else:\n hc = np.asarray([-h, h]) / c**(n-1)\n\n if work.nit == 0:\n hr = h / d**np.arange(2*n)\n else:\n hr = np.asarray([h, h/d]) / c**(n-1)\n\n n_new = 2*n if work.nit == 0 else 2 # number of new abscissae\n x_eval = np.zeros((len(work.hdir), n_new), dtype=work.dtype)\n il, ic, ir = work.il, work.ic, work.ir\n x_eval[ir] = work.x[ir, np.newaxis] + hr\n x_eval[ic] = work.x[ic, np.newaxis] + hc\n x_eval[il] = work.x[il, np.newaxis] - hr\n return x_eval\n\n def post_func_eval(x, f, work):\n \"\"\" Estimate the derivative and error from the function evaluations\n\n As in `pre_func_eval`: in the first iteration, there is only one stored\n function value in `work.fs`, `f(x)`, so we need to add the `order` new\n points. In subsequent iterations, we add two new points. The tricky\n part is getting the order to match that of the weights, which is\n described in `_differentiate_weights`.\n\n For improvement:\n - Change the order of the weights (and steps in `pre_func_eval`) to\n simplify `work_fc` concatenation and eliminate `fc` concatenation.\n - It would be simple to do one-step Richardson extrapolation with `df`\n and `df_last` to increase the order of the estimate and/or improve\n the error estimate.\n - Process the function evaluations in a more numerically favorable\n way. For instance, combining the pairs of central difference evals\n into a second-order approximation and using Richardson extrapolation\n to produce a higher order approximation seemed to retain accuracy up\n to very high order.\n - Alternatively, we could use `polyfit` like Jacobi. An advantage of\n fitting polynomial to more points than necessary is improved noise\n tolerance.\n \"\"\"\n n = work.terms\n n_new = n if work.nit == 0 else 1\n il, ic, io = work.il, work.ic, work.io\n\n # Central difference\n # `work_fc` is *all* the points at which the function has been evaluated\n # `fc` is the points we're using *this iteration* to produce the estimate\n work_fc = (f[ic, :n_new], work.fs[ic, :], f[ic, -n_new:])\n work_fc = np.concatenate(work_fc, axis=-1)\n if work.nit == 0:\n fc = work_fc\n else:\n fc = (work_fc[:, :n], work_fc[:, n:n+1], work_fc[:, -n:])\n fc = np.concatenate(fc, axis=-1)\n\n # One-sided difference\n work_fo = np.concatenate((work.fs[io, :], f[io, :]), axis=-1)\n if work.nit == 0:\n fo = work_fo\n else:\n fo = np.concatenate((work_fo[:, 0:1], work_fo[:, -2*n:]), axis=-1)\n\n work.fs = np.zeros((len(ic), work.fs.shape[-1] + 2*n_new))\n work.fs[ic] = work_fc\n work.fs[io] = work_fo\n\n wc, wo = _differentiate_weights(work, n)\n work.df_last = work.df.copy()\n work.df[ic] = fc @ wc / work.h\n work.df[io] = fo @ wo / work.h\n work.df[il] *= -1\n\n work.h /= work.fac\n work.error_last = work.error\n # Simple error estimate - the difference in derivative estimates between\n # this iteration and the last. This is typically conservative because if\n # convergence has begin, the true error is much closer to the difference\n # between the current estimate and the *next* error estimate. However,\n # we could use Richarson extrapolation to produce an error estimate that\n # is one order higher, and take the difference between that and\n # `work.df` (which would just be constant factor that depends on `fac`.)\n work.error = abs(work.df - work.df_last)\n\n def check_termination(work):\n \"\"\"Terminate due to convergence, non-finite values, or error increase\"\"\"\n stop = np.zeros_like(work.df).astype(bool)\n\n i = work.error < work.atol + work.rtol*abs(work.df)\n work.status[i] = eim._ECONVERGED\n stop[i] = True\n\n if work.nit > 0:\n i = ~((np.isfinite(work.x) & np.isfinite(work.df)) | stop)\n work.df[i], work.status[i] = np.nan, eim._EVALUEERR\n stop[i] = True\n\n # With infinite precision, there is a step size below which\n # all smaller step sizes will reduce the error. But in floating point\n # arithmetic, catastrophic cancellation will begin to cause the error\n # to increase again. This heuristic tries to avoid step sizes that are\n # too small. There may be more theoretically sound approaches for\n # detecting a step size that minimizes the total error, but this\n # heuristic seems simple and effective.\n i = (work.error > work.error_last*10) & ~stop\n work.status[i] = _EERRORINCREASE\n stop[i] = True\n\n return stop\n\n def post_termination_check(work):\n return\n\n def customize_result(res, shape):\n return shape\n\n return eim._loop(work, callback, shape, maxiter, func, args, dtype,\n pre_func_eval, post_func_eval, check_termination,\n post_termination_check, customize_result, res_work_pairs,\n preserve_shape)\n```\n###test function signature:\n```python\n\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_maxiter_callback", + "reference": "\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n x = 0.612814\n dist = stats.norm()\n maxiter = 3\n\n def f(x):\n res = dist.cdf(x)\n return res\n\n default_order = 8\n res = differentiate(f, x, maxiter=maxiter, rtol=1e-15)\n assert not np.any(res.success)\n assert np.all(res.nfev == default_order + 1 + (maxiter - 1)*2)\n assert np.all(res.nit == maxiter)\n\n def callback(res):\n callback.iter += 1\n callback.res = res\n assert hasattr(res, 'x')\n assert res.df not in callback.dfs\n callback.dfs.add(res.df)\n assert res.status == eim._EINPROGRESS\n if callback.iter == maxiter:\n raise StopIteration\n callback.iter = -1 # callback called once before first iteration\n callback.res = None\n callback.dfs = set()\n\n res2 = differentiate(f, x, callback=callback, rtol=1e-15)\n # terminating with callback is identical to terminating due to maxiter\n # (except for `status`)\n for key in res.keys():\n if key == 'status':\n assert res[key] == eim._ECONVERR\n assert callback.res[key] == eim._EINPROGRESS\n assert res2[key] == eim._ECALLBACK\n else:\n assert res2[key] == callback.res[key] == res[key]", + "completions": [ + "```python\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n def callback(work):\n if work.nit == 0:\n work.maxiter = 1\n else:\n work.maxiter = 0\n res = _differentiate(lambda x: x**2, 0, maxiter=2, callback=callback)\n self.assertEqual(res.nit, 1)\n self.assertEqual(res.status, 0)\n```\n" + ], + "line": 187, + "token": 1981, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 22, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef get_standard_metadata(\n pyproject_dict: Mapping[str, Any],\n settings: ScikitBuildSettings,\n) -> StandardMetadata:\n new_pyproject_dict = dict(pyproject_dict)\n # Handle any dynamic metadata\n for field, provider, config in load_dynamic_metadata(settings.metadata):\n if provider is None:\n msg = f\"{field} is missing provider\"\n raise KeyError(msg)\n if field not in pyproject_dict.get(\"project\", {}).get(\"dynamic\", []):\n msg = f\"{field} is not in project.dynamic\"\n raise KeyError(msg)\n new_pyproject_dict[\"project\"][field] = provider.dynamic_metadata(field, config)\n new_pyproject_dict[\"project\"][\"dynamic\"].remove(field)\n\n metadata = StandardMetadata.from_pyproject(new_pyproject_dict)\n\n # For scikit-build-core < 0.5, we keep the normalized name for back-compat\n if settings.minimum_version is not None and settings.minimum_version < Version(\n \"0.5\"\n ):\n metadata.name = metadata.canonical_name\n\n # The description field is required to be one line. Instead of merging it\n # or cutting off subsequent lines (setuptools), we throw a nice error.\n # But we didn't validate before 0.9.\n if (\n settings.minimum_version is None or settings.minimum_version >= Version(\"0.9\")\n ) and \"\\n\" in (metadata.description or \"\"):\n msg = \"Multiple lines in project.description are not supported; this is supposed to be a one line summary\"\n raise ValueError(msg)\n\n return metadata\n```\n###test function signature:\n```python\n\ndef test_plugin_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_plugin_metadata", + "reference": "\ndef test_plugin_metadata():\n reason_msg = (\n \"install hatch-fancy-pypi-readme and setuptools-scm to test the \"\n \"dynamic metadata plugins\"\n )\n pytest.importorskip(\"hatch_fancy_pypi_readme\", reason=reason_msg)\n pytest.importorskip(\"setuptools_scm\", reason=reason_msg)\n if shutil.which(\"git\") is None:\n pytest.skip(\"git is not installed\")\n\n shutil.copy(\"plugin_project.toml\", \"pyproject.toml\")\n\n subprocess.run([\"git\", \"init\", \"--initial-branch=main\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.name\", \"bot\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.email\", \"bot@scikit-build.org\"], check=True)\n subprocess.run([\"git\", \"add\", \"pyproject.toml\"], check=True)\n subprocess.run([\"git\", \"commit\", \"-m\", \"initial commit\"], check=True)\n subprocess.run([\"git\", \"tag\", \"v0.1.0\", \"-m\", \"initial commint\"], check=True)\n\n with Path(\"pyproject.toml\").open(\"rb\") as ft:\n pyproject = tomllib.load(ft)\n settings_reader = SettingsReader(pyproject, {}, state=\"metadata_wheel\")\n settings = settings_reader.settings\n\n settings_reader.validate_may_exit()\n\n metadata = get_standard_metadata(pyproject, settings)\n\n assert str(metadata.version) == \"0.1.0\"\n assert metadata.readme == pyproject_metadata.Readme(\n \"Fragment #1Fragment #2\", None, \"text/x-rst\"\n )\n\n assert set(GetRequires().dynamic_metadata()) == {\n \"hatch-fancy-pypi-readme>=22.3\",\n \"setuptools-scm\",\n }", + "completions": [ + "```python\n\ndef test_plugin_metadata():\n from skbuild import __version__ as skbuild_version\n from skbuild.settings import ScikitBuildSettings\n from skbuild.metadata import StandardMetadata\n from skbuild.utils import load_dynamic_metadata\n from skbuild.version import Version\n\n pyproject_dict = {\n \"build-system\": {\n \"requires\": [\"setuptools>=40.8.0\", \"wheel\"],\n \"build-backend\": \"setuptools.build_meta\",\n },\n \"project\": {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n \"description\": \"A build system for Python packages\",\n \"dynamic\": [\"name\", \"version\"],\n },\n }\n\n settings = ScikitBuildSettings(\n minimum_version=Version(\"0.1.0\"),\n metadata=[\n (\"name\", None, None),\n (\"version\", None, None),\n ],\n )\n\n metadata = get_standard_metadata(pyproject_dict, settings)\n\n assert metadata.name == \"scikit-build\"\n assert metadata.version == \"0.1.0\"\n assert metadata.description == \"A build system for Python packages\"\n\n assert metadata.canonical_name == \"scikit-build\"\n assert metadata.canonical_version == \"0.1.0\"\n\n assert metadata.dynamic_metadata == {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n }\n\n assert metadata.dynamic_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_providers == [None, None]\n\n assert metadata.dynamic_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_metadata_providers == [None, None]\n\n assert metadata.dynamic_metadata_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields_by_provider == {None: [\"name\", \"version\"]}\n\n assert metadata.dynamic_metadata_providers_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_configs_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field == {\n (None, \"name\"): \"scikit-build\",\n (None, \"version\"): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider == {\n (\"name\", None): \"scikit-build\",\n (\"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider == {\n (\"name\", None): None,\n (\"version\", None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config == {\n (None, \"name\", None): \"scikit-build\",\n (None, \"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config == {\n (\"name\", None, None): \"scikit-build\",\n (\"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider_and_config == {\n (\"name\", None, None): None,\n (\"version\", None, None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config_and_provider == {\n (None, \"name\", None, None): \"scikit-build\",\n (None, \"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config_and_provider == {\n (\"name\", None, None" + ], + "line": 40, + "token": 347, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 23, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _get_datetime_format(value):\n if not isinstance(value, pd.Series):\n value = pd.Series(value)\n\n value = value[~value.isna()]\n value = value.astype(str).to_numpy()\n\n return _guess_datetime_format_for_array(value)\n```\n###test function signature:\n```python\ndef test__get_datetime_format():\n # Setup\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test__get_datetime_format", + "reference": "def test__get_datetime_format():\n # Setup\n string_value = '2021-02-02'\n list_value = [np.nan, '2021-02-02']\n series_value = pd.Series(['2021-02-02T12:10:59'])\n\n # Run\n string_out = _get_datetime_format(string_value)\n list_out = _get_datetime_format(list_value)\n series_out = _get_datetime_format(series_value)\n\n # Assert\n expected_output = '%Y-%m-%d'\n assert string_out == expected_output\n assert list_out == expected_output\n assert series_out == '%Y-%m-%dT%H:%M:%S'", + "completions": [ + "```python\ndef test__get_datetime_format():\n # Setup\n value = pd.Series([\"2020-01-01\", \"2020-01-02\", \"2020-01-03\"])\n expected = \"%Y-%m-%d\"\n\n # Exercise\n actual = _get_datetime_format(value)\n\n # Verify\n assert actual == expected\n```\n" + ], + "line": 17, + "token": 124, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 24, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def load(self) -> None:\n super().load()\n\n _path = str(\n importlib_resources.files(\"distilabel\")\n / \"steps\"\n / \"tasks\"\n / \"templates\"\n / \"ultrafeedback\"\n / f\"{self.aspect}.jinja2\"\n )\n\n self._template = Template(open(_path).read())\n```\n###test function signature:\n```python\n\n def test_process_with_complex_aspect(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_process_with_complex_aspect", + "reference": "\n def test_process_with_complex_aspect(self) -> None:\n pipeline = Pipeline(name=\"unit-test-pipeline\")\n llm = UltraFeedbackLLM()\n\n task = UltraFeedback(\n name=\"ultrafeedback\",\n aspect=\"truthfulness\",\n llm=llm,\n pipeline=pipeline,\n )\n task.load()\n\n assert next(\n task.process([{\"instruction\": \"test\", \"generations\": [\"A\", \"B\"]}])\n ) == [\n {\n \"instruction\": \"test\",\n \"generations\": [\"A\", \"B\"],\n \"types\": [1, 2],\n \"rationales\": [\"text\", \"text\"],\n \"ratings\": [1, 2],\n \"rationales-for-ratings\": [\"text\", \"text\"],\n \"model_name\": \"ultrafeedback-model\",\n }\n ]", + "completions": [ + "```python\n def test_process_with_complex_aspect(self) -> None:\n self.aspect = \"complex_aspect\"\n self.load()\n self.assertEqual(\n self._template.render(\n {\n \"aspect\": self.aspect,\n \"aspect_name\": \"Complex Aspect\",\n \"aspect_description\": \"This is a complex aspect\",\n \"aspect_type\": \"complex\",\n \"aspect_value\": \"1\",\n \"aspect_weight\": \"1\",\n \"aspect_color\": \"red\",\n \"aspect_icon\": \"icon-circle\",\n \"aspect_icon_color\": \"red\",\n \"aspect_icon_bg_color\": \"white\",\n \"aspect_icon_size\": \"20\",\n \"aspect_icon_border_color\": \"red\",\n \"aspect_icon_border_width\": \"2\",\n \"aspect_icon_border_radius\": \"50\",\n \"aspect_icon_padding\": \"5\",\n \"aspect_icon_margin\": \"5\",\n \"aspect_icon_bg_size\": \"20\",\n \"aspect_icon_bg_border_color\": \"red\",\n \"aspect_icon_bg_border_width\": \"2\",\n \"aspect_icon_bg_border_radius\": \"50\",\n \"aspect_icon_bg_padding\": \"5\",\n \"aspect_icon_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_" + ], + "line": 21, + "token": 145, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 25, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _setup_cuda_device(device: torch.device) -> torch.device:\n local_rank = _get_local_rank() or 0\n if device.index is None:\n device = torch.device(type=\"cuda\", index=local_rank)\n\n # Ensure index is available before setting device\n if device.index >= torch.cuda.device_count():\n raise RuntimeError(\n \"The local rank is larger than the number of available GPUs.\"\n )\n\n torch.cuda.set_device(device)\n return device\n```\n###test function signature:\n```python\n\n def test_get_gpu_device(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_gpu_device", + "reference": "\n def test_get_gpu_device(self) -> None:\n device_idx = torch.cuda.device_count() - 1\n assert device_idx >= 0\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(device_idx)}, clear=True):\n device = get_device()\n assert device.type == \"cuda\"\n assert device.index == device_idx\n assert device.index == torch.cuda.current_device()\n\n # Test that we raise an error if the device index is specified on distributed runs\n if device_idx > 0:\n with pytest.raises(\n RuntimeError,\n match=f\"Device specified is cuda:0 but was assigned cuda:{device_idx}\",\n ):\n device = get_device(\"cuda:0\")\n\n invalid_device_idx = device_idx + 10\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(invalid_device_idx)}):\n with pytest.raises(\n RuntimeError,\n match=\"The local rank is larger than the number of available GPUs\",\n ):\n device = get_device(\"cuda\")\n\n # Test that we fall back to 0 if LOCAL_RANK is not specified\n device = torch.device(_get_device_type_from_env())\n device = _setup_cuda_device(device)\n assert device.type == \"cuda\"\n assert device.index == 0\n assert device.index == torch.cuda.current_device()", + "completions": [ + "```python\n def test_get_gpu_device(self) -> None:\n device = torch.device(type=\"cuda\", index=0)\n self.assertEqual(device, _setup_cuda_device(device))\n```\n" + ], + "line": 21, + "token": 157, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):\n local_dir = settings.home_dir / settings.cache_dir / \"models\" / model_id\n client.load_artifact(local_dir, posixpath.join(\"models\", model_id))\n meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)\n # internal worker case, no token based http client\n if not local_dir.exists():\n raise RuntimeError(f\"Cannot find existing model {project_key}.{model_id} in {local_dir}\")\n with (Path(local_dir) / META_FILENAME).open(encoding=\"utf-8\") as f:\n file_meta = yaml.load(f, Loader=yaml.Loader)\n classification_labels = cls.cast_labels(meta_response)\n meta = ModelMeta(\n name=meta_response.name,\n description=meta_response.description,\n model_type=SupportedModelTypes[meta_response.modelType],\n feature_names=meta_response.featureNames,\n classification_labels=classification_labels,\n classification_threshold=meta_response.threshold,\n loader_module=file_meta[\"loader_module\"],\n loader_class=file_meta[\"loader_class\"],\n )\n\n model_py_ver = (\n tuple(meta_response.languageVersion.split(\".\")) if \"PYTHON\" == meta_response.language.upper() else None\n )\n\n clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver)\n\n constructor_params = meta.__dict__\n constructor_params[\"id\"] = str(model_id)\n\n del constructor_params[\"loader_module\"]\n del constructor_params[\"loader_class\"]\n\n model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params)\n return model\n```\n###test function signature:\n```python\n\ndef test_model_download(request):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_model_download", + "reference": "\ndef test_model_download(request):\n model: BaseModel = request.getfixturevalue(\"german_credit_model\")\n project_key = str(uuid.uuid4())\n\n with utils.MockedProjectCacheDir():\n with utils.MockedClient(mock_all=False) as (client, mr):\n # The model needs to request files\n requested_urls = []\n requested_urls.extend(utils.register_uri_for_model_meta_info(mr, model, project_key))\n requested_urls.extend(\n utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True)\n )\n\n downloaded_model = BaseModel.download(client=client, project_key=project_key, model_id=str(model.id))\n\n for requested_url in requested_urls:\n assert utils.is_url_requested(mr.request_history, requested_url)\n\n assert downloaded_model.id == model.id\n assert downloaded_model.meta == model.meta", + "completions": [ + "```python\ndef test_model_download(request):\n from pathlib import Path\n from unittest.mock import MagicMock\n from giskard.client import GiskardClient\n from giskard.models import ModelMetaInfo, ModelMeta, SupportedModelTypes\n from giskard.utils import cast_labels\n from giskard.models.base import Model\n from giskard.models.classification import ClassificationModel\n from giskard.models.regression import RegressionModel\n from giskard.models.clustering import ClusteringModel\n from giskard.models.dim_reduction import DimReductionModel\n from giskard.models.anomaly_detection import AnomalyDetectionModel\n from giskard.models.time_series import TimeSeriesModel\n from giskard.models.text_classification import TextClassificationModel\n from giskard.models.text_recommendation import TextRecommendationModel\n from giskard.models.image_classification import ImageClassificationModel\n from giskard.models.image_recommendation import ImageRecommendationModel\n from giskard.models.image_segmentation import ImageSegmentationModel\n from giskard.models.image_generation import ImageGenerationModel\n from giskard.models.image_style_transfer import ImageStyleTransferModel\n from giskard.models.image_super_resolution import ImageSuperResolutionModel\n from giskard.models.image_denoising import ImageDenoisingModel\n from giskard.models.image_colorization import ImageColorizationModel\n from giskard.models.image_deraining import ImageDerainingModel\n from giskard.models.image_deblurring import ImageDeblurringModel\n from giskard.models.image_edge_detection import ImageEdgeDetectionModel\n from giskard.models.image_face_detection import ImageFaceDetectionModel\n from giskard.models.image_face_recognition import ImageFaceRecognitionModel\n from giskard.models.image_object_detection import ImageObjectDetectionModel\n from giskard.models.image_object_recognition import ImageObjectRecognitionModel\n from giskard.models.image_segmentation_mask_rcnn import ImageSegmentationMaskRCNNModel\n from giskard.models.image_segmentation_yolo import ImageSegmentationYOLOModel\n from giskard.models.image_segmentation_faster_rcnn import ImageSegmentationFasterRCNNModel\n from giskard.models.image_segmentation_unet import ImageSegmentationUNetModel\n from giskard.models.image_segmentation_deeplab import ImageSegmentationDeepLabModel\n from giskard.models.image_segmentation_fcn import ImageSegmentationFCNModel\n from giskard.models.image_segmentation_pspnet import ImageSegmentationPSPNetModel\n from giskard.models.image_segmentation_linknet import ImageSegmentationLinkNetModel\n from giskard.models.image_segmentation_unet_plus import ImageSegmentationUNetPlusModel\n from giskard.models.image_segmentation_deeplab_v3 import ImageSegmentationDeepLabV3Model\n from giskard.models.image_segmentation_deeplab_v3_plus import ImageSegmentationDeepLabV3PlusModel\n from giskard.models.image_segmentation_fcn_resnet101 import ImageSegmentationFCNResNet101Model\n from giskard.models.image_segmentation_fcn_resnet50 import ImageSegmentationFCNResNet50Model\n from giskard.models.image_segmentation_fcn_mobilenet_v2 import ImageSegmentationFCNMobileNetV2Model\n from giskard.models.image_segmentation_fcn_mobilenet_v3 import ImageSegmentationFCNMobileNetV3Model\n from giskard.models.image_segmentation_fcn_efficientnet_b" + ], + "line": 40, + "token": 316, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\ndef test_rounded_date_columns():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rounded_date_columns", + "reference": "def test_rounded_date_columns():\n df = pl.DataFrame(\n [\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.000000\", \"exp\": True},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.123456\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:01.000000\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00\", \"exp\": True},\n ]\n )\n\n df = df.with_columns(df[\"a\"].str.to_date().alias(\"a_dt\"))\n actual = columns_equal(df[\"a_dt\"], df[\"b\"])\n expected = df[\"exp\"]\n assert_series_equal(actual, expected, check_names=False)", + "completions": [ + "```python\ndef test_rounded_date_columns():\n # given\n col_1 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n col_2 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n\n # when\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n\n # then\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef annotations_to_instances(annos, image_size, mask_format=\"polygon\"):\n boxes = [BoxMode.convert(obj[\"bbox\"], obj[\"bbox_mode\"], BoxMode.XYXY_ABS) for obj in annos]\n target = Instances(image_size)\n boxes = target.gt_boxes = Boxes(boxes)\n boxes.clip(image_size)\n\n classes = [obj[\"category_id\"] for obj in annos]\n classes = torch.tensor(classes, dtype=torch.int64)\n target.gt_classes = classes\n\n if len(annos) and \"segmentation\" in annos[0]:\n segms = [obj[\"segmentation\"] for obj in annos]\n if mask_format == \"polygon\":\n masks = PolygonMasks(segms)\n else:\n assert mask_format == \"bitmask\", mask_format\n masks = []\n for segm in segms:\n if isinstance(segm, list):\n # polygon\n masks.append(polygons_to_bitmask(segm, *image_size))\n elif isinstance(segm, dict):\n # COCO RLE\n masks.append(mask_util.decode(segm))\n elif isinstance(segm, np.ndarray):\n assert segm.ndim == 2, \"Expect segmentation of 2 dimensions, got {}.\".format(\n segm.ndim\n )\n # mask array\n masks.append(segm)\n else:\n raise ValueError(\n \"Cannot convert segmentation of type '{}' to BitMasks!\"\n \"Supported types are: polygons as list[list[float] or ndarray],\"\n \" COCO-style RLE as a dict, or a full-image segmentation mask \"\n \"as a 2D ndarray.\".format(type(segm))\n )\n # torch.from_numpy does not support array with negative stride.\n masks = BitMasks(\n torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])\n )\n target.gt_masks = masks\n\n if len(annos) and \"keypoints\" in annos[0]:\n kpts = [obj.get(\"keypoints\", []) for obj in annos]\n target.gt_keypoints = Keypoints(kpts)\n\n return target\n```\n###test function signature:\n```python\n\n def test_transform_simple_annotation(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_simple_annotation", + "reference": "\n def test_transform_simple_annotation(self):\n transforms = T.TransformList([T.HFlipTransform(400)])\n anno = {\n \"bbox\": np.asarray([10, 10, 200, 300]),\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 3,\n \"segmentation\": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],\n }\n\n output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))\n self.assertTrue(np.allclose(output[\"bbox\"], [200, 10, 390, 300]))\n self.assertEqual(len(output[\"segmentation\"]), len(anno[\"segmentation\"]))\n self.assertTrue(np.allclose(output[\"segmentation\"][0], [390, 10, 300, 100, 300, 10]))\n\n detection_utils.annotations_to_instances([output, output], (400, 400))", + "completions": [ + "```python\n def test_transform_simple_annotation(self):\n annos = [\n {\n \"bbox\": [10, 10, 100, 100],\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 0,\n \"segmentation\": [\n [10, 10, 50, 50, 90, 90, 100, 100],\n [20, 20, 50, 50, 80, 80, 100, 100],\n ],\n \"keypoints\": [10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n }\n ]\n image_size = (100, 100)\n target = annotations_to_instances(annos, image_size)\n assert target.gt_boxes.tensor.tolist() == [[10, 10, 100, 100]]\n assert target.gt_classes.tolist() == [0]\n assert target.gt_masks.polygons().tolist() == [\n [[10, 10], [50, 50], [90, 90], [100, 100]],\n [[20, 20], [50, 50], [80, 80], [100, 100]],\n ]\n assert target.gt_keypoints.tensor.tolist() == [[10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n```\n" + ], + "line": 54, + "token": 437, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 29, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_simple(seed):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_simple", + "reference": "\ndef test_sabre_simple(seed):\n placer = Trivial()\n circ = Circuit(5)\n circ.add(gates.CZ(0, 1))\n initial_layout = placer(circ)\n router = Sabre(connectivity=star_connectivity(), seed=seed)\n routed_circuit, final_map = router(circuit=circ, initial_layout=initial_layout)\n assert router.added_swaps == 1\n assert final_map == {\"q0\": 2, \"q1\": 1, \"q2\": 0, \"q3\": 3, \"q4\": 4}\n assert routed_circuit.queue[0].qubits == (0, 2)\n assert isinstance(routed_circuit.queue[0], gates.SWAP)\n assert isinstance(routed_circuit.queue[1], gates.CZ)\n assert_connectivity(circuit=routed_circuit, connectivity=star_connectivity())\n assert_circuit_equivalence(\n original_circuit=circ,\n transpiled_circuit=routed_circuit,\n final_map=final_map,\n initial_map=initial_layout,\n )", + "completions": [ + "```python\ndef test_sabre_simple(seed):\n import random\n random.seed(seed)\n import networkx as nx\n from pyquil.quil import Program\n from pyquil.api import QVMConnection\n from pyquil.gates import CNOT, H, MEASURE\n from pyquil.quilbase import Declare\n from pyquil.quilbase import DefGate\n from pyquil.quilbase import DefCircuit\n from pyquil.quilbase import Gate\n from pyquil.quilbase import Circuit\n from pyquil.quilbase import QubitPlaceholder\n from pyquil.quilbase import MemoryReference\n from pyquil.quilbase import Parameter\n from pyquil.quilbase import Pragma\n from pyquil.quilbase import PragmaDeclare\n from pyquil.quilbase import PragmaDefGate\n from pyquil.quilbase import PragmaDefCircuit\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefQubit\n from pyquil.quilbase import PragmaDefMemory\n from pyquil.quilbase import PragmaDefRegister\n from pyquil.quilbase import PragmaDefLabel\n from pyquil.quilbase import PragmaDefFunction\n from pyquil.quilbase import PragmaDefProcedure\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyqu" + ], + "line": 27, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 30, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef unique_patterns_2d(\n agrid: NDArray[np.int64], ksize: int, periodic_input: bool\n) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:\n assert ksize >= 1\n if periodic_input:\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n else:\n # TODO: implement non-wrapped image handling\n # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n\n patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(\n agrid,\n (\n agrid.shape[0] - ksize + 1,\n agrid.shape[1] - ksize + 1,\n ksize,\n ksize,\n *agrid.shape[2:],\n ),\n agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],\n writeable=False,\n )\n patch_codes = hash_downto(patches, 2)\n uc, ui = np.unique(patch_codes, return_index=True)\n locs = np.unravel_index(ui, patch_codes.shape)\n up: NDArray[np.int64] = patches[locs[0], locs[1]]\n ids: NDArray[np.int64] = np.vectorize(\n {code: ind for ind, code in enumerate(uc)}.get\n )(patch_codes)\n return ids, up, patch_codes\n```\n###test function signature:\n```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_patterns_2d", + "reference": "\ndef test_unique_patterns_2d(img_redmaze) -> None:\n img = img_redmaze\n tile_size = 1\n pattern_width = 2\n _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(\n img, tile_size\n )\n\n (\n _patterns_in_grid,\n pattern_contents_list,\n patch_codes,\n ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)\n assert patch_codes[1][2] == 4867810695119132864\n assert pattern_contents_list[7][1][1] == 8253868773529191888", + "completions": [ + "```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n ids, up, patch_codes = unique_patterns_2d(img_redmaze, 3, True)\n assert ids.shape == (10,)\n assert up.shape == (10, 3, 3)\n assert patch_codes.shape == (10,)\n```\n" + ], + "line": 48, + "token": 425, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 31, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef serialize(data: Any, encoding: str | None = None) -> str | bytes:\n serialized: bytes | str\n\n if encoding is not None:\n serialized = yaml.dump(data, encoding=encoding, Dumper=AiiDADumper)\n else:\n serialized = yaml.dump(data, Dumper=AiiDADumper)\n\n return serialized\n```\n###test function signature:\n```python\ndef test_mixed_attribute_normal_dict():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mixed_attribute_normal_dict", + "reference": "def test_mixed_attribute_normal_dict():\n from aiida.common.extendeddicts import AttributeDict\n\n # Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively\n dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}\n attribute_dict = AttributeDict(dictionary)\n\n # Now add a normal dictionary in the attribute dictionary\n attribute_dict['nested']['normal'] = {'a': 2}\n\n serialized = serialize.serialize(attribute_dict)\n deserialized = serialize.deserialize_unsafe(serialized)\n\n assert attribute_dict, deserialized", + "completions": [ + "```python\ndef test_mixed_attribute_normal_dict():\n data = {\n \"a\": 1,\n \"b\": 2,\n \"c\": 3,\n \"d\": 4,\n \"e\": 5,\n \"f\": 6,\n \"g\": 7,\n \"h\": 8,\n \"i\": 9,\n \"j\": 10,\n \"k\": 11,\n \"l\": 12,\n \"m\": 13,\n \"n\": 14,\n \"o\": 15,\n \"p\": 16,\n \"q\": 17,\n \"r\": 18,\n \"s\": 19,\n \"t\": 20,\n \"u\": 21,\n \"v\": 22,\n \"w\": 23,\n \"x\": 24,\n \"y\": 25,\n \"z\": 26,\n }\n assert serialize(data) == \"\"\"a: 1\nb: 2\nc: 3\nd: 4\ne: 5\nf: 6\ng: 7\nh: 8\ni: 9\nj: 10\nk: 11\nl: 12\nm: 13\nn: 14\no: 15\np: 16\nq: 17\nr: 18\ns: 19\nt: 20\nu: 21\nv: 22\nw: 23\nx: 24\ny: 25\nz: 26\n\"\"\"\n```\n" + ], + "line": 17, + "token": 133, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected_against_alias", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n http_client.select_model(model_id=\"yolov8n-640\")\n\n # when\n result = http_client.unload_model(model_id=\"yolov8n-640\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"coco/3\",\n }\n assert (\n http_client.selected_model is None\n ), \"Even when alias is in use - selected model should be emptied\"", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n requests_mock.post(\n f\"{API_URL}/model/remove\",\n json={\n \"model_id\": \"test_model_id\",\n },\n headers=DEFAULT_HEADERS,\n )\n client = RoboflowClient(api_key=\"test_api_key\", api_secret=\"test_api_secret\")\n # when\n response = client.unload_model(model_id=\"test_model_alias\")\n # then\n assert response.model_id == \"test_model_id\"\n assert response.model_name == \"test_model_name\"\n assert response.model_version == \"test_model_version\"\n assert response.model_type == \"test_model_type\"\n assert response.model_status == \"test_model_status\"\n assert response.model_created_at == \"test_model_created_at\"\n assert response.model_updated_at == \"test_model_updated_at\"\n assert response.model_tags == \"test_model_tags\"\n assert response.model_description == \"test_model_description\"\n assert response.model_url == \"test_model_url\"\n assert response.model_thumbnail_url == \"test_model_thumbnail_url\"\n assert response.model_training_data_url == \"test_model_training_data_url\"\n assert response.model_training_data_type == \"test_model_training_data_type\"\n assert response.model_training_data_size == \"test_model_training_data_size\"\n assert response.model_training_data_description == \"test_model_training_data_description\"\n assert response.model_training_data_tags == \"test_model_training_data_tags\"\n assert response.model_training_data_created_at == \"test_model_training_data_created_at\"\n assert response.model_training_data_updated_at == \"test_model_training_data_updated_at\"\n assert response.model_training_data_status == \"test_model_training_data_status\"\n assert response.model_training_data_status_message == \"test_model_training_data_status_message\"\n assert response.model_training_data_status_details == \"test_model_training_data_status_details\"\n assert response.model_training_data_status_created_at == \"test_model_training_data_status_created_at\"\n assert response.model_training_data_status_updated_at == \"test_model_training_data_status_updated_at\"\n assert response.model_training_data_status_completed_at == \"test_model_training_data_status_completed_at\"\n assert response.model_training_data_status_duration == \"test_model_training_data_status_duration\"\n assert response.model_training_data_status_budget_type == \"test_model_training_data_status_budget_type\"\n assert response.model_training_data_status_budget_amount == \"test_model_training_data_status_budget_amount\"\n assert response.model_training_data_status_budget_unit == \"test_model_training_data_status_budget_unit\"\n assert response.model_training_data_status_budget_remaining == \"test_model_training_data_status_budget_remaining\"\n assert response.model_training_data_status_budget_message == \"test_model_training_data_status_budget_message\"\n assert response.model_training_data_status_budget_details == \"test_model_training_data_status_budget_details\"\n assert response.model_training_data_status_budget_created_at == \"test_model_training_data_status_budget_created_at\"\n assert response.model_training_data_status_budget_updated_at == \"test_" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 33, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef edge_index_to_vector(\n edge_index: Tensor,\n size: Tuple[int, int],\n bipartite: bool,\n force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n row, col = edge_index\n\n if bipartite: # No need to account for self-loops.\n idx = (row * size[1]).add_(col)\n population = size[0] * size[1]\n return idx, population\n\n elif force_undirected:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We only operate on the upper triangular matrix:\n mask = row < col\n row, col = row[mask], col[mask]\n offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n idx = row.mul_(num_nodes).add_(col).sub_(offset)\n population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n return idx, population\n\n else:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We remove self-loops as we do not want to take them into account\n # when sampling negative values.\n mask = row != col\n row, col = row[mask], col[mask]\n col[row < col] -= 1\n idx = row.mul_(num_nodes - 1).add_(col)\n population = num_nodes * num_nodes - num_nodes\n return idx, population\n```\n###test function signature:\n```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_edge_index_to_vector_and_vice_versa", + "reference": "\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n N = 10\n row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n edge_index = torch.stack([row, col], dim=0)\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n assert population == N * N\n assert idx.tolist() == list(range(population))\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n assert is_undirected(edge_index2)\n assert edge_index.tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n assert population == N * N - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n force_undirected=True)\n assert population == (N * (N + 1)) / 2 - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n force_undirected=True)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()", + "completions": [ + "```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n\n # Test force_undirected:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test force_undirected with bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n```\n" + ], + "line": 43, + "token": 351, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 34, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_integration_test(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_integration_test", + "reference": "\n def test_spectrogram_integration_test(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,\n 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,\n 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,\n 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,\n 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,\n 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,\n 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,\n 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,\n 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,\n 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,\n 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,\n 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,\n 0.0293578 , 0.03452379, 0.02194803, 0.01676056,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=256,\n num_mel_filters=400,\n min_frequency=20,\n max_frequency=8000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"kaldi\",\n triangularize_in_mel_space=True,\n )\n\n mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"povey\", periodic=False),\n frame_length=400,\n hop_length=160,\n fft_length=512,\n power=2.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n preemphasis=0.97,\n mel_filters=mel_filters,\n log_mel=\"log\",\n mel_floor=1.1920928955078125e-07,\n remove_dc_offset=True,\n )\n self.assertEqual(spec.shape, (400, 584))\n\n # fmt: off\n expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -6.52463769, -7.73677889, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.18650018, -3.37195286,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -4.70190154, -2.4217066 , -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -5.62755239, -3.53385194,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -9.43303023, -8.77480925, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.2951092 , -5.51585994,\n -15.94238515, -15.94238515, -15.94238515, -4.40151721,\n -3.95228878, -15.94238515, -15.94238515, -15.94238515,\n -6.10365415, -4.59494697, -15.94238515, -15.94238515,\n -15.94238515, -8.10727767, -6.2585298 , -15.94238515,\n -15.94238515, -15.94238515, -5.60161702, -4.47217004,\n -15.94238515, -15.94238515, -15.94238515, -5.91641988]\n )\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))", + "completions": [ + "```python\n def test_spectrogram_integration_test(self):\n waveform = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])\n window = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n expected_spectrogram = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]])\n\n self.assertTrue(np.allclose(spectrogram, expected_spectrogram))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 35, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef run_query_expansion_node(modules: List[Callable],\n module_params: List[Dict],\n previous_result: pd.DataFrame,\n node_line_dir: str,\n strategies: Dict,\n ) -> pd.DataFrame:\n if not os.path.exists(node_line_dir):\n os.makedirs(node_line_dir)\n node_dir = os.path.join(node_line_dir, \"query_expansion\")\n if not os.path.exists(node_dir):\n os.makedirs(node_dir)\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n\n # run query expansion\n results, execution_times = zip(*map(lambda task: measure_speed(\n task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))\n average_times = list(map(lambda x: x / len(results[0]), execution_times))\n\n # save results to folder\n pseudo_module_params = deepcopy(module_params)\n for i, module_param in enumerate(pseudo_module_params):\n if 'prompt' in module_params:\n module_param['prompt'] = str(i)\n filepaths = list(map(lambda x: os.path.join(node_dir, f'{x}.parquet'), range(len(modules))))\n list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet\n filenames = list(map(lambda x: os.path.basename(x), filepaths))\n\n # make summary file\n summary_df = pd.DataFrame({\n 'filename': filenames,\n 'module_name': list(map(lambda module: module.__name__, modules)),\n 'module_params': module_params,\n 'execution_time': average_times,\n })\n\n # Run evaluation when there are more than one module.\n if len(modules) > 1:\n # pop general keys from strategies (e.g. metrics, speed_threshold)\n general_key = ['metrics', 'speed_threshold']\n general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))\n extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))\n\n # first, filter by threshold if it is enabled.\n if general_strategy.get('speed_threshold') is not None:\n results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],\n filenames)\n\n # check metrics in strategy\n if general_strategy.get('metrics') is None:\n raise ValueError(\"You must at least one metrics for query expansion evaluation.\")\n\n if extra_strategy.get('top_k') is None:\n extra_strategy['top_k'] = 10 # default value\n\n # get retrieval modules from strategy\n retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)\n\n # get retrieval_gt\n retrieval_gt = pd.read_parquet(os.path.join(project_dir, \"data\", \"qa.parquet\"))['retrieval_gt'].tolist()\n\n # run evaluation\n evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(\n retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,\n general_strategy['metrics'], project_dir, previous_result), results))\n\n evaluation_df = pd.DataFrame({\n 'filename': filenames,\n **{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))\n for metric_name in general_strategy['metrics']}\n })\n summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')\n\n best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)\n # change metric name columns to query_expansion_metric_name\n best_result = best_result.rename(columns={\n metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})\n best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])\n else:\n best_result, best_filename = results[0], filenames[0]\n best_result = pd.concat([previous_result, best_result], axis=1)\n\n # add 'is_best' column at summary file\n summary_df['is_best'] = summary_df['filename'] == best_filename\n\n # save files\n summary_df.to_csv(os.path.join(node_dir, \"summary.csv\"), index=False)\n best_result.to_parquet(os.path.join(node_dir, f\"best_{os.path.splitext(best_filename)[0]}.parquet\"), index=False)\n\n return best_result\n```\n###test function signature:\n```python\n\ndef test_run_query_expansion_node_default(node_line_dir):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_query_expansion_node_default", + "reference": "\ndef test_run_query_expansion_node_default(node_line_dir):\n generator_models['mock'] = MockLLM\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n qa_path = os.path.join(project_dir, \"data\", \"qa.parquet\")\n previous_result = pd.read_parquet(qa_path)\n\n modules = [query_decompose, hyde]\n module_params = [{'llm': \"mock\", 'temperature': 0.2, 'batch': 7}, {'llm': \"mock\"}]\n strategies = {\n 'metrics': metrics\n }\n best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n base_query_expansion_test(best_result, node_line_dir)", + "completions": [ + "```python\ndef test_run_query_expansion_node_default(node_line_dir):\n modules = [query_expansion_module]\n module_params = [{'prompt': 'default'}]\n previous_result = pd.DataFrame({'query': ['a', 'b', 'c']})\n strategies = {'metrics': ['recall', 'precision', 'f1'], 'speed_threshold': 1000}\n run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n```\n" + ], + "line": 84, + "token": 850, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 36, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(self) -> Tuple[pd.DataFrame, FreqDetectionResult]:\n try:\n if self._prophet_detector_params[\"suppress_stan\"]:\n pd.set_option(\"mode.chained_assignment\", None)\n\n # Skip measurements based on feedbacks given from SODA Cloud\n preprocessed_df = self.preprocess(time_series_df=self.raw_time_series_df)\n\n # Automatically detect frequency of the time series\n freq_detector = FrequencyDetector(\n logs=self.logs,\n params=self.params,\n time_series_df=preprocessed_df,\n manual_freq=self.training_dataset_params.frequency,\n )\n freq_detection_result = freq_detector.detect_frequency()\n\n # Return if frequency detection failed\n if freq_detection_result.error_code_int >= ERROR_CODE_LEVEL_CUTTOFF:\n return self.exit_with_warning(freq_detection_result)\n\n # Apply training dataset configurations\n training_df = self.apply_training_dataset_configs(\n time_series_df=preprocessed_df, freq_detection_result=freq_detection_result\n )\n\n # Remove big gaps from the time series to not confuse Prophet\n training_df = self.remove_big_gaps_from_time_series(\n time_series_df=training_df, freq_detection_result=freq_detection_result\n )\n\n # Only use the last n points for training based on the window length\n window_length = self.get_window_length(training_df=training_df)\n training_df = training_df.iloc[-window_length:]\n\n training_df_shape = training_df[\"y\"].dropna().shape[0]\n if training_df_shape <= self._min_n_points:\n freq_detection_result = get_not_enough_measurements_freq_result(n_data_points=training_df_shape)\n return self.exit_with_warning(freq_detection_result)\n\n model_hyperparameters = self.get_prophet_hyperparameters(time_series_df=training_df)\n\n predictions_df = self.setup_fit_predict(\n time_series_df=training_df, model_hyperparameters=model_hyperparameters\n )\n anomalies_df = self.detect_anomalies(time_series_df=training_df, predictions_df=predictions_df)\n anomalies_df = self.generate_severity_zones(anomalies_df=anomalies_df)\n anomalies_df = self.compute_alert_level(anomalies_df=anomalies_df)\n return anomalies_df, freq_detection_result\n except Exception as e:\n raise e\n```\n###test function signature:\n```python\n\ndef test_with_exit() -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_exit", + "reference": "\ndef test_with_exit() -> None:\n time_series_df = generate_random_dataframe(size=2, n_rows_to_convert_none=0, frequency=\"D\")\n detector = ProphetDetector(\n logs=LOGS,\n params=PARAMS,\n time_series_df=time_series_df,\n model_cfg=ModelConfigs(),\n training_dataset_params=TrainingDatasetParameters(),\n severity_level_params=SeverityLevelParameters(),\n )\n df_anomalies, frequency_result = detector.run()\n assert df_anomalies.empty\n assert frequency_result.error_code_int == 100\n assert frequency_result.freq_detection_strategy == \"not_enough_measurements_custom\"", + "completions": [ + "```python\ndef test_with_exit() -> None:\n from soda.sodacl.checks.check_frequency import FrequencyDetector\n from soda.sodacl.checks.check_frequency import get_not_enough_measurements_freq_result\n from soda.sodacl.checks.check_frequency import ERROR_CODE_LEVEL_CUTTOFF\n from soda.sodacl.checks.check_frequency import FreqDetectionResult\n from soda.sodacl.checks.check_frequency import CheckFrequency\n from soda.sodacl.checks.check_frequency import CheckFrequencyParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyLogs\n from soda.sodacl.checks.check_frequency import CheckFrequencyTrainingDatasetParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyPro" + ], + "line": 52, + "token": 310, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 37, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef average_transforms(\n R: torch.Tensor,\n t: torch.Tensor,\n w: torch.Tensor,\n mask: torch.Tensor,\n dim: int,\n t_edge: Optional[torch.Tensor] = None,\n dither: Optional[bool] = True,\n dither_eps: float = 1e-4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert dim >= 0, \"dimension must index from the left\"\n w = torch.where(\n mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)\n )\n\n # We use different averaging models based on the number of weights\n num_transform_weights = w.size(-1)\n if num_transform_weights == 1:\n # Share a single scalar weight between t and R.\n probs = w.softmax(dim)\n t_probs = probs\n R_probs = probs[..., None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 2:\n # Use separate scalar weights for each of t and R.\n probs = w.softmax(dim)\n t_probs, R_probs = probs.unbind(-1)\n t_probs = t_probs[..., None]\n R_probs = R_probs[..., None, None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 3:\n # For R use a signed scalar weight.\n R_probs = w[..., 2].softmax(dim)[..., None, None]\n\n # For t use a two-parameter precision matrix P = P_isometric + P_radial.\n # We need to hand compute softmax over the shared dim x 2 elements.\n w_t = w[..., :2]\n w_t_total = w_t.logsumexp([dim, -1], True)\n p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)\n\n # Use Gaussian fusion for translation.\n t_edge = t_edge * mask.to(t_edge.dtype)[..., None]\n t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)\n else:\n raise NotImplementedError\n\n # Average rotation via SVD\n R_avg_unc = (R * R_probs).sum(dim)\n R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)\n U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)\n R_avg = U @ Vh\n\n # Enforce that matrix is rotation matrix\n d = torch.linalg.det(R_avg)\n d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)\n Vh = Vh * d_expand\n R_avg = U @ Vh\n return R_avg, t_avg\n```\n###test function signature:\n```python\n\ndef test_average_transforms(rotations, translations):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_average_transforms", + "reference": "\ndef test_average_transforms(rotations, translations):\n R = torch.stack([rotations[0], torch.eye(3)])\n t = torch.stack([translations[0], torch.zeros(3)])\n w = torch.ones(2, 2)\n mask = torch.ones(2)\n # average of a transform with the identity is \"half\" the transform\n R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)\n R_total_fromavg, _ = compose_transforms(\n R_avg, torch.zeros(3), R_avg, torch.zeros(3)\n )\n _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)\n assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)\n assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)", + "completions": [ + "```python\ndef test_average_transforms(rotations, translations):\n # test case 1\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 2\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 3\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 4\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 5\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 6\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 7\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 8\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 9\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w" + ], + "line": 65, + "token": 537, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n return GPT2Config(\n vocab_size=btlm_config.vocab_size,\n n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n n_embd=btlm_config.hidden_size,\n n_layer=btlm_config.num_hidden_layers,\n n_head=btlm_config.num_attention_heads,\n n_inner=btlm_config.n_inner,\n activation_function=btlm_config.activation_function,\n resid_pdrop=btlm_config.resid_pdrop,\n embd_pdrop=btlm_config.embd_pdrop,\n attn_pdrop=btlm_config.attn_pdrop,\n layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n initializer_range=btlm_config.initializer_range,\n bos_token_id=btlm_config.bos_token_id,\n eos_token_id=btlm_config.eos_token_id,\n # These are new arguments not in the original GPT2Config\n use_alibi=btlm_config.position_embedding_type == \"alibi\",\n use_flash_attn=btlm_config.position_embedding_type == \"alibi\", # Alibi code path requires flash_attn\n mup_width_scale=btlm_config.mup_width_scale,\n mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n mup_output_multiplier=btlm_config.mup_output_alpha,\n mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n mlp_multiple_of=1,\n )\n```\n###test function signature:\n```python\ndef test_btlm_optimized(model_name):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_btlm_optimized", + "reference": "def test_btlm_optimized(model_name):\n dtype = torch.float16\n device = \"cuda\"\n config = btlm_config_to_gpt2_config(\n AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n )\n config.fused_bias_fc = True\n config.fused_dropout_add_ln = True\n config.residual_in_fp32 = True\n\n pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n model = GPTLMHeadModel(config, device=device, dtype=dtype)\n model.load_state_dict(pretrained_state_dict)\n model.eval()\n\n torch.manual_seed(0)\n batch_size = 2\n max_seqlen = 256\n seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n input_ids = torch.randint(\n 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n )\n with torch.no_grad():\n out = model.transformer(input_ids)\n logits = model(input_ids).logits\n del model\n\n # Without device_map, the model is loaded on the CPU, which is very slow\n # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n model_ref = AutoModelForCausalLM.from_pretrained(\n model_name, device_map=\"auto\", trust_remote_code=True\n )\n model_ref.eval()\n with torch.no_grad():\n out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n logits_ref = model_ref(input_ids).logits.to(device=device)\n del model_ref\n\n model_hf = AutoModelForCausalLM.from_pretrained(\n model_name,\n torch_dtype=dtype,\n device_map={\"\": device},\n trust_remote_code=True,\n )\n model_hf.eval()\n with torch.no_grad():\n out_hf = model_hf.transformer(input_ids).last_hidden_state\n logits_hf = model_hf(input_ids).logits\n del model_hf\n\n print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n assert (logits - logits_ref).abs().max().item() < 3 * (\n logits_hf - logits_ref\n ).abs().max().item()", + "completions": [ + "```python\ndef test_btlm_optimized(model_name):\n from transformers import GPT2Config, PretrainedConfig\n from bert_extended.optimized import btlm_config_to_gpt2_config\n\n config = PretrainedConfig.from_pretrained(model_name)\n gpt2_config = btlm_config_to_gpt2_config(config)\n gpt2_config.to_json_file(f\"{model_name}_optimized.json\")\n```\n" + ], + "line": 35, + "token": 169, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 39, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef grad_solid_harmonic(r: jnp.ndarray, l_max: int) -> jnp.ndarray:\n def cartesian_sum(r, ell):\n # computes all cartesian sums of order ell\n # i.e. all terms (a*x + b*y + c*z) for a+b+c = ell\n if ell == 0:\n return np.array([0.])\n return np.sum(np.stack(\n np.meshgrid(*[r]*ell, indexing='ij'), axis=0).reshape(ell, -1), axis=0)\n\n dharmonics = []\n for ell in range(l_max + 1):\n dx = cartesian_sum(np.array([1., 0., 0.]), ell)\n dy = cartesian_sum(np.array([0., 1., 0.]), ell)\n dz = cartesian_sum(np.array([0., 0., 1.]), ell)\n d = np.stack([dx, dy, dz], axis=1)\n cp = jax.vmap(functools.partial(cartesian_product, ell=ell), out_axes=-1)\n dcp = jnp.where(r.T[None, :, :] == 0, 0,\n d[:, :, None]*cp(r)[:, None, :] / r.T[None, :, :])\n sh = jnp.einsum('cs,cxn->snx', full_cart2sph(ell, True), dcp)\n pad = l_max - ell\n sh = jnp.pad(sh, [[pad, pad], [0, 0], [0, 0]])\n dharmonics.append(sh)\n\n return jnp.stack(dharmonics, axis=1) # [L, M, grid, xyz]\n```\n###test function signature:\n```python\n\n def test_grad_solid_harmonic(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_grad_solid_harmonic", + "reference": "\n def test_grad_solid_harmonic(self):\n np.random.seed(0)\n r = np.random.randn(100, 3)\n l_max = 5\n\n jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max))\n expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])),\n [1, 2, 0, 3])\n\n with self.subTest('by hand'):\n observed = gto.grad_solid_harmonic(r, l_max)\n np.testing.assert_allclose(observed, expected, atol=1.e-4)\n with self.subTest('by jax'):\n observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max)\n np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4)", + "completions": [ + "```python\n def test_grad_solid_harmonic(self):\n r = jnp.array([[0.5, 0.5, 0.5]])\n l_max = 2\n dharmonics = grad_solid_harmonic(r, l_max)\n self.assertEqual(dharmonics.shape, (l_max + 1, l_max + 1, 1, 3))\n self.assertTrue(jnp.allclose(dharmonics[0, 0, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 0, 0, :], jnp.array([1., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 1, 0, :], jnp.array([0., 1., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 0, 0, :], jnp.array([0., 0., 1.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 1, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 2, 0, :], jnp.array([0., 0., 0.])))\n```\n" + ], + "line": 32, + "token": 389, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 40, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_basis(\n base_shape, angular_tesselation, remove_symmetries=True, eps=1e-4\n):\n\n if base_shape == 'tetrahedron':\n verts = np.array([\n (np.sqrt(8 / 9), 0, -1 / 3),\n (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),\n (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3),\n (0, 0, 1),\n ])\n faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)])\n elif base_shape == 'icosahedron':\n a = (np.sqrt(5) + 1) / 2\n verts = np.array([\n (-1, 0, a),\n (1, 0, a),\n (-1, 0, -a),\n (1, 0, -a),\n (0, a, 1),\n (0, a, -1),\n (0, -a, 1),\n (0, -a, -1),\n (a, 1, 0),\n (-a, 1, 0),\n (a, -1, 0),\n (-a, -1, 0),\n ]) / np.sqrt(a + 2)\n faces = np.array([\n (0, 4, 1),\n (0, 9, 4),\n (9, 5, 4),\n (4, 5, 8),\n (4, 8, 1),\n (8, 10, 1),\n (8, 3, 10),\n (5, 3, 8),\n (5, 2, 3),\n (2, 7, 3),\n (7, 10, 3),\n (7, 6, 10),\n (7, 11, 6),\n (11, 0, 6),\n (0, 1, 6),\n (6, 1, 10),\n (9, 0, 11),\n (9, 11, 2),\n (9, 2, 5),\n (7, 2, 11),\n ])\n elif base_shape == 'octahedron':\n verts = np.array(\n [(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]\n )\n corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n else:\n raise ValueError(f'base_shape {base_shape} not supported')\n verts = tesselate_geodesic(verts, faces, angular_tesselation)\n\n if remove_symmetries:\n # Remove elements of `verts` that are reflections of each other.\n match = compute_sq_dist(verts.T, -verts.T) < eps\n verts = verts[~np.any(np.triu(match), axis=0), :]\n\n basis = verts[:, ::-1]\n return basis\n```\n###test function signature:\n```python\n def test_generate_basis_golden(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_basis_golden", + "reference": " def test_generate_basis_golden(self):\n\n basis = geopoly.generate_basis('tetrahedron', 1)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('tetrahedron', 2)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.57735027, 0.00000000, -0.81649658],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.57735027, -0.70710678, 0.40824829],\n [-0.57735027, 0.70710678, 0.40824829],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('icosahedron', 2)\n basis_golden = np.array([\n [0.85065081, 0.00000000, 0.52573111],\n [0.80901699, 0.50000000, 0.30901699],\n [0.52573111, 0.85065081, 0.00000000],\n [1.00000000, 0.00000000, 0.00000000],\n [0.80901699, 0.50000000, -0.30901699],\n [0.85065081, 0.00000000, -0.52573111],\n [0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, -0.85065081],\n [0.50000000, 0.30901699, -0.80901699],\n [0.00000000, 1.00000000, 0.00000000],\n [-0.52573111, 0.85065081, 0.00000000],\n [-0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, 0.85065081],\n [-0.30901699, 0.80901699, 0.50000000],\n [0.30901699, 0.80901699, 0.50000000],\n [0.50000000, 0.30901699, 0.80901699],\n [0.50000000, -0.30901699, 0.80901699],\n [0.00000000, 0.00000000, 1.00000000],\n [-0.50000000, 0.30901699, 0.80901699],\n [-0.80901699, 0.50000000, 0.30901699],\n [-0.80901699, 0.50000000, -0.30901699],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('octahedron', 4)\n basis_golden = np.array([\n [0.00000000, 0.00000000, -1.00000000],\n [0.00000000, -0.31622777, -0.94868330],\n [0.00000000, -0.70710678, -0.70710678],\n [0.00000000, -0.94868330, -0.31622777],\n [0.00000000, -1.00000000, 0.00000000],\n [-0.31622777, 0.00000000, -0.94868330],\n [-0.40824829, -0.40824829, -0.81649658],\n [-0.40824829, -0.81649658, -0.40824829],\n [-0.31622777, -0.94868330, 0.00000000],\n [-0.70710678, 0.00000000, -0.70710678],\n [-0.81649658, -0.40824829, -0.40824829],\n [-0.70710678, -0.70710678, 0.00000000],\n [-0.94868330, 0.00000000, -0.31622777],\n [-0.94868330, -0.31622777, 0.00000000],\n [-1.00000000, 0.00000000, 0.00000000],\n [0.00000000, -0.31622777, 0.94868330],\n [0.00000000, -0.70710678, 0.70710678],\n [0.00000000, -0.94868330, 0.31622777],\n [0.40824829, -0.40824829, 0.81649658],\n [0.40824829, -0.81649658, 0.40824829],\n [0.31622777, -0.94868330, 0.00000000],\n [0.81649658, -0.40824829, 0.40824829],\n [0.70710678, -0.70710678, 0.00000000],\n [0.94868330, -0.31622777, 0.00000000],\n [0.31622777, 0.00000000, -0.94868330],\n [0.40824829, 0.40824829, -0.81649658],\n [0.40824829, 0.81649658, -0.40824829],\n [0.70710678, 0.00000000, -0.70710678],\n [0.81649658, 0.40824829, -0.40824829],\n [0.94868330, 0.00000000, -0.31622777],\n [0.40824829, -0.40824829, -0.81649658],\n [0.40824829, -0.81649658, -0.40824829],\n [0.81649658, -0.40824829, -0.40824829],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))", + "completions": [ + "```python\n def test_generate_basis_golden(self):\n basis = generate_basis(\n base_shape='tetrahedron', angular_tesselation=10, remove_symmetries=True, eps=1e-4\n )\n self.assertTrue(np.allclose(basis, np.array([[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0]])))\n```\n" + ], + "line": 75, + "token": 692, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 41, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:\n if not isinstance(t, torch.Tensor):\n t = torch.Tensor([t]).float()\n\n t_max = self.t_max.to(t.device)\n t_min = self.t_min.to(t.device)\n t_tilde = t_min + (t_max - t_min) * t\n\n return t_tilde\n```\n###test function signature:\n```python\n\ndef test_noise_schedule_diffusion_coeff(kind):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_noise_schedule_diffusion_coeff", + "reference": "\ndef test_noise_schedule_diffusion_coeff(kind):\n noise_schedule = GaussianNoiseSchedule(kind=kind, log_snr_range=(-6, 6))\n ts = torch.linspace(1e-2, 1 - 1e-2, 10)\n t_map = noise_schedule.t_map(ts) # map time to the prescribed log_SNR range\n\n if kind == \"log_snr\":\n g = noise_schedule.g(ts)\n\n # compute true beta_t\n l_range = noise_schedule.log_snr_range\n snr = torch.exp((1 - t_map) * l_range[1] + t_map * l_range[0])\n g_true = (-(l_range[0] - l_range[1]) / (snr + 1)).sqrt()\n\n assert torch.allclose(g, g_true, atol=1e-4)\n\n if kind == \"ot_linear\":\n g = noise_schedule.g(ts)\n g_true = (2.0 * t_map / (1.0 - t_map)).sqrt()\n assert torch.allclose(g, g_true, atol=1e-4)", + "completions": [ + "```python\ndef test_noise_schedule_diffusion_coeff(kind):\n from torch.testing import assert_allclose\n from torch.nn import functional as F\n from torch import tensor\n from torch.distributions import Uniform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import Transform\n from torch.distributions.transforms import SigmoidTransform\n from torch.distributions.transforms import TanhTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import ComposeTransform\n from torch.distributions.transforms import PowerTransform\n from torch.distributions.transforms import ExpTransform\n from torch.distributions.transforms import LogTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import" + ], + "line": 17, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 42, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_data_quality", + "reference": "\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index, temperature_data, temperature_mean=False, data_quality=True\n )\n assert df.shape == (27, 4)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_not_null\",\n \"temperature_null\",\n ]\n assert round(df.temperature_not_null.mean(), 2) == 729.23\n assert round(df.temperature_null.mean(), 2) == 0.0", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"1H\", tz=\"America/Chicago\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0" + ], + "line": 168, + "token": 984, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 43, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n\n # when\n result = http_client.unload_model(model_id=\"some/1\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"some/1\",\n }", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n model_id = \"my_model_id\"\n de_aliased_model_id = \"my_de_aliased_model_id\"\n response_payload = {\n \"model_id\": de_aliased_model_id,\n \"model_name\": \"my_model_name\",\n \"model_version\": \"my_model_version\",\n \"model_type\": \"my_model_type\",\n \"model_status\": \"my_model_status\",\n \"model_description\": \"my_model_description\",\n \"model_tags\": [\"my_model_tag1\", \"my_model_tag2\"],\n \"model_created_at\": \"my_model_created_at\",\n \"model_updated_at\": \"my_model_updated_at\",\n }\n requests_mock.post(\n f\"{self.__api_url}/model/remove\",\n json=response_payload,\n headers=DEFAULT_HEADERS,\n )\n # when\n response = self.client.unload_model(model_id=model_id)\n # then\n assert response.model_id == de_aliased_model_id\n assert response.model_name == \"my_model_name\"\n assert response.model_version == \"my_model_version\"\n assert response.model_type == \"my_model_type\"\n assert response.model_status == \"my_model_status\"\n assert response.model_description == \"my_model_description\"\n assert response.model_tags == [\"my_model_tag1\", \"my_model_tag2\"]\n assert response.model_created_at == \"my_model_created_at\"\n assert response.model_updated_at == \"my_model_updated_at\"\n assert self.client.selected_model is None\n```\n" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=20.0,\n max_iteration=100,\n distance_strict=False,\n seed=None,\n):\n rng = np.random.default_rng(seed=seed)\n units_locations = np.zeros((num_units, 3), dtype=\"float32\")\n\n minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um\n minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um\n\n units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)\n units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)\n units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)\n\n if minimum_distance is not None:\n solution_found = False\n renew_inds = None\n for i in range(max_iteration):\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n inds0, inds1 = np.nonzero(distances < minimum_distance)\n mask = inds0 != inds1\n inds0 = inds0[mask]\n inds1 = inds1[mask]\n\n if inds0.size > 0:\n if renew_inds is None:\n renew_inds = np.unique(inds0)\n else:\n # random only bad ones in the previous set\n renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]\n\n units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)\n units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)\n units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)\n else:\n solution_found = True\n break\n\n if not solution_found:\n if distance_strict:\n raise ValueError(\n f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} \"\n \"You can use distance_strict=False or reduce minimum distance\"\n )\n else:\n warnings.warn(f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}\")\n\n return units_locations\n```\n###test function signature:\n```python\n\ndef test_generate_unit_locations():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_unit_locations", + "reference": "\ndef test_generate_unit_locations():\n seed = 0\n\n probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)\n channel_locations = probe.contact_positions\n\n num_units = 100\n minimum_distance = 20.0\n unit_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=minimum_distance,\n max_iteration=500,\n distance_strict=False,\n seed=seed,\n )\n distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)\n dist_flat = np.triu(distances, k=1).flatten()\n dist_flat = dist_flat[dist_flat > 0]\n assert np.all(dist_flat > minimum_distance)\n\n # import matplotlib.pyplot as plt\n # fig, ax = plt.subplots()\n # ax.hist(dist_flat, bins = np.arange(0, 400, 10))\n # fig, ax = plt.subplots()\n # from probeinterface.plotting import plot_probe\n # plot_probe(probe, ax=ax)\n # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)\n # plt.show()", + "completions": [ + "```python\n\ndef test_generate_unit_locations():\n num_units = 10\n channel_locations = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])\n margin_um = 20.0\n minimum_z = 5.0\n maximum_z = 40.0\n minimum_distance = 20.0\n max_iteration = 100\n distance_strict = False\n seed = None\n\n units_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=margin_um,\n minimum_z=minimum_z,\n maximum_z=maximum_z,\n minimum_distance=minimum_distance,\n max_iteration=max_iteration,\n distance_strict=distance_strict,\n seed=seed,\n )\n\n assert units_locations.shape == (num_units, 3)\n assert np.all(units_locations[:, 0] >= margin_um)\n assert np.all(units_locations[:, 0] <= 100 - margin_um)\n assert np.all(units_locations[:, 1] >= margin_um)\n assert np.all(units_locations[:, 1] <= 100 - margin_um)\n assert np.all(units_locations[:, 2] >= minimum_z)\n assert np.all(units_locations[:, 2] <= maximum_z)\n\n # check distance\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n assert np.all(distances >= minimum_distance)\n\n # check no duplicate\n assert np.all(np.unique(units_locations, axis=0) == units_locations)\n```\n" + ], + "line": 58, + "token": 453, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 45, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def to_dataframe(\n self, index: Optional[Union[str, None]] = \"date\", sort_by: Optional[str] = None\n ) -> pd.DataFrame:\n\n def is_list_of_basemodel(items: Union[List[T], T]) -> bool:\n return isinstance(items, list) and all(\n isinstance(item, BaseModel) for item in items\n )\n\n if self.results is None or not self.results:\n raise OpenBBError(\"Results not found.\")\n\n if isinstance(self.results, pd.DataFrame):\n return self.results\n\n try:\n res = self.results\n df = None\n sort_columns = True\n\n # List[Dict]\n if isinstance(res, list) and len(res) == 1 and isinstance(res[0], dict):\n r = res[0]\n dict_of_df = {}\n\n for k, v in r.items():\n # Dict[str, List[BaseModel]]\n if is_list_of_basemodel(v):\n dict_of_df[k] = basemodel_to_df(v, index)\n sort_columns = False\n # Dict[str, Any]\n else:\n dict_of_df[k] = pd.DataFrame(v)\n\n df = pd.concat(dict_of_df, axis=1)\n\n # List[BaseModel]\n elif is_list_of_basemodel(res):\n dt: Union[List[Data], Data] = res # type: ignore\n df = basemodel_to_df(dt, index)\n sort_columns = False\n # List[List | str | int | float] | Dict[str, Dict | List | BaseModel]\n else:\n try:\n df = pd.DataFrame(res)\n # Set index, if any\n if index is not None and index in df.columns:\n df.set_index(index, inplace=True)\n\n except ValueError:\n if isinstance(res, dict):\n df = pd.DataFrame([res])\n\n if df is None:\n raise OpenBBError(\"Unsupported data format.\")\n\n # Drop columns that are all NaN, but don't rearrange columns\n if sort_columns:\n df.sort_index(axis=1, inplace=True)\n df = df.dropna(axis=1, how=\"all\")\n\n # Sort by specified column\n if sort_by:\n df.sort_values(by=sort_by, inplace=True)\n\n except OpenBBError as e:\n raise e\n except ValueError as ve:\n raise OpenBBError(\n f\"ValueError: {ve}. Ensure the data format matches the expected format.\"\n ) from ve\n except TypeError as te:\n raise OpenBBError(\n f\"TypeError: {te}. Check the data types in your results.\"\n ) from te\n except Exception as ex:\n raise OpenBBError(f\"An unexpected error occurred: {ex}\") from ex\n\n return df\n```\n###test function signature:\n```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_to_dataframe_w_args", + "reference": "def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n co = OBBject(results=results)\n\n # Act and Assert\n result = co.to_dataframe(index=index, sort_by=sort_by)\n assert isinstance(result, pd.DataFrame)\n assert result.index.name == index\n\n # check if dataframe is properly sorted\n assert result[sort_by].is_monotonic_increasing", + "completions": [ + "```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n from openbb_terminal.decorators import log_start_end\n from openbb_terminal.stocks.fundamental_analysis import financedatabase_model\n\n @log_start_end(log=logger)\n def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n df = financedatabase_model.to_dataframe(results, index, sort_by)\n # Act\n # Assert\n assert df.shape[0] > 0\n assert df.shape[1] > 0\n assert df.index.name == index\n assert df.columns.tolist() == sort_by\n\n # Act\n test_to_dataframe_w_args(results, index, sort_by)\n # Assert\n```\n" + ], + "line": 76, + "token": 566, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 46, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef unique(ar, return_index=False, return_inverse=False,\n return_counts=False, axis=None, *, equal_nan=True):\n ar = np.asanyarray(ar)\n if axis is None:\n ret = _unique1d(ar, return_index, return_inverse, return_counts, \n equal_nan=equal_nan)\n return _unpack_tuple(ret)\n\n # axis was specified and not None\n try:\n ar = np.moveaxis(ar, axis, 0)\n except np.AxisError:\n # this removes the \"axis1\" or \"axis2\" prefix from the error message\n raise np.AxisError(axis, ar.ndim) from None\n\n # Must reshape to a contiguous 2D array for this to work...\n orig_shape, orig_dtype = ar.shape, ar.dtype\n ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))\n ar = np.ascontiguousarray(ar)\n dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]\n\n # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured\n # data type with `m` fields where each field has the data type of `ar`.\n # In the following, we create the array `consolidated`, which has\n # shape `(n,)` with data type `dtype`.\n try:\n if ar.shape[1] > 0:\n consolidated = ar.view(dtype)\n else:\n # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is\n # a data type with itemsize 0, and the call `ar.view(dtype)` will\n # fail. Instead, we'll use `np.empty` to explicitly create the\n # array with shape `(len(ar),)`. Because `dtype` in this case has\n # itemsize 0, the total size of the result is still 0 bytes.\n consolidated = np.empty(len(ar), dtype=dtype)\n except TypeError as e:\n # There's no good way to do this for object arrays, etc...\n msg = 'The axis argument to unique is not supported for dtype {dt}'\n raise TypeError(msg.format(dt=ar.dtype)) from e\n\n def reshape_uniq(uniq):\n n = len(uniq)\n uniq = uniq.view(orig_dtype)\n uniq = uniq.reshape(n, *orig_shape[1:])\n uniq = np.moveaxis(uniq, 0, axis)\n return uniq\n\n output = _unique1d(consolidated, return_index,\n return_inverse, return_counts, equal_nan=equal_nan)\n output = (reshape_uniq(output[0]),) + output[1:]\n return _unpack_tuple(output)\n```\n###test function signature:\n```python\n\n def test_unique_1d(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_1d", + "reference": "\n def test_unique_1d(self):\n\n def check_all(a, b, i1, i2, c, dt):\n base_msg = 'check {0} failed for type {1}'\n\n msg = base_msg.format('values', dt)\n v = unique(a)\n assert_array_equal(v, b, msg)\n\n msg = base_msg.format('return_index', dt)\n v, j = unique(a, True, False, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i1, msg)\n\n msg = base_msg.format('return_inverse', dt)\n v, j = unique(a, False, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i2, msg)\n\n msg = base_msg.format('return_counts', dt)\n v, j = unique(a, False, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, c, msg)\n\n msg = base_msg.format('return_index and return_inverse', dt)\n v, j1, j2 = unique(a, True, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n\n msg = base_msg.format('return_index and return_counts', dt)\n v, j1, j2 = unique(a, True, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format('return_inverse and return_counts', dt)\n v, j1, j2 = unique(a, False, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i2, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format(('return_index, return_inverse '\n 'and return_counts'), dt)\n v, j1, j2, j3 = unique(a, True, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n assert_array_equal(j3, c, msg)\n\n a = [5, 7, 1, 2, 1, 5, 7]*10\n b = [1, 2, 5, 7]\n i1 = [2, 3, 0, 1]\n i2 = [2, 3, 0, 1, 0, 2, 3]*10\n c = np.multiply([2, 1, 2, 2], 10)\n\n # test for numeric arrays\n types = []\n types.extend(np.typecodes['AllInteger'])\n types.extend(np.typecodes['AllFloat'])\n types.append('datetime64[D]')\n types.append('timedelta64[D]')\n for dt in types:\n aa = np.array(a, dt)\n bb = np.array(b, dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for object arrays\n dt = 'O'\n aa = np.empty(len(a), dt)\n aa[:] = a\n bb = np.empty(len(b), dt)\n bb[:] = b\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for structured arrays\n dt = [('', 'i'), ('', 'i')]\n aa = np.array(list(zip(a, a)), dt)\n bb = np.array(list(zip(b, b)), dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for ticket #2799\n aa = [1. + 0.j, 1 - 1.j, 1]\n assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])\n\n # test for ticket #4785\n a = [(1, 2), (1, 2), (2, 3)]\n unq = [1, 2, 3]\n inv = [0, 1, 0, 1, 1, 2]\n a1 = unique(a)\n assert_array_equal(a1, unq)\n a2, a2_inv = unique(a, return_inverse=True)\n assert_array_equal(a2, unq)\n assert_array_equal(a2_inv, inv)\n\n # test for chararrays with return_inverse (gh-5099)\n a = np.chararray(5)\n a[...] = ''\n a2, a2_inv = np.unique(a, return_inverse=True)\n assert_array_equal(a2_inv, np.zeros(5))\n\n # test for ticket #9137\n a = []\n a1_idx = np.unique(a, return_index=True)[1]\n a2_inv = np.unique(a, return_inverse=True)[1]\n a3_idx, a3_inv = np.unique(a, return_index=True,\n return_inverse=True)[1:]\n assert_equal(a1_idx.dtype, np.intp)\n assert_equal(a2_inv.dtype, np.intp)\n assert_equal(a3_idx.dtype, np.intp)\n assert_equal(a3_inv.dtype, np.intp)\n\n # test for ticket 2111 - float\n a = [2.0, np.nan, 1.0, np.nan]\n ua = [1.0, 2.0, np.nan]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - complex\n a = [2.0-1j, np.nan, 1.0+1j, complex(0.0, np.nan), complex(1.0, np.nan)]\n ua = [1.0+1j, 2.0-1j, complex(0.0, np.nan)]\n ua_idx = [2, 0, 3]\n ua_inv = [1, 2, 0, 2, 2]\n ua_cnt = [1, 1, 3]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - datetime64\n nat = np.datetime64('nat')\n a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]\n ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - timedelta\n nat = np.timedelta64('nat')\n a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]\n ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for gh-19300\n all_nans = [np.nan] * 4\n ua = [np.nan]\n ua_idx = [0]\n ua_inv = [0, 0, 0, 0]\n ua_cnt = [4]\n assert_equal(np.unique(all_nans), ua)\n assert_equal(np.unique(all_nans, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(all_nans, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(all_nans, return_counts=True), (ua, ua_cnt))", + "completions": [ + "```python\n def test_unique_1d(self):\n ar = np.array([1, 2, 3, 4, 5, 5, 5, 5, 5, 5])\n self.assertEqual(unique(ar), np.array([1, 2, 3, 4, 5]))\n self.assertEqual(unique(ar, return_index=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n```\n" + ], + "line": 56, + "token": 571, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef scaled_dot_product_attention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n att_mask: Optional[Union[AttentionMask, \"SparseCS\", torch.Tensor]],\n dropout: Optional[torch.nn.Module] = None,\n block_size: int = 128,\n) -> torch.Tensor:\n autocast_disabled = (\n _has_cpp_library\n and isinstance(att_mask, SparseCS)\n or (att_mask is not None and att_mask.is_sparse)\n )\n seq_len = q.shape[-2]\n\n # switch if:\n # causal is required but mask is not sparse\n # fp16 or under amp context\n # sequence length is divisible by block size\n # same seq len for K and Q\n switch_to_blocksparse = (\n _is_blocksparse_available\n and (att_mask is not None and not att_mask.is_sparse)\n and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)\n and (q.dtype == torch.float16 or torch.is_autocast_enabled())\n and not seq_len % block_size\n and q.shape[-2] == k.shape[-2]\n )\n\n if switch_to_blocksparse:\n logger.info(\"Switching causal attention to Triton blocksparse...\")\n return blocksparse_attention(q, k, v, dropout, block_size)\n\n with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore\n if autocast_disabled:\n q, k, v = q.float(), k.float(), v.float()\n\n att = scaled_query_key_softmax(q, k, att_mask=att_mask)\n\n # Optional dropout, could be part of the masking in the future\n att = _apply_dropout(att, dropout)\n\n # Get to the predicted values, for all heads\n # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)\n y = bmm(att, v)\n return y\n```\n###test function signature:\n```python\n\ndef test_amp_attention_dense(device):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_amp_attention_dense", + "reference": "\ndef test_amp_attention_dense(device):\n b, s, d = 8, 64, 32\n prob = 0.9\n\n a = torch.rand(b, s, d, device=device)\n m = torch.rand(s, s, device=device) > prob\n\n with torch.cuda.amp.autocast():\n r = scaled_dot_product_attention(a, a, a, m)\n\n expected_device = torch.float16 if device == \"cuda\" else torch.float32\n assert r.dtype == expected_device", + "completions": [ + "```python\n\ndef test_amp_attention_dense(device):\n q = torch.randn(1, 128, 128, device=device)\n k = torch.randn(1, 128, 128, device=device)\n v = torch.randn(1, 128, 128, device=device)\n att_mask = torch.randn(1, 128, 128, device=device)\n dropout = torch.nn.Dropout(0.1)\n block_size = 128\n y = scaled_dot_product_attention(q, k, v, att_mask, dropout, block_size)\n assert y.shape == (1, 128, 128)\n```\n" + ], + "line": 50, + "token": 390, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def quantize(\n cls,\n tensor: Tensor,\n nbits: int = 4,\n channel_wise: bool = True,\n group_size: int = 64,\n optimize: bool = False,\n round_zero: bool = False,\n axis: int = 0,\n bitpack: bool = True,\n compute_dtype: torch.dtype | None = None,\n view_as_float: bool = False,\n device: str = \"cuda\",\n ) -> tuple:\n assert nbits in Quantizer.SUPPORTED_BITS, (\n \"nbits=\" + str(nbits) + \" not supported.\"\n )\n assert axis in [0, 1], \"axis should be either 0 or 1\"\n if group_size is not None:\n assert is_divisible(tensor.numel(), group_size), (\n \"group_size should be divisble by the total tensor dimensions. shape: \"\n + str(tensor.shape)\n + \", group_size: \"\n + str(group_size)\n )\n\n W = tensor.float()\n shape = W.shape\n\n # Reshape for grouping\n if (group_size is not None) and channel_wise:\n W = (\n W.reshape([-1, group_size])\n if (axis == 1)\n else W.reshape([group_size, -1])\n )\n\n # Get min/max values\n if not channel_wise:\n _min, _max = W.min(), W.max()\n optimize = False\n else:\n _min = W.min(axis=axis, keepdim=True)[0]\n _max = W.max(axis=axis, keepdim=True)[0]\n\n max_v = 2**nbits - 1\n min_v = 0\n min_max = [min_v, max_v]\n\n # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.\n scale = (max_v / (_max - _min)).clamp(\n max=2e4\n ) # clamp to avoid half-precision problems\n zero = -_min * scale\n\n # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14\n if round_zero:\n zero = torch.round(zero)\n\n # Fine-tune weights\n if optimize:\n W_q, scale, zero = Quantizer.optimize_weights(\n tensor=W,\n scale=scale,\n zero=zero,\n min_max=min_max,\n axis=axis,\n device=device,\n )\n else:\n W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])\n\n # Store meta-data (we invert the scale for dequantization)\n meta = {\n \"nbits\": nbits,\n \"group_size\": group_size,\n \"shape\": shape,\n \"scale\": 1.0 / scale,\n \"zero\": zero,\n \"axis\": axis,\n \"packing\": Quantizer.bit_to_packing[nbits],\n }\n meta[\"unpack_view_dtype\"] = Quantizer.unpack_view_dtype[meta[\"packing\"]]\n\n # Pack bits\n meta[\"view_as_float\"] = view_as_float\n if bitpack:\n W_q = Quantizer.pack[meta[\"packing\"]](W_q)\n if view_as_float:\n W_q = W_q.view(\n torch.float32 if compute_dtype is None else compute_dtype\n ) # store quantized weights as compute_dtype\n else:\n W_q = W_q.to(tensor.dtype)\n meta[\"packing\"] = None\n\n # cleanup\n del W, _min, _max\n torch.cuda.empty_cache()\n\n return W_q, meta\n```\n###test function signature:\n```python\n\n def test_quantizer(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_quantizer", + "reference": "\n def test_quantizer(self): \n for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:\n for nbits in [8,4,3,2,1]:\n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, view_as_float=False)\n if nbits == 3:\n assert W_q.dtype == torch.int32\n else:\n assert W_q.dtype == torch.uint8\n w_dq = Quantizer.dequantize(W_q, meta)\n norm1 = torch.norm(self.w - w_dq, p=0.7)\n \n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, compute_dtype=compute_dtype, view_as_float=True)\n assert W_q.dtype == compute_dtype\n w_dq = Quantizer.dequantize(W_q, meta)\n norm2 = torch.norm(self.w - w_dq, p=0.7)\n \n self.assertTrue(torch.equal(norm1, norm2))", + "completions": [ + "```python\n\n def test_quantizer(self):\n # test case 1\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 2\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=True, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], 0.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 3\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=True, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 4\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor," + ], + "line": 100, + "token": 645, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 49, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_num_ref_area", + "reference": "\ndef test_nanoplot_out_bars_with_num_ref_area():\n\n out_bars_with_num_ref_area = _generate_nanoplot(**CASES[13])\n\n assert _is_nanoplot_output(out_bars_with_num_ref_area)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_area = [1, 3]\n y_ref_line = 2\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_" + ], + "line": 749, + "token": 7038, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 50, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef mean(x, axis=None, keepdims=False):\n if isinstance(x, (list, tuple)):\n x = stack(x)\n x = convert_to_tensor(x)\n if axis == () or axis == []:\n # Torch handles the empty axis case differently from numpy.\n return x\n axis = to_tuple_or_list(axis) # see [NB] below\n\n ori_dtype = standardize_dtype(x.dtype)\n # torch.mean only supports floating point inputs\n compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n if \"int\" in ori_dtype or ori_dtype == \"bool\":\n result_dtype = compute_dtype\n else:\n result_dtype = ori_dtype\n\n # [NB] the python torch op torch.mean() is generated into\n # `torch._C._VariableFunctions.pyi`, and the method\n # signature is overloaded.\n # Dynamo won't actually find the correct signature of\n # `torch.mean()` if arguments are passed via kwargs\n # So we have to pass the arguments via positional args\n # EXCEPT for those that are forced as kwargs via the `*`\n # delimiter in the overloaded method signatures.\n # Additionally, we have to create a singleton-tuple\n # when `axis` is an int to match the existing fn signature\n result = torch.mean(\n x,\n axis,\n keepdims,\n dtype=to_torch_dtype(compute_dtype),\n )\n return cast(result, result_dtype)\n```\n###test function signature:\n```python\n\n def test_mean(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mean", + "reference": "\n def test_mean(self):\n x = np.array([[1, 2, 3], [3, 2, 1]])\n self.assertAllClose(knp.mean(x), np.mean(x))\n self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n self.assertAllClose(\n knp.mean(x, axis=1, keepdims=True),\n np.mean(x, axis=1, keepdims=True),\n )\n\n self.assertAllClose(knp.Mean()(x), np.mean(x))\n self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n self.assertAllClose(\n knp.Mean(axis=1, keepdims=True)(x),\n np.mean(x, axis=1, keepdims=True),\n )\n\n # test overflow\n x = np.array([65504, 65504, 65504], dtype=\"float16\")\n self.assertAllClose(knp.mean(x), np.mean(x))", + "completions": [ + "```python\n\n def test_mean(self):\n x = torch.randn(2, 3, 4)\n self.assertEqual(mean(x), torch.mean(x))\n self.assertEqual(mean(x, axis=0), torch.mean(x, dim=0))\n self.assertEqual(mean(x, axis=1), torch.mean(x, dim=1))\n self.assertEqual(mean(x, axis=2), torch.mean(x, dim=2))\n self.assertEqual(mean(x, axis=(0, 1)), torch.mean(x, dim=(0, 1)))\n self.assertEqual(mean(x, axis=(0, 2)), torch.mean(x, dim=(0, 2)))\n self.assertEqual(mean(x, axis=(1, 2)), torch.mean(x, dim=(1, 2)))\n self.assertEqual(mean(x, axis=(0, 1, 2)), torch.mean(x, dim=(0, 1, 2)))\n self.assertEqual(mean(x, axis=0, keepdims=True), torch.mean(x, dim=0, keepdim=True))\n self.assertEqual(mean(x, axis=1, keepdims=True), torch.mean(x, dim=1, keepdim=True))\n self.assertEqual(mean(x, axis=2, keepdims=True), torch.mean(x, dim=2, keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1), keepdims=True), torch.mean(x, dim=(0, 1), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 2), keepdims=True), torch.mean(x, dim=(0, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(1, 2), keepdims=True), torch.mean(x, dim=(1, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1, 2), keepdims=True), torch.mean(x, dim=(0, 1, 2), keepdim=True))\n```\n" + ], + "line": 42, + "token": 322, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 51, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_intrinsic_value(\n cash_flow: float,\n growth_rate: float,\n perpetual_growth_rate: float,\n weighted_average_cost_of_capital: float,\n cash_and_cash_equivalents: float,\n total_debt: float,\n shares_outstanding: float,\n periods: int = 5,\n) -> pd.DataFrame:\n components = {}\n\n cash_flow_projection = [cash_flow]\n\n # Cash Flow to Use\n for period in range(1, periods + 1):\n if period == 1:\n cash_flow_projection.append(cash_flow_projection[0] * (1 + growth_rate))\n else:\n cash_flow_projection.append(\n cash_flow_projection[period - 1] * (1 + growth_rate)\n )\n\n # Calculate the Terminal Value\n terminal_value = (\n cash_flow_projection[-1]\n * (1 + perpetual_growth_rate)\n / (weighted_average_cost_of_capital - perpetual_growth_rate)\n )\n\n # Add Terminal Value to the end of the cash flow projection\n cash_flow_projection[-1] = cash_flow_projection[-1] + terminal_value\n\n # Calculate the Present Value based on the Discounted Cash Flow Formula\n cash_flow_present_value = []\n for index, cash_flow_value in enumerate(cash_flow_projection):\n cash_flow_present_value.append(\n cash_flow_value / (1 + weighted_average_cost_of_capital) ** (index + 1)\n )\n\n # Calculate the Enterprise Value\n enterprise_value = sum(cash_flow_present_value)\n\n # Calculate the Equity Value\n equity_value = enterprise_value + cash_and_cash_equivalents - total_debt\n\n # Calculate the Intrinsic Value\n intrinsic_value = equity_value / shares_outstanding\n\n # Combine the components into a dictionary\n components = {\n \"Terminal Value\": terminal_value,\n \"Cash Flow Projection\": cash_flow_projection[-1],\n \"Enterprise Value\": enterprise_value,\n \"Equity Value\": equity_value,\n \"Intrinsic Value\": intrinsic_value,\n }\n\n return pd.DataFrame.from_dict(\n components, orient=\"index\", columns=[f\"Periods = {periods}\"]\n )\n```\n###test function signature:\n```python\n\ndef test_get_intrinsic_value(recorder):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_intrinsic_value", + "reference": "\ndef test_get_intrinsic_value(recorder):\n recorder.capture(\n intrinsic_model.get_intrinsic_value(\n cash_flow=500,\n growth_rate=0.05,\n perpetual_growth_rate=0.02,\n weighted_average_cost_of_capital=0.1,\n cash_and_cash_equivalents=100,\n total_debt=100,\n shares_outstanding=100,\n periods=5,\n )\n )", + "completions": [ + "```python\n\ndef test_get_intrinsic_value(recorder):\n cash_flow = 1000000\n growth_rate = 0.05\n perpetual_growth_rate = 0.05\n weighted_average_cost_of_capital = 0.1\n cash_and_cash_equivalents = 1000000\n total_debt = 1000000\n shares_outstanding = 1000000\n periods = 5\n\n result = get_intrinsic_value(\n cash_flow,\n growth_rate,\n perpetual_growth_rate,\n weighted_average_cost_of_capital,\n cash_and_cash_equivalents,\n total_debt,\n shares_outstanding,\n periods,\n )\n\n recorder.capture(result)\n```\n" + ], + "line": 61, + "token": 379, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 52, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef to_numpy(\n evset: EventSet,\n timestamp_to_datetime: bool = True,\n timestamps: bool = True,\n) -> Dict[str, ndarray]:\n timestamp_key = \"timestamp\"\n index_names = evset.schema.index_names()\n feature_names = evset.schema.feature_names()\n\n column_names = index_names + feature_names\n if timestamps:\n column_names += [timestamp_key]\n\n dst = {column_name: [] for column_name in column_names}\n for index, data in evset.data.items():\n assert isinstance(index, tuple)\n\n if timestamps:\n # Timestamps\n if evset.schema.is_unix_timestamp and timestamp_to_datetime:\n dst[timestamp_key].append(\n data.timestamps.astype(\"datetime64[s]\")\n )\n else:\n dst[timestamp_key].append(data.timestamps)\n\n # Features\n for feature_name, feature in zip(feature_names, data.features):\n dst[feature_name].append(feature)\n\n # Indexes\n num_timestamps = len(data.timestamps)\n for index_name, index_item in zip(index_names, index):\n dst[index_name].append(np.repeat(index_item, num_timestamps))\n\n dst = {k: np.concatenate(v) for k, v in dst.items()}\n return dst\n```\n###test function signature:\n```python\n\n def test_correct(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct", + "reference": "\n def test_correct(self):\n evset = event_set(\n timestamps=[\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n features={\n \"feature_1\": [0.5, 0.6],\n \"my_index\": [\"red\", \"blue\"],\n },\n indexes=[\"my_index\"],\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"timestamp\": np.array(\n [\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n dtype=\"datetime64[s]\",\n ),\n \"feature_1\": np.array([0.5, 0.6]),\n \"my_index\": np.array([b\"red\", b\"blue\"]),\n }\n\n for k in expected:\n np.testing.assert_array_equal(\n np.sort(result[k]), np.sort(expected[k])\n )", + "completions": [ + "```python\n def test_correct(self):\n evset = EventSet(\n data={\n \"timestamp\": [1610000000, 1610000000, 1610000000, 1610000000],\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n },\n schema=EventSetSchema(\n index_names=[\"index\"],\n feature_names=[\"feature1\", \"feature2\"],\n is_unix_timestamp=True,\n ),\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n \"timestamp\": [\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n ],\n }\n\n self.assertDictEqual(result, expected)\n```\n" + ], + "line": 41, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def from_file(cls, path: str) -> \"LayerConfig\":\n logger.info(f\"Loading route config from {path}\")\n _, ext = os.path.splitext(path)\n with open(path, \"r\") as f:\n if ext == \".json\":\n layer = json.load(f)\n elif ext in [\".yaml\", \".yml\"]:\n layer = yaml.safe_load(f)\n else:\n raise ValueError(\n \"Unsupported file type. Only .json and .yaml are supported\"\n )\n\n if not is_valid(json.dumps(layer)):\n raise Exception(\"Invalid config JSON or YAML\")\n\n encoder_type = layer[\"encoder_type\"]\n encoder_name = layer[\"encoder_name\"]\n routes = []\n for route_data in layer[\"routes\"]:\n # Handle the 'llm' field specially if it exists\n if \"llm\" in route_data and route_data[\"llm\"] is not None:\n llm_data = route_data.pop(\n \"llm\"\n ) # Remove 'llm' from route_data and handle it separately\n # Use the module path directly from llm_data without modification\n llm_module_path = llm_data[\"module\"]\n # Dynamically import the module and then the class from that module\n llm_module = importlib.import_module(llm_module_path)\n llm_class = getattr(llm_module, llm_data[\"class\"])\n # Instantiate the LLM class with the provided model name\n llm = llm_class(name=llm_data[\"model\"])\n # Reassign the instantiated llm object back to route_data\n route_data[\"llm\"] = llm\n\n # Dynamically create the Route object using the remaining route_data\n route = Route(**route_data)\n routes.append(route)\n\n return cls(\n encoder_type=encoder_type, encoder_name=encoder_name, routes=routes\n )\n```\n###test function signature:\n```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_from_file_yaml", + "reference": "\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n config_path = tmp_path / \"config.yaml\"\n config_path.write_text(\n layer_yaml()\n ) # Assuming layer_yaml() returns a valid YAML string\n\n # Load the LayerConfig from the temporary file\n layer_config = LayerConfig.from_file(str(config_path))\n\n # Assertions to verify the loaded configuration\n assert layer_config.encoder_type == \"cohere\"\n assert layer_config.encoder_name == \"embed-english-v3.0\"\n assert len(layer_config.routes) == 2\n assert layer_config.routes[0].name == \"politics\"", + "completions": [ + "```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n yaml_file = tmp_path / \"test_layer.yaml\"\n yaml_file.write_text(\n \"\"\"\n encoder_type: openai\n encoder_name: gpt2\n routes:\n - name: test_route\n model_name: test_model\n llm:\n module: tests.test_layer\n class: TestLLM\n model: test_model\n encoder_kwargs:\n max_position_embeddings: 512\n num_attention_heads: 8\n intermediate_size: 2048\n hidden_act: gelu\n hidden_dropout_prob: 0.1\n attention_probs_dropout_prob: 0.1\n type_vocab_size: 16\n initializer_range: 0.02\n layer_norm_eps: 1e-12\n use_cache: True\n is_decoder: False\n classifier_dropout: 0.1\n use_multiplicative_input_mask: False\n num_labels: 2\n output_attentions: False\n output_hidden_states: False\n return_dict: True\n task_specific_params:\n summarization:\n max_length: 20\n min_length: 5\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n" + ], + "line": 49, + "token": 405, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 54, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def bind_all(self, circuit: QuantumCircuit) -> np.ndarray:\n arr = np.empty(self.shape, dtype=object)\n for idx in np.ndindex(self.shape):\n arr[idx] = self.bind(circuit, idx)\n return arr\n```\n###test function signature:\n```python\n def test_data(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_data", + "reference": " def test_data(self):\n with self.subTest(\"binding a single value\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 0)\n self.assertEqual(ba.shape, ())\n self.assertEqual(ba.size, 1)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuit = ba.bind(self.circuit, ())\n self.assertEqual(bound_circuit, self.circuit.assign_parameters(vals))\n\n with self.subTest(\"binding an array\"):\n vals = np.linspace(0, 1, 300).reshape((2, 3, 50))\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 2)\n self.assertEqual(ba.shape, (2, 3))\n self.assertEqual(ba.size, 6)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuits = ba.bind_all(self.circuit)\n self.assertEqual(bound_circuits.shape, (2, 3))\n self.assertEqual(bound_circuits[0, 0], self.circuit.assign_parameters(vals[0, 0]))\n self.assertEqual(bound_circuits[0, 1], self.circuit.assign_parameters(vals[0, 1]))\n self.assertEqual(bound_circuits[0, 2], self.circuit.assign_parameters(vals[0, 2]))\n self.assertEqual(bound_circuits[1, 0], self.circuit.assign_parameters(vals[1, 0]))\n self.assertEqual(bound_circuits[1, 1], self.circuit.assign_parameters(vals[1, 1]))\n self.assertEqual(bound_circuits[1, 2], self.circuit.assign_parameters(vals[1, 2]))\n\n with self.subTest(\"binding a single param\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params[0]: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 1)\n self.assertEqual(ba.ndim, 1)\n self.assertEqual(ba.shape, (50,))\n self.assertEqual(ba.size, 50)\n self.assertEqual(list(ba.data.keys()), [(self.params[0].name,)])\n np.testing.assert_allclose(list(ba.data.values()), [vals[..., np.newaxis]])", + "completions": [ + "```python\n def test_data(self):\n from qiskit import QuantumCircuit\n from qiskit.quantum_info import Pauli\n from qiskit.quantum_info import PauliTable\n from qiskit.quantum_info import PauliSumOp\n from qiskit.quantum_info import PauliOp\n from qiskit.quantum_info import PauliExpectation\n from qiskit.quantum_info import PauliMoment\n from qiskit.quantum_info import PauliTensor\n from qiskit.quantum_info import PauliList\n from qiskit.quantum_info import PauliTransform\n from qiskit.quantum_info import PauliTransformOp\n from qiskit.quantum_info import PauliTransformList\n from qiskit.quantum_info import PauliTransformSum\n from qiskit.quantum_info import PauliTransformProduct\n from qiskit.quantum_info import PauliTransformExpectation\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum" + ], + "line": 15, + "token": 122, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef set_seed(\n seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None\n) -> int:\n world_size, rank = get_world_size_and_rank()\n max_val = np.iinfo(np.uint32).max - world_size + 1\n min_val = np.iinfo(np.uint32).min\n if seed is None:\n rand_seed = torch.randint(min_val, max_val, (1,))\n seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks\n if seed < min_val or seed > max_val:\n raise ValueError(\n f\"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]\"\n )\n local_seed = seed + rank\n _log.debug(\n f\"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}\"\n )\n\n torch.manual_seed(local_seed)\n np.random.seed(local_seed)\n random.seed(local_seed)\n\n if debug_mode is not None:\n _log.debug(f\"Setting deterministic debug mode to {debug_mode}\")\n torch.set_deterministic_debug_mode(debug_mode)\n deterministic_debug_mode = torch.get_deterministic_debug_mode()\n if deterministic_debug_mode == 0:\n _log.debug(\"Disabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = False\n torch.backends.cudnn.benchmark = True\n else:\n _log.debug(\"Enabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.benchmark = False\n # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n return seed\n```\n###test function signature:\n```python\n\n def test_deterministic_true(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_deterministic_true", + "reference": "\n def test_deterministic_true(self) -> None:\n for det_debug_mode, det_debug_mode_str in [(1, \"warn\"), (2, \"error\")]:\n warn_only = det_debug_mode == 1\n for debug_mode in (det_debug_mode, det_debug_mode_str):\n set_seed(42, debug_mode=debug_mode)\n assert torch.backends.cudnn.deterministic\n assert not torch.backends.cudnn.benchmark\n assert det_debug_mode == torch.get_deterministic_debug_mode()\n assert torch.are_deterministic_algorithms_enabled()\n assert (\n warn_only == torch.is_deterministic_algorithms_warn_only_enabled()\n )\n assert os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] == \":4096:8\"", + "completions": [ + "```python\n def test_deterministic_true(self) -> None:\n seed = set_seed(debug_mode=True)\n self.assertEqual(seed, 123456789)\n```\n" + ], + "line": 45, + "token": 345, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 56, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n```\n###test function signature:\n```python\n\ndef test_concat_periodarray_2d():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_concat_periodarray_2d", + "reference": "\ndef test_concat_periodarray_2d():\n pi = pd.period_range(\"2016-01-01\", periods=36, freq=\"D\")\n arr = pi._data.reshape(6, 6)\n\n result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)\n tm.assert_period_array_equal(result, arr)\n\n result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)\n tm.assert_period_array_equal(result, arr)\n\n msg = (\n \"all the input array dimensions.* for the concatenation axis must match exactly\"\n )\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)\n\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:2], arr[2:]], axis=1)", + "completions": [ + "```python\n\ndef test_concat_periodarray_2d():\n import numpy as np\n import pandas as pd\n from pandas.core.dtypes.common import is_period_dtype\n\n def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n\n def _concat_datetime(to_concat, axis: AxisInt = 0):\n from pandas.core.arrays import PeriodArray\n\n if axis == 0:\n " + ], + "line": 70, + "token": 610, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 57, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef ingest_ocp_payload(request):\n request_id = uuid4().hex\n response_data = {\"request-id\": request_id, \"payload-name\": []}\n if request.method == \"POST\":\n if not request.FILES:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for _, file in request.FILES.items():\n payload_name = file.name\n response_data[\"payload-name\"].append(payload_name)\n s3_signature = get_s3_signature(settings.S3_ENDPOINT, payload_name, method=\"put_object\")\n res = upload_file_to_s3(s3_signature, data=file.file)\n if res.status_code == HTTPStatus.OK:\n response_data[\"upload\"] = \"success\"\n else:\n response_data[\"upload\"] = \"failed\"\n response_data[\"failed-reason\"] = res.reason\n return Response(response_data, status=res.status_code)\n send_payload(request_id, payload_name)\n else:\n params = request.query_params\n payload_names = params.get(\"payload_name\")\n if not payload_names:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for payload_name in payload_names.split(\",\"):\n send_payload(request_id, payload_name)\n response_data[\"payload-name\"].append(payload_name)\n\n response_data[\"ingest-started\"] = True\n\n return Response(response_data, status=HTTPStatus.ACCEPTED)\n```\n###test function signature:\n```python\n\n def test_ingest_ocp_payload(self):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ingest_ocp_payload", + "reference": "\n def test_ingest_ocp_payload(self):\n # Arrange\n file1 = SimpleUploadedFile(\"file1.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n file2 = SimpleUploadedFile(\"file2.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n test_table = [\n # Happy path tests\n (\n \"POST\",\n {\"file1\": file1, \"file2\": file2},\n \"\",\n HTTPStatus.ACCEPTED,\n {\"upload\": \"success\", \"ingest-started\": True},\n \"happy_path_post\",\n ),\n (\"GET\", {}, \"?payload_name=file1,file2\", HTTPStatus.ACCEPTED, {\"ingest-started\": True}, \"happy_path_get\"),\n # Edge cases\n (\"POST\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_files_post\"),\n (\"GET\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_payload_name_get\"),\n # Error cases\n (\"POST\", {\"file1\": file1}, \"\", HTTPStatus.INTERNAL_SERVER_ERROR, {\"upload\": \"failed\"}, \"error_case_post\"),\n (\"GET\", {}, \"?payload_name=non_existent_file\", HTTPStatus.ACCEPTED, {}, \"error_case_get\"),\n ]\n for test in test_table:\n method, files, query_params, expected_status, expected_response, test_id = test\n with self.subTest(test=test_id):\n\n mock_response = MagicMock()\n mock_response.status_code = (\n HTTPStatus.OK if expected_response.get(\"upload\") != \"failed\" else HTTPStatus.INTERNAL_SERVER_ERROR\n )\n mock_response.reason = \"Some reason\"\n\n with patch(\"masu.api.ingest_ocp_payload.get_s3_signature\") as mock_get_s3_signature, patch(\n \"masu.api.ingest_ocp_payload.upload_file_to_s3\", return_value=mock_response\n ) as mock_upload_file_to_s3, patch(\"masu.api.ingest_ocp_payload.send_payload\") as mock_send_payload:\n\n request = (\n APIRequestFactory().get(f\"/{query_params}\")\n if method == \"GET\"\n else APIRequestFactory().post(\"/\", data=files)\n )\n\n # Act\n response = ingest_ocp_payload(request)\n\n # Assert\n self.assertEqual(response.status_code, expected_status)\n if method == \"POST\" and files:\n mock_get_s3_signature.assert_called()\n mock_upload_file_to_s3.assert_called()\n if method == \"GET\" and \"payload_name\" in query_params:\n mock_send_payload.assert_called()\n for key, value in expected_response.items():\n self.assertEqual(response.data[key], value)", + "completions": [ + "```python\n def test_ingest_ocp_payload(self):\n # Arrange\n request = self.factory.post(\n \"/api/v1/ingest/ocp/payload\",\n data={\"payload_name\": \"payload_name\"},\n content_type=\"multipart/form-data\",\n )\n request.FILES = {\"payload_name\": \"payload_name\"}\n # Act\n response = ingest_ocp_payload(request)\n # Assert\n self.assertEqual(response.status_code, HTTPStatus.ACCEPTED)\n self.assertEqual(response.data[\"request-id\"], \"request_id\")\n self.assertEqual(response.data[\"payload-name\"], [\"payload_name\"])\n self.assertEqual(response.data[\"ingest-started\"], True)\n```\n" + ], + "line": 41, + "token": 308, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def compute_nb_items_with_metadata_field(\n project_id: str,\n metadata_field: str,\n collection_name: str,\n quantile_value: float = 0.1,\n) -> Tuple[int, float, int]:\n\n mongo_db = await get_mongo_db()\n\n try:\n pipeline = [\n # Match on the project id and the existence of the metadata field\n {\n \"$match\": {\n \"project_id\": project_id,\n f\"metadata.{metadata_field}\": {\"$exists\": True, \"$ne\": None},\n }\n },\n {\n \"$group\": {\n \"_id\": \"$user_id\", # Group by 'user_id'\n \"count\": {\"$sum\": 1}, # Count the documents in each group\n }\n },\n {\n \"$sort\": {\"count\": -1} # Sort by 'count' in descending order\n },\n ]\n\n # Execute the aggregation pipeline\n results = [] # List to hold the results\n async for group in mongo_db[collection_name].aggregate(pipeline):\n results.append(group[\"count\"])\n\n logger.debug(\n f\"Results for {metadata_field} in collection {collection_name} for project {project_id}: {results}\"\n )\n\n # If no results, return 0\n if len(results) == 0:\n return 0, 0.0, 0\n\n # Get the average and the quantiles\n average = sum(results) / len(results)\n bottom_quantile = results[int(len(results) * quantile_value)]\n top_quantile = results[int(len(results) * (1 - quantile_value))]\n\n return bottom_quantile, average, top_quantile\n\n except Exception as e:\n logger.warning(\n f\"Failed to fetch the number of {metadata_field} in collection {collection_name} for project {project_id}: {e}\",\n )\n return 0, 0.0, 0\n```\n###test function signature:\n```python\n\nasync def test_main_pipeline(db, populated_project):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_main_pipeline", + "reference": "\nasync def test_main_pipeline(db, populated_project):\n async for mongo_db in db:\n test_project_id = populated_project.id\n\n # Run the function\n topics = await fetch_topics(test_project_id)\n\n logger.debug(f\"Topics: {topics}\")\n\n count = await nb_items_with_a_metadata_field(\n test_project_id, \"tasks\", \"user_id\"\n )\n\n assert count > 0\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_successrate_metadata_quantiles(test_project_id, \"user_id\")\n\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert bottom_quantile >= 0\n # assert bottom_quantile <= top_quantile\n\n test_project_id = populated_project.id\n test_metadata_field = \"user_id\"\n test_collection_name = \"tasks\"\n test_quantile_value = 0.1\n\n # Run the function\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_nb_items_with_metadata_field(\n test_project_id,\n test_metadata_field,\n test_collection_name,\n test_quantile_value,\n )\n\n logger.debug(\"runned compute_nb_items_with_metadata_field\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, int)\n assert isinstance(average, float)\n assert isinstance(top_quantile, int)\n # assert bottom_quantile <= average <= top_quantile\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_session_length_per_metadata(\n test_project_id,\n test_metadata_field,\n test_quantile_value,\n )\n logger.debug(\"runned compute_session_length_per_metadata\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, float)\n assert isinstance(average, float)\n assert isinstance(top_quantile, float)\n # assert bottom_quantile <= average <= top_quantile", + "completions": [ + "```python\nasync def test_main_pipeline(db, populated_project):\n # Arrange\n project_id = populated_project.id\n collection_name = \"collection_name\"\n metadata_field = \"metadata_field\"\n quantile_value = 0.1\n\n # Act\n result = await compute_nb_items_with_metadata_field(\n project_id, metadata_field, collection_name, quantile_value\n )\n\n # Assert\n assert result == (0, 0.0, 0)\n```\n" + ], + "line": 56, + "token": 430, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n```\n###test function signature:\n```python\n\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_engine_auto_error_message", + "reference": "\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n from pandas.compat._optional import VERSIONS\n\n # Do we have engines installed, but a bad version of them?\n pa_min_ver = VERSIONS.get(\"pyarrow\")\n fp_min_ver = VERSIONS.get(\"fastparquet\")\n have_pa_bad_version = (\n False\n if not _HAVE_PYARROW\n else Version(pyarrow.__version__) < Version(pa_min_ver)\n )\n have_fp_bad_version = (\n False\n if not _HAVE_FASTPARQUET\n else Version(fastparquet.__version__) < Version(fp_min_ver)\n )\n # Do we have usable engines installed?\n have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version\n have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version\n\n if not have_usable_pa and not have_usable_fp:\n # No usable engines found.\n if have_pa_bad_version:\n match = f\"Pandas requires version .{pa_min_ver}. or newer of .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n\n if have_fp_bad_version:\n match = f\"Pandas requires version .{fp_min_ver}. or newer of .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")", + "completions": [ + "```python\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n import pyarrow\n import fastparquet\n from pyarrow.parquet import ParquetEngine\n from fastparquet.parquet import ParquetEngine as FastParquetEngine\n\n def get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n\n # pyarrow is installed\n assert get_engine(\"auto\").__class__ == PyArrowImpl\n assert get_engine(\"pyarrow\").__class__ == PyArrowImpl\n assert get_engine(\"fastparquet\").__class__ == FastParquetImpl\n\n # pyarrow is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'pyarrow'\" in str(excinfo.value)\n\n # fastparquet is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'fastparquet'\" in str(excinfo.value)\n\n # pyarrow is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n\n # fastparquet is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n" + ], + "line": 38, + "token": 277, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 60, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def attribute(\n self,\n inp: InterpretableInput,\n target: Union[str, torch.Tensor, None] = None,\n num_trials: int = 1,\n gen_args: Optional[Dict] = None,\n # internal callback hook can be used for logging\n _inspect_forward: Optional[Callable] = None,\n **kwargs,\n ) -> LLMAttributionResult:\n\n assert isinstance(\n inp, self.SUPPORTED_INPUTS\n ), f\"LLMAttribution does not support input type {type(inp)}\"\n\n if target is None:\n # generate when None\n assert hasattr(self.model, \"generate\") and callable(self.model.generate), (\n \"The model does not have recognizable generate function.\"\n \"Target must be given for attribution\"\n )\n\n if not gen_args:\n gen_args = DEFAULT_GEN_ARGS\n\n model_inp = self._format_model_input(inp.to_model_input())\n output_tokens = self.model.generate(model_inp, **gen_args)\n target_tokens = output_tokens[0][model_inp.size(1) :]\n else:\n assert gen_args is None, \"gen_args must be None when target is given\"\n\n if type(target) is str:\n # exclude sos\n target_tokens = self.tokenizer.encode(target)[1:]\n target_tokens = torch.tensor(target_tokens)\n elif type(target) is torch.Tensor:\n target_tokens = target\n\n attr = torch.zeros(\n [\n 1 + len(target_tokens) if self.include_per_token_attr else 1,\n inp.n_itp_features,\n ],\n dtype=torch.float,\n device=self.device,\n )\n\n for _ in range(num_trials):\n attr_input = inp.to_tensor().to(self.device)\n\n cur_attr = self.attr_method.attribute(\n attr_input,\n additional_forward_args=(inp, target_tokens, _inspect_forward),\n **kwargs,\n )\n\n # temp necessary due to FA & Shapley's different return shape of multi-task\n # FA will flatten output shape internally (n_output_token, n_itp_features)\n # Shapley will keep output shape (batch, n_output_token, n_input_features)\n cur_attr = cur_attr.reshape(attr.shape)\n\n attr += cur_attr\n\n attr = attr / num_trials\n\n attr = inp.format_attr(attr)\n\n return LLMAttributionResult(\n attr[0],\n attr[1:]\n if self.include_per_token_attr\n else None, # shape(n_output_token, n_input_features)\n inp.values,\n self.tokenizer.convert_ids_to_tokens(target_tokens),\n )\n```\n###test function signature:\n```python\n\n def test_llm_attr(self, AttrClass) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_llm_attr", + "reference": "\n def test_llm_attr(self, AttrClass) -> None:\n llm = DummyLLM()\n llm.to(self.device)\n tokenizer = DummyTokenizer()\n llm_attr = LLMAttribution(AttrClass(llm), tokenizer)\n\n inp = TextTemplateInput(\"{} b {} {} e {}\", [\"a\", \"c\", \"d\", \"f\"])\n res = llm_attr.attribute(inp, \"m n o p q\")\n\n self.assertEqual(res.seq_attr.shape, (4,))\n self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))\n self.assertEqual(res.input_tokens, [\"a\", \"c\", \"d\", \"f\"])\n self.assertEqual(res.output_tokens, [\"m\", \"n\", \"o\", \"p\", \"q\"])\n self.assertEqual(res.seq_attr.device.type, self.device)\n self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)", + "completions": [ + "```python\n def test_llm_attr(self, AttrClass) -> None:\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for L" + ], + "line": 72, + "token": 464, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 61, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def plan(self, task: str):\n agent_pool_description = self.agent_pool.get_agent_pool_description()\n feedback = self.get_last_feedback()\n last_plan = self.get_last_plan()\n\n # run the planning\n plan = self.planning.execute(task, agent_pool_description, last_plan, feedback)\n\n logger.info(f\"Plan: {plan}\")\n\n self.shared_memory.add(\n event_type=EventType.planning, agent=self.name, content=str(plan)\n )\n\n return plan\n```\n###test function signature:\n```python\n\ndef test_planner_succeeds(get_llm): # noqa: F811\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_planner_succeeds", + "reference": "\ndef test_planner_succeeds(get_llm): # noqa: F811\n llm = get_llm(__file__, test_planner_succeeds.__name__)\n\n physicist_description = (\n \"The physicist agent answers questions or research about physics-related topics\"\n )\n physicist = Physicist(\n name=\"Physicist\",\n description=physicist_description,\n llm=llm,\n )\n\n agent_pool = AgentPool()\n agent_pool.add_agents([physicist])\n\n shared_memory = SharedMemory(\n objective=\"Share the information across different agents.\",\n agent_pool=agent_pool,\n )\n\n planner = Planner(\n name=\"planner\",\n agent_pool=agent_pool,\n shared_memory=shared_memory,\n llm=llm,\n )\n\n task = \"\"\"We need to render a highly complex 3D image on the solar system. We can use any publicly avaliable\n resources to achieve this task.\"\"\" # noqa: E501\n\n p = planner.plan(task=task)\n assert len(p.steps) > 0", + "completions": [ + "```python\ndef test_planner_succeeds(get_llm): # noqa: F811\n planner = get_llm\n planner.agent_pool.set_agent_pool_description(\n {\n \"agent_pool\": {\n \"agents\": [\n {\n \"name\": \"agent1\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n {\n \"name\": \"agent2\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n ]\n }\n }\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task1\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task2\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task3\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task4\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task5\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task6\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task7\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\"," + ], + "line": 21, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_edge(self, from_step: str, to_step: str) -> None:\n if from_step not in self.G:\n raise ValueError(f\"Step with name '{from_step}' does not exist\")\n\n if to_step not in self.G:\n raise ValueError(f\"Step with name '{to_step}' does not exist\")\n\n if to_step in self.G[from_step]:\n raise ValueError(\n f\"There is already a edge from '{to_step}' to '{from_step}'\"\n )\n\n if to_step in nx.ancestors(self.G, from_step):\n raise ValueError(\n f\"Cannot add edge from '{from_step}' to '{to_step}' as it would create a cycle.\"\n )\n\n self.G.add_edge(from_step, to_step)\n```\n###test function signature:\n```python\n\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validate_missing_step_input", + "reference": "\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n class DummyStep3(Step):\n @property\n def inputs(self) -> List[str]:\n return [\"instruction\"]\n\n @property\n def outputs(self) -> List[str]:\n return [\"response\"]\n\n def process(self) -> \"StepOutput\": # type: ignore\n yield [{\"response\": \"response1\"}]\n\n dag = DAG()\n dag.add_step(dummy_generator_step)\n dag.add_step(DummyStep3(name=\"dummy_step_3\", pipeline=pipeline))\n dag.add_edge(\"dummy_generator_step\", \"dummy_step_3\")\n\n with pytest.raises(\n ValueError,\n match=\"Step 'dummy_step_3' should have a parameter with type hint `StepInput`\",\n ):\n dag.validate()", + "completions": [ + "```python\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_2\")\n```\n" + ], + "line": 26, + "token": 238, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 63, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_hourly_temp_mean", + "reference": "\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n meter_data = il_electricity_cdd_hdd_hourly[\"meter_data\"][\"2016-03-01\":\"2016-07-01\"]\n temperature_data = il_electricity_cdd_hdd_hourly[\"temperature_data\"][\n \"2016-03-01\":\"2016-07-01\"\n ]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert list(sorted(df.columns)) == [\n \"n_hours_dropped\",\n \"n_hours_kept\",\n \"temperature_mean\",\n ]\n assert df.shape == (2952, 3)\n\n assert round(df.temperature_mean.mean()) == 62.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n df = il_electricity_cdd_hdd_hourly.copy()\n df = df.iloc[100:150]\n df = df.assign(\n **{\n \"cdd_10\": np.maximum(df.temperature_mean - 10, 0),\n \"cdd_20\": np.maximum(df.temperature_mean - 20, 0),\n \"hdd_10\": np.maximum(10 - df.temperature_mean, 0),\n \"hdd_20\": np.maximum(20 - df.temperature_mean, 0),\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n df = df.drop([\"temperature_mean\"], axis=1)\n df = df.reindex(il_electricity_cdd_hdd_hourly.index)\n df = overwrite_partial_rows_with_nan(df)\n df = df.iloc[:-1].reindex(df.index)\n assert_frame_equal(df, il_electricity_cdd_hdd_hourly)\n```\n" + ], + "line": 167, + "token": 992, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 64, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef pandasSQL_builder(\n con,\n schema: str | None = None,\n need_transaction: bool = False,\n) -> PandasSQL:\n import sqlite3\n\n if isinstance(con, sqlite3.Connection) or con is None:\n return SQLiteDatabase(con)\n\n sqlalchemy = import_optional_dependency(\"sqlalchemy\", errors=\"ignore\")\n\n if isinstance(con, str) and sqlalchemy is None:\n raise ImportError(\"Using URI string without sqlalchemy installed.\")\n\n if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):\n return SQLDatabase(con, schema, need_transaction)\n\n warnings.warn(\n \"pandas only supports SQLAlchemy connectable (engine/connection) or \"\n \"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 \"\n \"objects are not tested. Please consider using SQLAlchemy.\",\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return SQLiteDatabase(con)\n```\n###test function signature:\n```python\n\n def test_execute_closed_connection(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_execute_closed_connection", + "reference": "\n def test_execute_closed_connection(self):\n create_sql = \"\"\"\n CREATE TABLE test\n (\n a TEXT,\n b TEXT,\n c REAL,\n PRIMARY KEY (a, b)\n );\n \"\"\"\n with contextlib.closing(sqlite3.connect(\":memory:\")) as conn:\n cur = conn.cursor()\n cur.execute(create_sql)\n\n with sql.pandasSQL_builder(conn) as pandas_sql:\n pandas_sql.execute('INSERT INTO test VALUES(\"foo\", \"bar\", 1.234)')\n\n msg = \"Cannot operate on a closed database.\"\n with pytest.raises(sqlite3.ProgrammingError, match=msg):\n tquery(\"select * from test\", con=conn)", + "completions": [ + "```python\n def test_execute_closed_connection(self):\n from pandas.io.sql import PandasSQL\n from pandas.io.sql.sqldatabase import SQLiteDatabase\n from pandas.io.sql.pandas_sql import PandasSQLQueryCompiler\n from pandas.io.sql.pandas_sql import PandasSQLEngine\n from pandas.io.sql.pandas_sql import PandasSQLTable\n from pandas.io.sql.pandas_sql import PandasSQLConverter\n from pandas.io.sql.pandas_sql import PandasSQLAlchemy\n from pandas.io.sql.pandas_sql import PandasSQLQuery\n from pandas.io.sql.pandas_sql import PandasSQLExecutionContext\n from pandas.io.sql.pandas_sql import PandasSQLCompiler\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas." + ], + "line": 31, + "token": 237, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 65, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_hourly_degree_days", + "reference": "\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index,\n temperature_data,\n heating_balance_points=[60, 61],\n cooling_balance_points=[65, 66],\n temperature_mean=False,\n degree_day_method=\"hourly\",\n )\n assert df.shape == (27, 6)\n assert list(sorted(df.columns)) == [\n \"cdd_65\",\n \"cdd_66\",\n \"hdd_60\",\n \"hdd_61\",\n \"n_hours_dropped\",\n \"n_hours_kept\",\n ]\n snapshot.assert_match(\n [\n round(df.hdd_60.mean(), 2),\n round(df.hdd_61.mean(), 2),\n round(df.cdd_65.mean(), 2),\n round(df.cdd_66.mean(), 2),\n round(df.n_hours_kept.mean(), 2),\n round(df.n_hours_dropped.mean(), 2),\n ],\n \"values\",\n )", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"H\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 1" + ], + "line": 168, + "token": 985, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef initialize_model_parallel(\n tensor_model_parallel_size: int = 1,\n pipeline_model_parallel_size: int = 1,\n virtual_pipeline_model_parallel_size: Optional[int] = None,\n pipeline_model_parallel_split_rank: Optional[int] = None,\n use_fp8: bool = False,\n) -> None:\n # Get world size and rank. Ensure some consistencies.\n assert torch.distributed.is_initialized()\n world_size: int = torch.distributed.get_world_size()\n\n if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:\n raise RuntimeError(\n f\"world_size ({world_size}) is not divisible by tensor_model_parallel_size \"\n f\"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n )\n\n data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n\n num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n num_data_parallel_groups: int = world_size // data_parallel_size\n\n if virtual_pipeline_model_parallel_size is not None:\n if not pipeline_model_parallel_size > 2:\n raise RuntimeError(\"pipeline-model-parallel size should be greater than 2 with \" \"interleaved schedule\")\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size\n\n if pipeline_model_parallel_split_rank is not None:\n global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK\n _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank\n\n rank = torch.distributed.get_rank()\n\n # Build the data-parallel groups.\n global _DATA_PARALLEL_GROUP\n global _DATA_PARALLEL_GROUP_GLOO\n global _DATA_PARALLEL_GLOBAL_RANKS\n assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'\n all_data_parallel_group_ranks = []\n for i in range(pipeline_model_parallel_size):\n start_rank = i * num_pipeline_model_parallel_groups\n end_rank = (i + 1) * num_pipeline_model_parallel_groups\n for j in range(tensor_model_parallel_size):\n ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)\n all_data_parallel_group_ranks.append(list(ranks))\n group = torch.distributed.new_group(ranks)\n group_gloo = torch.distributed.new_group(ranks, backend=\"gloo\")\n if rank in ranks:\n _DATA_PARALLEL_GROUP = group\n _DATA_PARALLEL_GROUP_GLOO = group_gloo\n _DATA_PARALLEL_GLOBAL_RANKS = ranks\n\n # Build the model-parallel groups.\n global _MODEL_PARALLEL_GROUP\n assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'\n for i in range(data_parallel_size):\n ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _MODEL_PARALLEL_GROUP = group\n\n # Build the tensor model-parallel groups.\n global _TENSOR_MODEL_PARALLEL_GROUP\n assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized'\n for i in range(num_tensor_model_parallel_groups):\n ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _TENSOR_MODEL_PARALLEL_GROUP = group\n\n # Build the pipeline model-parallel groups and embedding groups\n # (first and last rank in each pipeline model-parallel group).\n global _PIPELINE_MODEL_PARALLEL_GROUP\n global _PIPELINE_GLOBAL_RANKS\n assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized'\n global _EMBEDDING_GROUP\n global _EMBEDDING_GLOBAL_RANKS\n assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'\n global _POSITION_EMBEDDING_GROUP\n global _POSITION_EMBEDDING_GLOBAL_RANKS\n assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'\n for i in range(num_pipeline_model_parallel_groups):\n ranks = range(i, world_size, num_pipeline_model_parallel_groups)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _PIPELINE_MODEL_PARALLEL_GROUP = group\n _PIPELINE_GLOBAL_RANKS = ranks\n # Setup embedding group (to exchange gradients between\n # first and last stages).\n if len(ranks) > 1:\n embedding_ranks = [ranks[0], ranks[-1]]\n position_embedding_ranks = [ranks[0]]\n if pipeline_model_parallel_split_rank is not None:\n if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:\n embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]]\n if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:\n position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]\n else:\n embedding_ranks = ranks\n position_embedding_ranks = ranks\n\n group = torch.distributed.new_group(embedding_ranks)\n if rank in embedding_ranks:\n _EMBEDDING_GROUP = group\n if rank in ranks:\n _EMBEDDING_GLOBAL_RANKS = embedding_ranks\n\n group = torch.distributed.new_group(position_embedding_ranks)\n if rank in position_embedding_ranks:\n _POSITION_EMBEDDING_GROUP = group\n if rank in ranks:\n _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks\n\n # Build the FP8 groups.\n global _AMAX_REDUCTION_GROUP\n assert _AMAX_REDUCTION_GROUP is None, \\\n 'FP8 amax reduction group is already initialized'\n if use_fp8:\n amax_group_size: int = tensor_model_parallel_size * data_parallel_size\n num_amax_groups: int = world_size // amax_group_size\n for i in range(num_amax_groups):\n start_rank = i * amax_group_size\n end_rank = (i + 1) * amax_group_size\n ranks = range(start_rank, end_rank)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _AMAX_REDUCTION_GROUP = group\n\n # Initialize global memory buffer\n # This isn't really \"parallel state\" but there isn't another good place to\n # put this. If we end up with a more generic initialization of megatron-core\n # we could stick it there\n _set_global_memory_buffer()\n```\n###test function signature:\n```python\n\ndef test_initialize__and_destroy_model_parallel():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_initialize__and_destroy_model_parallel", + "reference": "\ndef test_initialize__and_destroy_model_parallel():\n with pytest.raises(AssertionError):\n assert(ps.initialize_model_parallel())\n Utils.initialize_distributed()\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))\n Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)\n\n assert(ps.model_parallel_is_initialized())\n assert(ps.get_model_parallel_group() is not None)\n assert(ps.get_tensor_model_parallel_group() is not None)\n assert(ps.get_pipeline_model_parallel_group() is not None)\n assert(ps.get_data_parallel_group() is not None) \n Utils.destroy_model_parallel()\n assert(ps._MODEL_PARALLEL_GROUP is None)", + "completions": [ + "```python\ndef test_initialize_and_destroy_model_parallel():\n initialize_model_parallel(\n tensor_model_parallel_size=1,\n pipeline_model_parallel_size=1,\n virtual_pipeline_model_parallel_size=None,\n pipeline_model_parallel_split_rank=None,\n use_fp8=False,\n )\n destroy_model_parallel()\n```\n" + ], + "line": 133, + "token": 886, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_shapes(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_shapes", + "reference": "\n def test_spectrogram_shapes(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 729))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 183))", + "completions": [ + "```python\n def test_spectrogram_shapes(self):\n waveform = np.random.rand(1000)\n window = np.hanning(100)\n frame_length = 100\n hop_length = 10\n fft_length = 100\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (10, 100))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_base64_image_given_and_no_resize_needed", + "reference": "\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n _, image = example_local_image\n _, buffer = cv2.imencode(\".jpg\", image)\n base64_image = base64.b64encode(buffer).decode(\"ascii\")\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=base64_image,\n max_width=None,\n max_height=None,\n )\n\n # then\n assert (\n scaling_factor is None\n ), \"No resize parameters given, so scaling factor should not be established\"\n assert (\n serialised_image == base64_image\n ), \"Serialised image should be identical with input\"", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n image_bytes = base64.b64decode(example_local_image[0])\n image = bytes_to_opencv_image(payload=image_bytes)\n # when\n result = load_image_from_string(\n reference=example_local_image[0],\n max_height=None,\n max_width=None,\n )\n # then\n assert result[0] == example_local_image[0]\n assert result[1] is None\n assert np.array_equal(image, result[1])\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 69, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def submit_payout_address(self, robot_index=1):\n path = reverse(\"order\")\n params = f\"?order_id={self.order_id}\"\n headers = self.get_robot_auth(robot_index)\n\n payout_address = create_address(\"robot\")\n signed_payout_address = sign_message(\n payout_address,\n passphrase_path=f\"tests/robots/{robot_index}/token\",\n private_key_path=f\"tests/robots/{robot_index}/enc_priv_key\",\n )\n body = {\n \"action\": \"update_address\",\n \"address\": signed_payout_address,\n \"mining_fee_rate\": 50,\n }\n self.response = self.client.post(path + params, body, **headers)\n```\n###test function signature:\n```python\n def test_successful_onchain(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_successful_onchain", + "reference": " def test_successful_onchain(self):\n trade = Trade(self.client)\n trade.publish_order()\n trade.take_order()\n trade.lock_taker_bond()\n trade.lock_escrow(trade.taker_index)\n trade.submit_payout_address(trade.maker_index)\n trade.confirm_fiat(trade.maker_index)\n trade.confirm_fiat(trade.taker_index)\n\n trade.process_payouts(mine_a_block=True)\n trade.get_order(trade.maker_index)\n\n data = trade.response.json()\n\n self.assertEqual(trade.response.status_code, 200)\n self.assertResponse(trade.response)\n\n self.assertEqual(data[\"status_message\"], Order.Status(Order.Status.SUC).label)\n self.assertTrue(data[\"is_fiat_sent\"])\n self.assertFalse(data[\"is_disputed\"])\n self.assertIsInstance(data[\"maker_summary\"][\"address\"], str)\n self.assertIsHash(data[\"maker_summary\"][\"txid\"])\n\n self.assert_order_logs(data[\"id\"])", + "completions": [ + "```python\n def test_successful_onchain(self):\n self.submit_payout_address(robot_index=1)\n self.assertEqual(self.response.status_code, 200)\n self.assertEqual(self.response.json()[\"status\"], \"success\")\n```\n" + ], + "line": 26, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 70, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_blank_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_blank_line", + "reference": "\ndef test_header_to_metadata_and_cell_blank_line():\n text = \"\"\"---\ntitle: Sample header\n---\n\nHeader is followed by a blank line\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {}\n assert lines[pos].startswith(\"Header is\")", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_blank_line():\n lines = [\"#' ---\", \"\"]\n header_prefix = \"#'\"\n header_suffix = \"\"\n ext = None\n root_level_metadata_as_raw_cell = True\n metadata, jupyter, cell, i = header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext, root_level_metadata_as_raw_cell\n )\n assert metadata == {}\n assert jupyter == []\n assert cell == new_raw_cell(source=\"\\n\".join([\"---\"]), metadata={\"lines_to_next_cell\": 1})\n assert i == 1\n```\n" + ], + "line": 104, + "token": 642, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 71, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def _create_output(\n self,\n accepted: torch.Tensor, # [batch_size, k]\n recovered_token_ids: torch.Tensor, # [batch_size, k]\n draft_token_ids: torch.Tensor, # [batch_size, k]\n bonus_token_ids: torch.Tensor, # [batch_size]\n ) -> torch.Tensor:\n bonus_token_ids = bonus_token_ids.squeeze()\n batch_size, k = recovered_token_ids.shape\n\n # Determine the index of the first False value for each row.\n limits = (accepted == 0).max(1).indices\n limits[~(accepted == 0).any(1)] = k\n\n # Create masks using the indices.\n indices = torch.arange(k, device=accepted.device).unsqueeze(0)\n accepted_mask = indices < limits.unsqueeze(1)\n after_false_mask = indices == limits.unsqueeze(1)\n\n # Create an extended output tensor\n output_with_bonus_tokens = -torch.ones(\n (batch_size, k + self._num_bonus_tokens),\n dtype=self.token_id_dtype,\n device=accepted.device)\n output = output_with_bonus_tokens[:, :k]\n\n # Fill in the first k columns of the output tensor using masks and data\n # tensors.\n output[:, :k] = torch.where(accepted_mask, draft_token_ids,\n -torch.ones_like(draft_token_ids))\n\n # Fill the last column.\n # We check output directly as accepted may have True values inconsistent\n # with causal acceptance.\n output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,\n bonus_token_ids, -1)\n\n # We disable bonus tokens because it causes corrupt KV cache for\n # proposal methods that require KV cache. We can fix it by \"prefilling\"\n # the bonus token in the proposer. The following issue tracks the fix.\n # https://github.com/vllm-project/vllm/issues/4212\n output_with_bonus_tokens[:, -1] = -1\n\n # Fill the recovered token ids.\n output.mul_(~after_false_mask).add_(\n recovered_token_ids.mul(after_false_mask))\n\n self.num_accepted_tokens += accepted.sum()\n self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()\n self.num_draft_tokens += batch_size * k\n\n return output_with_bonus_tokens\n```\n###test function signature:\n```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct_output_format", + "reference": "def test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n set_random_seed(seed)\n torch.set_default_device(device)\n\n batch_size = 10\n k = 5\n vocab_size = 3000\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"no_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"some_tokens_accepted\":\n last_accepted_indices = torch.randint(low=-1,\n high=k,\n size=(batch_size, ))\n accepted = mock_causal_accepted_tensor(k, last_accepted_indices)\n else:\n raise AssertionError()\n\n recovered_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n draft_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n bonus_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, 1),\n dtype=torch.int64)\n\n rejection_sampler = RejectionSampler()\n rejection_sampler.init_gpu_tensors(rank=0)\n output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access\n accepted,\n recovered_token_ids,\n draft_token_ids,\n bonus_token_ids,\n )\n\n # Bonus tokens are currently disabled. Verify they're set to -1.\n # See https://github.com/vllm-project/vllm/issues/4212\n expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n # Expect all tokens to be equal to draft tokens.\n assert torch.equal(output_token_ids[:, :-1], draft_token_ids)\n\n # Expect all bonus tokens to be included.\n assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)\n elif which_tokens_accepted == \"no_tokens_accepted\":\n # Expect first token to be equal to recovered tokens.\n assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])\n\n # Expect everything else to be -1.\n assert torch.equal(output_token_ids[:, 1:],\n torch.ones_like(output_token_ids[:, 1:]) * -1)\n elif which_tokens_accepted == \"some_tokens_accepted\":\n recovered_plus_bonus = torch.cat(\n (recovered_token_ids, expected_bonus_token_ids), dim=-1)\n # Assert first rejected token is a recovered token or bonus token.\n assert torch.equal(\n recovered_plus_bonus[torch.arange(0, batch_size),\n last_accepted_indices + 1],\n output_token_ids[torch.arange(0, batch_size),\n last_accepted_indices + 1])\n\n # Assert every subsequent token is -1.\n subsequent_mask = torch.arange(0, k + 1).expand(\n batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)\n assert torch.all(output_token_ids[subsequent_mask] == -1)", + "completions": [ + "```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n torch.manual_seed(seed)\n batch_size = 10\n k = 10\n num_bonus_tokens = 1\n accepted = torch.randint(0, 2, (batch_size, k), dtype=torch.int32)\n recovered_token_ids = torch.randint(0, 10, (batch_size, k),\n dtype=torch.int32)\n draft_token_ids = torch.randint(0, 10, (batch_size, k), dtype=torch.int32)\n bonus_token_ids = torch.randint(0, 10, (batch_size,), dtype=torch.int32)\n output = _create_output(accepted, recovered_token_ids, draft_token_ids,\n bonus_token_ids)\n assert output.shape == (batch_size, k + num_bonus_tokens)\n assert output.dtype == torch.int32\n assert output.device == device\n assert (output[:, :k] == torch.where(accepted, draft_token_ids,\n -torch.ones_like(draft_token_ids))).all()\n assert (output[:, -1] == torch.where(output[:, -1] != -1, bonus_token_ids,\n -1)).all()\n assert (output[:, -1] == -1).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted" + ], + "line": 54, + "token": 433, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 72, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_array(fp, allow_pickle=False, pickle_kwargs=None, *,\n max_header_size=_MAX_HEADER_SIZE):\n if allow_pickle:\n # Effectively ignore max_header_size, since `allow_pickle` indicates\n # that the input is fully trusted.\n max_header_size = 2**64\n\n version = read_magic(fp)\n _check_version(version)\n shape, fortran_order, dtype = _read_array_header(\n fp, version, max_header_size=max_header_size)\n if len(shape) == 0:\n count = 1\n else:\n count = numpy.multiply.reduce(shape, dtype=numpy.int64)\n\n # Now read the actual data.\n if dtype.hasobject:\n # The array contained Python objects. We need to unpickle the data.\n if not allow_pickle:\n raise ValueError(\"Object arrays cannot be loaded when \"\n \"allow_pickle=False\")\n if pickle_kwargs is None:\n pickle_kwargs = {}\n try:\n array = pickle.load(fp, **pickle_kwargs)\n except UnicodeError as err:\n # Friendlier error message\n raise UnicodeError(\"Unpickling a python object failed: %r\\n\"\n \"You may need to pass the encoding= option \"\n \"to numpy.load\" % (err,)) from err\n else:\n if isfileobj(fp):\n # We can use the fast fromfile() function.\n array = numpy.fromfile(fp, dtype=dtype, count=count)\n else:\n # This is not a real file. We have to read it the\n # memory-intensive way.\n # crc32 module fails on reads greater than 2 ** 32 bytes,\n # breaking large reads from gzip streams. Chunk reads to\n # BUFFER_SIZE bytes to avoid issue and reduce memory overhead\n # of the read. In non-chunked case count < max_read_count, so\n # only one read is performed.\n\n # Use np.ndarray instead of np.empty since the latter does\n # not correctly instantiate zero-width string dtypes; see\n # https://github.com/numpy/numpy/pull/6430\n array = numpy.ndarray(count, dtype=dtype)\n\n if dtype.itemsize > 0:\n # If dtype.itemsize == 0 then there's nothing more to read\n max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)\n\n for i in range(0, count, max_read_count):\n read_count = min(max_read_count, count - i)\n read_size = int(read_count * dtype.itemsize)\n data = _read_bytes(fp, read_size, \"array data\")\n array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n count=read_count)\n\n if fortran_order:\n array.shape = shape[::-1]\n array = array.transpose()\n else:\n array.shape = shape\n\n return array\n```\n###test function signature:\n```python\n\ndef test_version_2_0():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_version_2_0", + "reference": "\ndef test_version_2_0():\n f = BytesIO()\n # requires more than 2 byte for header\n dt = [((\"%d\" % i) * 100, float) for i in range(500)]\n d = np.ones(1000, dtype=dt)\n\n format.write_array(f, d, version=(2, 0))\n with warnings.catch_warnings(record=True) as w:\n warnings.filterwarnings('always', '', UserWarning)\n format.write_array(f, d)\n assert_(w[0].category is UserWarning)\n\n # check alignment of data portion\n f.seek(0)\n header = f.readline()\n assert_(len(header) % format.ARRAY_ALIGN == 0)\n\n f.seek(0)\n n = format.read_array(f, max_header_size=200000)\n assert_array_equal(d, n)\n\n # 1.0 requested but data cannot be saved this way\n assert_raises(ValueError, format.write_array, f, d, (1, 0))", + "completions": [ + "```python\ndef test_version_2_0():\n fp = open(\"test_file.npy\", \"rb\")\n read_array(fp)\n```\n" + ], + "line": 70, + "token": 516, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 73, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):\n if not torch.is_tensor(audio):\n if isinstance(audio, str):\n audio = load_audio(audio)\n audio = torch.from_numpy(audio)\n\n window = torch.hann_window(N_FFT).to(audio.device)\n stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)\n magnitudes = stft[..., :-1].abs() ** 2\n\n filters = mel_filters(audio.device, n_mels)\n mel_spec = filters @ magnitudes\n\n log_spec = torch.clamp(mel_spec, min=1e-10).log10()\n log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)\n log_spec = (log_spec + 4.0) / 4.0\n return log_spec\n```\n###test function signature:\n```python\n\ndef test_audio():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_audio", + "reference": "\ndef test_audio():\n audio_path = os.path.join(os.path.dirname(__file__), \"jfk.flac\")\n audio = load_audio(audio_path)\n assert audio.ndim == 1\n assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12\n assert 0 < audio.std() < 1\n\n mel_from_audio = log_mel_spectrogram(audio)\n mel_from_file = log_mel_spectrogram(audio_path)\n\n assert np.allclose(mel_from_audio, mel_from_file)\n assert mel_from_audio.max() - mel_from_audio.min() <= 2.0", + "completions": [ + "```python\ndef test_audio():\n audio = torch.randn(1, 16000)\n log_mel_spectrogram(audio)\n```\n" + ], + "line": 24, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 74, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef transform_grid_cell(\n grid: Grid,\n cell: npt.NDArray[np.float_] = np.eye(3),\n output_shape: Optional[Tuple[int, int, int]] = None,\n mode: str = \"constant\",\n order: int = 1,\n **kwargs,\n) -> Grid:\n # Take the current shape of the grid if no output shape was provided\n if output_shape is None:\n output_shape = grid.shape\n\n # Make sure the cell has type float\n cell = np.asarray(cell, dtype=float)\n\n # Get the current cell in coordinates of the destination axes\n inv_cell = cell_invert(cell).T\n projected_cell = grid.cell.dot(inv_cell)\n\n # From that, infere how long will the bounding box of the cell be\n lengths = abs(projected_cell).sum(axis=0)\n\n # Create the transformation matrix. Since we want to control the shape\n # of the output, we can not use grid.dcell directly, we need to modify it.\n scales = output_shape / lengths\n forward_t = (grid.dcell.dot(inv_cell) * scales).T\n\n # Scipy's affine transform asks for the inverse transformation matrix, to\n # map from output pixels to input pixels. By taking the inverse of our\n # transformation matrix, we get exactly that.\n tr = cell_invert(forward_t).T\n\n # Calculate the offset of the image so that all points of the grid \"fall\" inside\n # the output array.\n # For this we just calculate the centers of the input and output images\n center_input = 0.5 * (_a.asarrayd(grid.shape) - 1)\n center_output = 0.5 * (_a.asarrayd(output_shape) - 1)\n\n # And then make sure that the input center that is interpolated from the output\n # falls in the actual input's center\n offset = center_input - tr.dot(center_output)\n\n # We pass all the parameters to scipy's affine_transform\n transformed_image = affine_transform(\n grid.grid,\n tr,\n order=1,\n offset=offset,\n output_shape=output_shape,\n mode=mode,\n **kwargs,\n )\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1))\n new_grid.grid = transformed_image\n new_grid.geometry = grid.geometry\n new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset))\n\n # Find the offset between the origin before and after the transformation\n return new_grid\n```\n###test function signature:\n```python\n\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_grid_cell", + "reference": "\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10))\n\n assert new_grid.shape == (10, 10, 10)\n assert new_grid.lattice.is_cartesian()\n\n if not skewed:\n assert np.allclose(new_grid.lattice.cell, grid.lattice.cell)\n\n assert np.allclose(new_grid.grid, grid.grid) == (not skewed)\n assert not np.allclose(new_grid.grid, 0)\n\n assert np.allclose(new_grid.origin, grid.origin) == (not skewed)\n\n # Convert to a skewed cell\n directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]])\n new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5))\n\n assert new_grid.shape == (5, 5, 5)\n for i in range(3):\n n = new_grid.lattice.cell[i] / directions[i]\n assert np.allclose(n, n[0])", + "completions": [ + "```python\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n cell = grid.cell.dot(cell_invert(grid.cell))\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = transform_grid_cell(grid, cell, grid.shape, mode=\"constant\", order=1)\n\n # Check that the new grid has the same shape and cell as the original grid\n assert new_grid.shape == grid.shape\n assert np.allclose(new_grid.cell, grid.cell)\n\n # Check that the new grid has the same geometry as the original grid\n assert new_grid.geometry == grid.geometry\n\n # Check that the new grid has the same origin as the original grid\n assert np.allclose(new_grid.origin, grid.origin)\n\n # Check that the new grid has the same dcell as the original grid\n assert np.allclose(new_grid.dcell, grid.dcell)\n\n # Check that the new grid has the same grid as the original grid\n assert np.allclose(new_grid.grid, grid.grid)\n\n # Check that the new grid has the same skewed as the original grid\n assert np.allclose(new_grid.skewed, grid.skewed)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid" + ], + "line": 63, + "token": 533, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:\n\n index_norm = norm(index, axis=1)\n xq_norm = norm(xq.T)\n sim = np.dot(index, xq.T) / (index_norm * xq_norm)\n return sim\n```\n###test function signature:\n```python\ndef test_top_scores__is_sorted(test_index):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_top_scores__is_sorted", + "reference": "def test_top_scores__is_sorted(test_index):\n\n xq = test_index[0] # should have max similarity\n\n sim = similarity_matrix(xq, test_index)\n _, idx = top_scores(sim, 3)\n\n # Scores and indexes should be sorted ascending\n assert np.array_equal(idx, np.array([2, 1, 0]))", + "completions": [ + "```python\ndef test_top_scores__is_sorted(test_index):\n sim = similarity_matrix(xq, test_index)\n assert np.all(np.diff(sim, axis=1) >= 0)\n```\n" + ], + "line": 15, + "token": 125, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_mixed_ref_area_2", + "reference": "\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n\n out_bars_with_mixed_ref_area_2 = _generate_nanoplot(**CASES[16])\n\n assert _is_nanoplot_output(out_bars_with_mixed_ref_area_2)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n y_ref_area = [1, 10]\n y_ref_line = 5\n x_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n expand_x = \"2020-01-01 00:00:00 UTC\"\n expand_y = \"2020-01-01 00:00:00 UTC\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n all_single_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_" + ], + "line": 749, + "token": 7038, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 77, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def mean(self, lower=0.0, upper=np.inf):\n loc = []\n for idx in self.index:\n qpd = self.qpd.loc[idx, :].values[0]\n l, _ = quad(exp_func, args=(qpd), a=lower, b=upper)\n loc.append(l)\n loc_arr = np.array(loc)\n return pd.DataFrame(loc_arr, index=self.index, columns=self.columns)\n```\n###test function signature:\n```python\ndef test_qpd_b_simple_use():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_qpd_b_simple_use", + "reference": "def test_qpd_b_simple_use():\n from skpro.distributions.qpd import QPD_B\n\n qpd = QPD_B(\n alpha=0.2,\n qv_low=[1, 2],\n qv_median=[3, 4],\n qv_high=[5, 6],\n lower=0,\n upper=10,\n )\n\n qpd.mean()", + "completions": [ + "```python\ndef test_qpd_b_simple_use():\n import numpy as np\n import pandas as pd\n from scipy.integrate import quad\n from scipy.special import expit\n\n def exp_func(x, qpd):\n return expit(qpd * x)\n\n qpd = pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z'])\n qpd_b = qpd.mean(lower=0.0, upper=np.inf)\n assert qpd_b.equals(pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z']))\n```\n" + ], + "line": 18, + "token": 146, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef create(\n supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False\n) -> adapters.HTTPAdapter:\n settings = supply_settings()\n allowed_methods = set(util.Retry.DEFAULT_ALLOWED_METHODS)\n if include_post:\n allowed_methods.add(\"POST\")\n\n retries = util.Retry(\n total=settings.total,\n backoff_factor=settings.backoff_factor,\n status_forcelist=settings.statuses,\n allowed_methods=allowed_methods,\n )\n return adapters.HTTPAdapter(max_retries=retries)\n```\n###test function signature:\n```python\n\ndef test_should_create_adapter(include_post):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_should_create_adapter", + "reference": "\ndef test_should_create_adapter(include_post):\n total_retries = 5\n backoff_factor = 0.4\n statuses = (random.randint(400, 600),)\n settings = config.ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses)\n adapter = http_adapter.create(supply_settings=lambda: settings, include_post=include_post)\n assert adapter.max_retries.total == total_retries\n assert adapter.max_retries.backoff_factor == backoff_factor\n assert adapter.max_retries.status_forcelist == statuses\n allowed_methods = adapter.max_retries.allowed_methods or []\n if include_post:\n assert \"POST\" in allowed_methods\n else:\n assert \"POST\" not in allowed_methods", + "completions": [ + "```python\n\ndef test_should_create_adapter(include_post):\n assert create(include_post=include_post) == adapters.HTTPAdapter(max_retries=util.Retry(\n total=config.ResilienceSettings().total,\n backoff_factor=config.ResilienceSettings().backoff_factor,\n status_forcelist=config.ResilienceSettings().statuses,\n allowed_methods=set(util.Retry.DEFAULT_ALLOWED_METHODS) if not include_post else set(util.Retry.DEFAULT_ALLOWED_METHODS).union({\"POST\"})\n ))\n```\n" + ], + "line": 24, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 79, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:\n summary_path = os.path.join(trial_path, 'summary.csv')\n if not os.path.exists(summary_path):\n raise ValueError(f\"summary.csv does not exist in {trial_path}.\")\n trial_summary_df = load_summary_file(summary_path, dict_columns=['best_module_params'])\n config_yaml_path = os.path.join(trial_path, 'config.yaml')\n with open(config_yaml_path, 'r') as f:\n config_dict = yaml.safe_load(f)\n yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)\n if output_path is not None:\n with open(output_path, 'w') as f:\n yaml.dump(yaml_dict, f)\n return yaml_dict\n```\n###test function signature:\n```python\n\ndef test_runner(evaluator):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_runner", + "reference": "\ndef test_runner(evaluator):\n os.environ['BM25'] = 'bm25'\n evaluator.start_trial(os.path.join(resource_dir, 'simple.yaml'))\n project_dir = evaluator.project_dir\n\n def runner_test(runner: Runner):\n answer = runner.run('What is the best movie in Korea? Have Korea movie ever won Oscar?',\n 'retrieved_contents')\n assert len(answer) == 1\n assert isinstance(answer, list)\n assert isinstance(answer[0], str)\n\n runner = Runner.from_trial_folder(os.path.join(project_dir, '0'))\n runner_test(runner)\n runner_test(runner)\n\n with tempfile.NamedTemporaryFile(suffix='yaml', mode='w+t') as yaml_path:\n extract_best_config(os.path.join(project_dir, '0'), yaml_path.name)\n runner = Runner.from_yaml(yaml_path.name, project_dir=project_dir)\n runner_test(runner)", + "completions": [ + "```python\ndef test_runner(evaluator):\n def test_extract_best_config():\n trial_path = 'path/to/trial'\n output_path = 'path/to/output'\n yaml_dict = extract_best_config(trial_path, output_path)\n assert yaml_dict is not None\n assert isinstance(yaml_dict, dict)\n assert 'best_module_params' in yaml_dict\n assert isinstance(yaml_dict['best_module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['" + ], + "line": 23, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 80, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def init_weights(self):\n logger = MMLogger.get_current_instance()\n if self.init_cfg is None:\n logger.warn(f'No pre-trained weights for '\n f'{self.__class__.__name__}, '\n f'training start from scratch')\n if self.use_abs_pos_embed:\n trunc_normal_(self.absolute_pos_embed, std=0.02)\n for m in self.modules():\n if isinstance(m, nn.Linear):\n trunc_normal_init(m, std=.02, bias=0.)\n elif isinstance(m, nn.LayerNorm):\n constant_init(m, 1.0)\n else:\n assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n f'specify `Pretrained` in ' \\\n f'`init_cfg` in ' \\\n f'{self.__class__.__name__} '\n ckpt = CheckpointLoader.load_checkpoint(\n self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n if 'state_dict' in ckpt:\n _state_dict = ckpt['state_dict']\n elif 'model' in ckpt:\n _state_dict = ckpt['model']\n else:\n _state_dict = ckpt\n if self.convert_weights:\n # supported loading weight from original repo,\n _state_dict = swin_converter(_state_dict)\n\n state_dict = OrderedDict()\n for k, v in _state_dict.items():\n if k.startswith('backbone.'):\n state_dict[k[9:]] = v\n\n # strip prefix of state_dict\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n # reshape absolute position embedding\n if state_dict.get('absolute_pos_embed') is not None:\n absolute_pos_embed = state_dict['absolute_pos_embed']\n N1, L, C1 = absolute_pos_embed.size()\n N2, C2, H, W = self.absolute_pos_embed.size()\n if N1 != N2 or C1 != C2 or L != H * W:\n logger.warning('Error in loading absolute_pos_embed, pass')\n else:\n state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n # interpolate position bias table if needed\n relative_position_bias_table_keys = [\n k for k in state_dict.keys()\n if 'relative_position_bias_table' in k\n ]\n for table_key in relative_position_bias_table_keys:\n table_pretrained = state_dict[table_key]\n table_current = self.state_dict()[table_key]\n L1, nH1 = table_pretrained.size()\n L2, nH2 = table_current.size()\n if nH1 != nH2:\n logger.warning(f'Error in loading {table_key}, pass')\n elif L1 != L2:\n S1 = int(L1**0.5)\n S2 = int(L2**0.5)\n table_pretrained_resized = F.interpolate(\n table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n size=(S2, S2),\n mode='bicubic')\n state_dict[table_key] = table_pretrained_resized.view(\n nH2, L2).permute(1, 0).contiguous()\n\n # load state_dict\n self.load_state_dict(state_dict, False)\n```\n###test function signature:\n```python\ndef test_swin_transformer():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_swin_transformer", + "reference": "def test_swin_transformer():\n\n with pytest.raises(TypeError):\n # Pretrained arg must be str or None.\n SwinTransformer(pretrained=123)\n\n with pytest.raises(AssertionError):\n # Because swin uses non-overlapping patch embed, so the stride of patch\n # embed must be equal to patch size.\n SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)\n\n # test pretrained image size\n with pytest.raises(AssertionError):\n SwinTransformer(pretrain_img_size=(224, 224, 224))\n\n # Test absolute position embedding\n temp = torch.randn((1, 3, 224, 224))\n model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)\n model.init_weights()\n model(temp)\n\n # Test patch norm\n model = SwinTransformer(patch_norm=False)\n model(temp)\n\n # Test normal inference\n temp = torch.randn((1, 3, 32, 32))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 31, 31))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 112, 137))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 28, 35)\n assert outs[1].shape == (1, 192, 14, 18)\n assert outs[2].shape == (1, 384, 7, 9)\n assert outs[3].shape == (1, 768, 4, 5)\n\n model = SwinTransformer(frozen_stages=4)\n model.train()\n for p in model.parameters():\n assert not p.requires_grad", + "completions": [ + "```python\ndef test_swin_transformer():\n from mmcv.cnn import (build_pos_embed, build_pos_embed_sine,\n constant_init, trunc_normal_init)\n from mmcv.runner import CheckpointLoader\n from mmdet.models.builder import build_backbone\n from mmdet.models.utils import swin_converter\n from mmdet.utils import get_root_logger\n from mmdet.utils.logger import MMLogger\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import PatchEmbed\n from mmdet.models.backbones.swin import PatchMerging\n from mmdet.models.backbones.swin import BasicLayer\n from mmdet.models.backbones.swin import BasicBlock\n from mmdet.models.backbones.swin import DepthwiseSeparableConv\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransform" + ], + "line": 79, + "token": 586, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 81, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\n\ndef test_date_columns_equal():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_date_columns_equal", + "reference": "\ndef test_date_columns_equal():\n data = \"\"\"a|b|expected\n2017-01-01|2017-01-01|True\n2017-01-02|2017-01-02|True\n2017-10-01|2017-10-10|False\n2017-01-01||False\n|2017-01-01|False\n||True\"\"\"\n df = pl.read_csv(\n io.StringIO(data),\n separator=\"|\",\n null_values=[\"NULL\"],\n missing_utf8_is_empty_string=True,\n )\n # First compare just the strings\n actual_out = columns_equal(df[\"a\"], df[\"b\"], rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)\n\n # Then compare converted to datetime objects\n col_a = df[\"a\"].str.to_datetime(strict=False)\n col_b = df[\"b\"].str.to_datetime(strict=False)\n actual_out = columns_equal(col_a, col_b, rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)", + "completions": [ + "```python\ndef test_date_columns_equal():\n col_1 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n col_2 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 82, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def upsert_child_firing(\n db: PrefectDBInterface,\n session: AsyncSession,\n firing: Firing,\n):\n automation_id = firing.trigger.automation.id\n parent_trigger_id = firing.trigger.parent.id\n child_trigger_id = firing.trigger.id\n\n upsert = (\n postgresql.insert(db.CompositeTriggerChildFiring)\n .values(\n automation_id=automation_id,\n parent_trigger_id=parent_trigger_id,\n child_trigger_id=child_trigger_id,\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n )\n .on_conflict_do_update(\n index_elements=[\n db.CompositeTriggerChildFiring.automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id,\n ],\n set_=dict(\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n updated=pendulum.now(\"UTC\"),\n ),\n )\n )\n\n await session.execute(upsert)\n\n result = await session.execute(\n sa.select(db.CompositeTriggerChildFiring).filter(\n db.CompositeTriggerChildFiring.automation_id == automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id == parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id == child_trigger_id,\n )\n )\n\n return result.scalars().one()\n```\n###test function signature:\n```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_insert_compound_trigger_child_firing", + "reference": "\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n firing = Firing(\n trigger=saw_a_red_dragon,\n trigger_states={TriggerState.Triggered},\n triggered=frozen_time.add(seconds=1),\n triggering_labels={},\n triggering_event=baby_red_dragon_passed_by,\n )\n res = await upsert_child_firing(\n session=automations_session,\n firing=firing,\n )\n\n assert res and res.id\n assert res.automation_id == raise_the_alarm.id\n assert res.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert res.child_trigger_id == saw_a_red_dragon.id\n assert res.child_firing_id == firing.id\n assert res.child_fired_at == firing.triggered\n\n child_firings = await get_child_firings(\n session=automations_session, trigger=saw_a_red_and_green_dragon\n )\n assert len(child_firings) == 1\n assert child_firings[0].id == res.id\n assert child_firings[0].child_firing_id == firing.id", + "completions": [ + "```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n \"\"\"\n Test that a compound trigger child firing is inserted correctly\n \"\"\"\n # Arrange\n firing = Firing(\n id=uuid4(),\n triggered=frozen_time,\n trigger=saw_a_red_and_green_dragon,\n event=baby_red_dragon_passed_by,\n )\n\n # Act\n result = await upsert_child_firing(\n db=automations_session,\n session=automations_session,\n firing=firing,\n )\n\n # Assert\n assert result.automation_id == raise_the_alarm.id\n assert result.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert result.child_trigger_id == saw_a_red_dragon.id\n assert result.child_firing_id == firing.id\n assert result.child_fired_at == firing.triggered\n assert result.child_firing == firing.dict()\n assert result.updated == frozen_time\n\n```\n" + ], + "line": 58, + "token": 225, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 83, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef write_output(output_file, output_type, results):\n\n # If json or txt, write contents directly to file.\n if output_type == \"json\" or output_type == \"txt\":\n with open(output_file, \"w\") as f:\n json.dump(results, f, indent=4)\n return\n # If csv or xlsx, convert results to pandas dataframes.\n urls_df = get_urls_df(results)\n codes_df = get_codes_df(results)\n\n # If csv, write dataframes to separate csv files for urls, codes.\n if output_type == \"csv\":\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n urls_df.to_csv(urls_output_file, index=False)\n codes_df.to_csv(codes_output_file, index=False)\n\n # If xlsx, write dataframes to separate sheets for urls, codes.\n if output_type == \"xlsx\":\n writer = pd.ExcelWriter(output_file, engine=\"xlsxwriter\")\n urls_df.to_excel(writer, sheet_name=\"URLs\", index=False)\n codes_df.to_excel(writer, sheet_name=\"Codes\", index=False)\n writer.close()\n```\n###test function signature:\n```python\n def test_write_output_csv(self, mock_urls, mock_codes):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_output_csv", + "reference": " def test_write_output_csv(self, mock_urls, mock_codes):\n\n test_file = \"./test_output/test_file.csv\"\n test_file_urls = \"./test_output/test_file_urls.csv\"\n test_file_codes = \"./test_output/test_file_codes.csv\"\n test_results = {\"test\": \"test\"}\n mock_urls.return_value = pd.DataFrame([test_results])\n mock_codes.return_value = pd.DataFrame([test_results])\n\n with open(test_file_urls, \"w\") as f:\n pass\n\n with open(test_file_codes, \"w\") as f:\n pass\n\n write_output(test_file, \"csv\", test_results)\n\n test_data_urls = pd.read_csv(test_file_urls).to_dict(orient=\"records\")[0]\n test_data_codes = pd.read_csv(test_file_codes).to_dict(orient=\"records\")[0]\n\n os.remove(test_file_urls)\n os.remove(test_file_codes)\n\n self.assertEqual(test_data_urls, test_results)\n self.assertEqual(test_data_codes, test_results)", + "completions": [ + "```python\ndef test_write_output_csv(self, mock_urls, mock_codes):\n \"\"\"\n Test write_output function with csv output type.\n \"\"\"\n output_file = \"output.csv\"\n output_type = \"csv\"\n results = {\"urls\": mock_urls, \"codes\": mock_codes}\n write_output(output_file, output_type, results)\n\n # Check that the urls and codes dataframes are written to separate csv files.\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n self.assertTrue(os.path.isfile(urls_output_file))\n self.assertTrue(os.path.isfile(codes_output_file))\n\n # Check that the dataframes are correct.\n urls_df = pd.read_csv(urls_output_file)\n codes_df = pd.read_csv(codes_output_file)\n self.assertEqual(urls_df.shape, (2, 2))\n self.assertEqual(codes_df.shape, (2, 2))\n self.assertEqual(urls_df.iloc[0, 0], \"url1\")\n self.assertEqual(codes_df.iloc[0, 0], \"code1\")\n self.assertEqual(urls_df.iloc[1, 0], \"url2\")\n self.assertEqual(codes_df.iloc[1, 0], \"code2\")\n```\n" + ], + "line": 31, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 84, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def terminate_idle_instance() -> None:\n async with get_session_ctx() as session:\n async with PROCESSING_POOL_LOCK:\n res = await session.execute(\n select(InstanceModel)\n .where(\n InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE,\n InstanceModel.deleted == False,\n InstanceModel.job == None, # noqa: E711\n InstanceModel.status == InstanceStatus.IDLE,\n )\n .options(joinedload(InstanceModel.project))\n )\n instances = res.scalars().all()\n\n for instance in instances:\n last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)\n if instance.last_job_processed_at is not None:\n last_time = instance.last_job_processed_at.replace(\n tzinfo=datetime.timezone.utc\n )\n\n idle_seconds = instance.termination_idle_time\n delta = datetime.timedelta(seconds=idle_seconds)\n\n current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc)\n\n if last_time + delta < current_time:\n jpd = JobProvisioningData.__response__.parse_raw(\n instance.job_provisioning_data\n )\n await terminate_job_provisioning_data_instance(\n project=instance.project, job_provisioning_data=jpd\n )\n instance.deleted = True\n instance.deleted_at = get_current_datetime()\n instance.finished_at = get_current_datetime()\n instance.status = InstanceStatus.TERMINATED\n instance.termination_reason = \"Idle timeout\"\n\n idle_time = current_time - last_time\n logger.info(\n \"Instance %s terminated by termination policy: idle time %ss\",\n instance.name,\n str(idle_time.seconds),\n extra={\n \"instance_name\": instance.name,\n \"instance_status\": InstanceStatus.TERMINATED.value,\n },\n )\n\n await session.commit()\n```\n###test function signature:\n```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_terminate_by_idle_timeout", + "reference": "\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n project = await create_project(session=session)\n pool = await create_pool(session, project)\n\n instance = await create_instance(session, project, pool, status=InstanceStatus.IDLE)\n instance.termination_idle_time = 300\n instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE\n instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)\n await session.commit()\n\n with patch(\n \"dstack._internal.server.background.tasks.process_instances.terminate_job_provisioning_data_instance\"\n ):\n await terminate_idle_instance()\n\n await session.refresh(instance)\n\n assert instance is not None\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"", + "completions": [ + "```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n # given\n project = ProjectModel(name=\"test_project\")\n await session.add(project)\n await session.commit()\n\n instance = InstanceModel(\n name=\"test_instance\",\n project=project,\n termination_policy=TerminationPolicy.DESTROY_AFTER_IDLE,\n termination_idle_time=10,\n status=InstanceStatus.IDLE,\n created_at=datetime.datetime.now(datetime.timezone.utc),\n )\n await session.add(instance)\n await session.commit()\n\n # when\n await terminate_idle_instance()\n\n # then\n instance = await session.get(InstanceModel, instance.id)\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"\n```\n" + ], + "line": 56, + "token": 289, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 85, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef nargsort(\n items,\n kind: str = \"quicksort\",\n ascending: bool = True,\n na_position: str = \"last\",\n key: Callable | None = None,\n mask: npt.NDArray[np.bool_] | None = None,\n) -> npt.NDArray[np.intp]:\n\n if key is not None:\n items = ensure_key_mapped(items, key)\n return nargsort(\n items,\n kind=kind,\n ascending=ascending,\n na_position=na_position,\n key=None,\n mask=mask,\n )\n\n if isinstance(items, ABCRangeIndex):\n return items.argsort(ascending=ascending) # TODO: test coverage with key?\n elif not isinstance(items, ABCMultiIndex):\n items = extract_array(items)\n if mask is None:\n mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?\n\n if is_extension_array_dtype(items):\n return items.argsort(ascending=ascending, kind=kind, na_position=na_position)\n else:\n items = np.asanyarray(items)\n\n idx = np.arange(len(items))\n non_nans = items[~mask]\n non_nan_idx = idx[~mask]\n\n nan_idx = np.nonzero(mask)[0]\n if not ascending:\n non_nans = non_nans[::-1]\n non_nan_idx = non_nan_idx[::-1]\n indexer = non_nan_idx[non_nans.argsort(kind=kind)]\n if not ascending:\n indexer = indexer[::-1]\n # Finally, place the NaNs at the end or the beginning according to\n # na_position\n if na_position == \"last\":\n indexer = np.concatenate([indexer, nan_idx])\n elif na_position == \"first\":\n indexer = np.concatenate([nan_idx, indexer])\n else:\n raise ValueError(f\"invalid na_position: {na_position}\")\n return ensure_platform_int(indexer)\n```\n###test function signature:\n```python\n\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nargsort", + "reference": "\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)\n\n # mergesort is the most difficult to get right because we want it to be\n # stable.\n\n # According to numpy/core/tests/test_multiarray, \"\"\"The number of\n # sorted items must be greater than ~50 to check the actual algorithm\n # because quick and merge sort fall over to insertion sort for small\n # arrays.\"\"\"\n\n result = nargsort(\n items, kind=\"mergesort\", ascending=ascending, na_position=na_position\n )\n tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)", + "completions": [ + "```python\n def test_nargsort(self, ascending, na_position, exp, box):\n items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216," + ], + "line": 58, + "token": 412, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 86, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_rotated_rb_circuits(\n n_qubits: int,\n num_cliffords: int,\n theta: Optional[float] = None,\n trials: int = 1,\n return_type: Optional[str] = None,\n seed: Optional[int] = None,\n) -> List[QPROGRAM]:\n\n circuits = cast(\n List[cirq.Circuit],\n generate_rb_circuits(n_qubits, num_cliffords, trials, seed=seed),\n )\n\n if theta is None:\n theta = 2 * np.pi * np.random.rand()\n\n for circ in circuits:\n qubits = list(circ.all_qubits())\n circ.insert(len(circ) // 2, cirq.Rz(rads=theta).on(qubits[0]))\n\n return_type = \"cirq\" if not return_type else return_type\n return [convert_from_mitiq(circuit, return_type) for circuit in circuits]\n```\n###test function signature:\n```python\n\ndef test_rotated_rb_circuits(n_qubits, theta):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rotated_rb_circuits", + "reference": "\ndef test_rotated_rb_circuits(n_qubits, theta):\n depth = 10\n for trials in [2, 3]:\n circuits = generate_rotated_rb_circuits(\n n_qubits=n_qubits, num_cliffords=depth, theta=theta, trials=trials\n )\n for circ in circuits:\n zero_prob = (\n cirq.DensityMatrixSimulator()\n .simulate(circ)\n .final_density_matrix[0, 0]\n .real\n )\n assert -1.0001 <= zero_prob <= 1.0001", + "completions": [ + "```python\ndef test_rotated_rb_circuits(n_qubits, theta):\n from cirq import Circuit, LineQubit\n from cirq.testing import assert_has_diagram\n from mitiq.generate import generate_rotated_rb_circuits\n\n circuits = generate_rotated_rb_circuits(n_qubits, 10, theta)\n for circuit in circuits:\n assert_has_diagram(\n circuit,\n f\"\"\"\n0: ───Rz({theta})───\n1: ───Rz({theta})───\n2: ───Rz({theta})───\n3: ───Rz({theta})───\n4: ───Rz({theta})───\n5: ───Rz({theta})───\n6: ───Rz({theta})───\n7: ───Rz({theta})───\n8: ───Rz({theta})───\n9: ───Rz({theta})───\n10: ───Rz({theta})───\n11: ───Rz({theta})───\n12: ───Rz({theta})───\n13: ───Rz({theta})───\n14: ───Rz({theta})───\n15: ───Rz({theta})───\n16: ───Rz({theta})───\n17: ───Rz({theta})───\n18: ───Rz({theta})───\n19: ───Rz({theta})───\n20: ───Rz({theta})───\n21: ───Rz({theta})───\n22: ───Rz({theta})───\n23: ───Rz({theta})───\n24: ───Rz({theta})───\n25: ───Rz({theta})───\n26: ───Rz({theta})───\n27: ───Rz({theta})───\n28: ───Rz({theta})───\n29: ───Rz({theta})───\n30: ───Rz({theta})───\n31: ───Rz({theta})───\n32: ───Rz({theta})───\n33: ───Rz({theta})───\n34: ───Rz({theta})───\n35: ───Rz({theta})───\n36: ───Rz({theta})───\n37: ───Rz({theta})───\n38: ───Rz({theta})───\n39: ───Rz({theta})───\n40: ───Rz({theta})───\n41: ───Rz({theta})───\n42: ───Rz({theta})───\n43: ───Rz({theta})───\n44: ───Rz({theta})───\n45: ───Rz({theta})───\n46: ───Rz({theta})───\n47: ───Rz({theta})───\n48: ───Rz({theta})───\n49: ───Rz({theta})───\n50: ───Rz({theta})───\n51: ───Rz({theta})───\n52: ───Rz({theta})───\n53: ───Rz({theta})───\n54: ───Rz({theta})───\n55: ───Rz({theta})───\n56: ───Rz({theta})───\n57: ───Rz({theta})───\n58: ───Rz({theta})───\n59: ───Rz({theta})───\n60: ───Rz({theta})───\n61: ───Rz({theta})───\n62: ───Rz({theta})───\n63: ───" + ], + "line": 29, + "token": 226, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 87, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef eye(N, M=None, k=None, dtype=None):\n dtype = to_torch_dtype(dtype or config.floatx())\n M = N if M is None else M\n k = 0 if k is None else k\n if k == 0:\n # TODO: torch.eye doesn't support bfloat16 with cpu\n if get_device() == \"cpu\" and dtype == torch.bfloat16:\n return cast(\n torch.eye(\n N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n ),\n dtype,\n )\n return torch.eye(N, M, dtype=dtype, device=get_device())\n diag_length = builtins.max(N, M)\n diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n return torch.diag(diag, diagonal=k)[:N, :M]\n```\n###test function signature:\n```python\n\n def test_eye(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_eye", + "reference": "\n def test_eye(self):\n self.assertAllClose(knp.eye(3), np.eye(3))\n self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n self.assertAllClose(knp.Eye()(3), np.eye(3))\n self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1))\n\n # Test k >= N\n self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3))\n\n # Test k > 0 and N >= M\n self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1))\n\n # Test k > 0 and N < M and N + k > M\n self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2))\n\n # Test k < 0 and M >= N\n self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1))\n\n # Test k < 0 and M < N and M - k > N\n self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))", + "completions": [ + "```python\n\n def test_eye(self):\n self.assertEqual(eye(3), torch.eye(3))\n self.assertEqual(eye(3, 4), torch.eye(3, 4))\n self.assertEqual(eye(3, 4, 1), torch.eye(3, 4, 1))\n self.assertEqual(eye(3, 4, -1), torch.eye(3, 4, -1))\n self.assertEqual(eye(3, 4, 2), torch.eye(3, 4, 2))\n self.assertEqual(eye(3, 4, -2), torch.eye(3, 4, -2))\n self.assertEqual(eye(3, 4, 3), torch.eye(3, 4, 3))\n self.assertEqual(eye(3, 4, -3), torch.eye(3, 4, -3))\n self.assertEqual(eye(3, 4, 4), torch.eye(3, 4, 4))\n self.assertEqual(eye(3, 4, -4), torch.eye(3, 4, -4))\n self.assertEqual(eye(3, 4, 5), torch.eye(3, 4, 5))\n self.assertEqual(eye(3, 4, -5), torch.eye(3, 4, -5))\n self.assertEqual(eye(3, 4, 6), torch.eye(3, 4, 6))\n self.assertEqual(eye(3, 4, -6), torch.eye(3, 4, -6))\n self.assertEqual(eye(3, 4, 7), torch.eye(3, 4, 7))\n self.assertEqual(eye(3, 4, -7), torch.eye(3, 4, -7))\n self.assertEqual(eye(3, 4, 8), torch.eye(3, 4, 8))\n self.assertEqual(eye(3, 4, -8), torch.eye(3, 4, -8))\n self.assertEqual(eye(3, 4, 9), torch.eye(3, 4, 9))\n self.assertEqual(eye(3, 4, -9), torch.eye(3, 4, -9))\n self.assertEqual(eye(3, 4, 10), torch.eye(3, 4, 10))\n self.assertEqual(eye(3, 4, -10), torch.eye(3, 4, -10))\n self.assertEqual(eye(3, 4, 11), torch.eye(3, 4, 11))\n self.assertEqual(eye(3, 4, -11), torch.eye(3, 4, -11))\n self.assertEqual(eye(3, 4, 12), torch.eye(3, 4, 12))\n self.assertEqual(eye(3, 4, -12), torch.eye(3, 4, -12))\n self.assertEqual(eye(3, 4, 13), torch.eye(3, 4, 13))\n self.assertEqual(eye(3, 4, -13), torch.eye(3, 4, -13))\n self.assertEqual(eye(3, 4, 14), torch.eye(3, 4, 14))\n self.assertEqual(eye(3, 4, -14), torch.eye(3, 4, -14))\n self.assertEqual(eye(3, 4, 15), torch.eye(" + ], + "line": 27, + "token": 219, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef ragged_qk_dotprod(\n query: RaggedActivations, key: RaggedActivations, lut: RaggedQkPidLookupTable\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n # check constraints\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n # allocates output\n # max_n_ctx_q_across_seqs = query.max_n_ctx_per_seq\n\n assert query.n_seqs == key.n_seqs\n # TODO: flag use zeros for garbage\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n # Integers\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n```\n###test function signature:\n```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ragged_qk_dotprod_single_seq", + "reference": "\ndef test_ragged_qk_dotprod_single_seq(dtype):\n d_head = 2\n\n key = RaggedActivations.from_list(\n [\n make_seq(n_ctx=3, value=42, d_model=d_head, dtype=dtype),\n ]\n )\n query = RaggedActivations.from_list(\n [\n make_seq(n_ctx=4, value=55, d_model=d_head, dtype=dtype),\n ]\n )\n torch_scores = scores_via_qk_dotprod(query, key)\n print(f\"{torch_scores=}\")\n\n lut = RaggedQkPidLookupTable.from_query_and_key_tokens_per_seq(\n n_ctx_q_per_seq=query.n_ctx_per_seq, n_ctx_k_per_seq=key.n_ctx_per_seq\n )\n\n scores = ragged_qk_dotprod(query, key, lut)\n assert_eq(torch_scores, scores)", + "completions": [ + "```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n # query\n query_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n query = RaggedActivations(\n raw_tensor=query_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # key\n key_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n key = RaggedActivations(\n raw_tensor=key_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # lut\n pid_to_in_q_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_in_k_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_q_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_k_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_seq_idx = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n lut = RaggedQkPidLookupTable(\n pid_to_in_q_token_offset=pid_to_in_q_token_offset,\n pid_to_in_k_token_offset=pid_to_in_k_token_offset,\n pid_to_out_q_block=pid_to_out_q_block,\n pid_to_out_k_block=pid_to_out_k_block,\n pid_to_out_seq_idx=pid_to_out_seq_idx,\n n_pids_total=1,\n )\n\n # scores_out\n scores_out = torch.ones(\n (1, 1, 1), dtype=dtype, device=torch.device(\"cuda\")\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_" + ], + "line": 57, + "token": 310, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 89, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):\n # Ravel both arrays, behavior for the first array could be different\n ar1 = np.asarray(ar1).ravel()\n ar2 = np.asarray(ar2).ravel()\n\n # Ensure that iteration through object arrays yields size-1 arrays\n if ar2.dtype == object:\n ar2 = ar2.reshape(-1, 1)\n\n if kind not in {None, 'sort', 'table'}:\n raise ValueError(\n f\"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.\")\n\n # Can use the table method if all arrays are integers or boolean:\n is_int_arrays = all(ar.dtype.kind in (\"u\", \"i\", \"b\") for ar in (ar1, ar2))\n use_table_method = is_int_arrays and kind in {None, 'table'}\n\n if use_table_method:\n if ar2.size == 0:\n if invert:\n return np.ones_like(ar1, dtype=bool)\n else:\n return np.zeros_like(ar1, dtype=bool)\n\n # Convert booleans to uint8 so we can use the fast integer algorithm\n if ar1.dtype == bool:\n ar1 = ar1.astype(np.uint8)\n if ar2.dtype == bool:\n ar2 = ar2.astype(np.uint8)\n\n ar2_min = np.min(ar2)\n ar2_max = np.max(ar2)\n\n ar2_range = int(ar2_max) - int(ar2_min)\n\n # Constraints on whether we can actually use the table method:\n # 1. Assert memory usage is not too large\n below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)\n # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype\n range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max\n # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype\n if ar1.size > 0:\n ar1_min = np.min(ar1)\n ar1_max = np.max(ar1)\n\n # After masking, the range of ar1 is guaranteed to be\n # within the range of ar2:\n ar1_upper = min(int(ar1_max), int(ar2_max))\n ar1_lower = max(int(ar1_min), int(ar2_min))\n\n range_safe_from_overflow &= all((\n ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,\n ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min\n ))\n\n # Optimal performance is for approximately\n # log10(size) > (log10(range) - 2.27) / 0.927.\n # However, here we set the requirement that by default\n # the intermediate array can only be 6x\n # the combined memory allocation of the original\n # arrays. See discussion on \n # https://github.com/numpy/numpy/pull/12065.\n\n if (\n range_safe_from_overflow and \n (below_memory_constraint or kind == 'table')\n ):\n\n if invert:\n outgoing_array = np.ones_like(ar1, dtype=bool)\n else:\n outgoing_array = np.zeros_like(ar1, dtype=bool)\n\n # Make elements 1 where the integer exists in ar2\n if invert:\n isin_helper_ar = np.ones(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 0\n else:\n isin_helper_ar = np.zeros(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 1\n\n # Mask out elements we know won't work\n basic_mask = (ar1 <= ar2_max) & (ar1 >= ar2_min)\n outgoing_array[basic_mask] = isin_helper_ar[ar1[basic_mask] -\n ar2_min]\n\n return outgoing_array\n elif kind == 'table': # not range_safe_from_overflow\n raise RuntimeError(\n \"You have specified kind='table', \"\n \"but the range of values in `ar2` or `ar1` exceed the \"\n \"maximum integer of the datatype. \"\n \"Please set `kind` to None or 'sort'.\"\n )\n elif kind == 'table':\n raise ValueError(\n \"The 'table' method is only \"\n \"supported for boolean or integer arrays. \"\n \"Please select 'sort' or None for kind.\"\n )\n\n\n # Check if one of the arrays may contain arbitrary objects\n contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject\n\n # This code is run when\n # a) the first condition is true, making the code significantly faster\n # b) the second condition is true (i.e. `ar1` or `ar2` may contain\n # arbitrary objects), since then sorting is not guaranteed to work\n if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:\n if invert:\n mask = np.ones(len(ar1), dtype=bool)\n for a in ar2:\n mask &= (ar1 != a)\n else:\n mask = np.zeros(len(ar1), dtype=bool)\n for a in ar2:\n mask |= (ar1 == a)\n return mask\n\n # Otherwise use sorting\n if not assume_unique:\n ar1, rev_idx = np.unique(ar1, return_inverse=True)\n ar2 = np.unique(ar2)\n\n ar = np.concatenate((ar1, ar2))\n # We need this to be a stable sort, so always use 'mergesort'\n # here. The values from the first array should always come before\n # the values from the second array.\n order = ar.argsort(kind='mergesort')\n sar = ar[order]\n if invert:\n bool_ar = (sar[1:] != sar[:-1])\n else:\n bool_ar = (sar[1:] == sar[:-1])\n flag = np.concatenate((bool_ar, [invert]))\n ret = np.empty(ar.shape, dtype=bool)\n ret[order] = flag\n\n if assume_unique:\n return ret[:len(ar1)]\n else:\n return ret[rev_idx]\n```\n###test function signature:\n```python\n\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_in1d_ravel", + "reference": "\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n a = np.arange(6).reshape(2, 3)\n b = np.arange(3, 9).reshape(3, 2)\n long_b = np.arange(3, 63).reshape(30, 2)\n ec = np.array([False, False, False, True, True, True])\n\n assert_array_equal(in1d(a, b, assume_unique=True, kind=kind),\n ec)\n assert_array_equal(in1d(a, b, assume_unique=False,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=True,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=False,\n kind=kind),\n ec)", + "completions": [ + "```python\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n ar1 = np.array([1, 2, 3, 4, 5])\n ar2 = np.array([1, 2, 3, 4, 5])\n ar1_ravel = ar1.ravel()\n ar2_ravel = ar2.ravel()\n self.assertTrue(np.array_equal(ar1_ravel, ar1))\n self.assertTrue(np.array_equal(ar2_ravel, ar2))\n self.assertTrue(np.array_equal(in1d(ar1, ar2, kind=kind),\n in1d(ar1_ravel, ar2_ravel, kind=kind)))\n```\n" + ], + "line": 133, + "token": 1112, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 90, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_metadata", + "reference": "\ndef test_header_to_metadata_and_cell_metadata():\n text = \"\"\"---\ntitle: Sample header\njupyter:\n mainlanguage: python\n---\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {\"mainlanguage\": \"python\"}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {\"lines_to_next_cell\": 0}\n assert pos == len(lines)", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_metadata():\n lines = [\n \"#' ---\",\n \"#' jupyter:\",\n \"#' jupytext:\",\n \"#' text_representation:\",\n \"#' extension: .py\",\n \"#' format_name: light\",\n \"#' format_version: '1.5'\",\n \"#' jupytext_version: 1.11.4\",\n \"#' kernelspec:\",\n \"#' display_name: Python 3\",\n \"#' language: python\",\n \"#' name: python3\",\n \"#' ---\",\n \"#'\",\n \"#' This is a comment\",\n \"#'\",\n \"#' And another one\",\n \"#'\",\n \"#' ```python\",\n \"#' import numpy as np\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' import matplotlib.pyplot as plt\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %matplotlib inline\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2" + ], + "line": 104, + "token": 642, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef import_optional_dependency(\n name: str,\n extra: str = \"\",\n errors: str = \"raise\",\n min_version: str | None = None,\n):\n\n assert errors in {\"warn\", \"raise\", \"ignore\"}\n\n package_name = INSTALL_MAPPING.get(name)\n install_name = package_name if package_name is not None else name\n\n msg = (\n f\"Missing optional dependency '{install_name}'. {extra} \"\n f\"Use pip or conda to install {install_name}.\"\n )\n try:\n module = importlib.import_module(name)\n except ImportError:\n if errors == \"raise\":\n raise ImportError(msg)\n return None\n\n # Handle submodules: if we have submodule, grab parent module from sys.modules\n parent = name.split(\".\")[0]\n if parent != name:\n install_name = parent\n module_to_get = sys.modules[install_name]\n else:\n module_to_get = module\n minimum_version = min_version if min_version is not None else VERSIONS.get(parent)\n if minimum_version:\n version = get_version(module_to_get)\n if version and Version(version) < Version(minimum_version):\n msg = (\n f\"Pandas requires version '{minimum_version}' or newer of '{parent}' \"\n f\"(version '{version}' currently installed).\"\n )\n if errors == \"warn\":\n warnings.warn(\n msg,\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return None\n elif errors == \"raise\":\n raise ImportError(msg)\n\n return module\n```\n###test function signature:\n```python\n\ndef test_bad_version(monkeypatch):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_bad_version", + "reference": "\ndef test_bad_version(monkeypatch):\n name = \"fakemodule\"\n module = types.ModuleType(name)\n module.__version__ = \"0.9.0\"\n sys.modules[name] = module\n monkeypatch.setitem(VERSIONS, name, \"1.0.0\")\n\n match = \"Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'\"\n with pytest.raises(ImportError, match=match):\n import_optional_dependency(\"fakemodule\")\n\n # Test min_version parameter\n result = import_optional_dependency(\"fakemodule\", min_version=\"0.8\")\n assert result is module\n\n with tm.assert_produces_warning(UserWarning):\n result = import_optional_dependency(\"fakemodule\", errors=\"warn\")\n assert result is None\n\n module.__version__ = \"1.0.0\" # exact match is OK\n result = import_optional_dependency(\"fakemodule\")\n assert result is module", + "completions": [ + "```python\ndef test_bad_version(monkeypatch):\n monkeypatch.setattr(import_optional_dependency, \"get_version\", lambda x: \"0.1.0\")\n monkeypatch.setattr(import_optional_dependency, \"VERSIONS\", {\"pandas\": \"0.2.0\"})\n with pytest.raises(ImportError):\n import_optional_dependency(\"pandas\")\n```\n" + ], + "line": 54, + "token": 352, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_mel_spectrogram(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mel_spectrogram", + "reference": "\n def test_mel_spectrogram(self):\n waveform = self._load_datasamples(1)[0]\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=513,\n num_mel_filters=13,\n min_frequency=100,\n max_frequency=4000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"htk\",\n )\n self.assertEqual(mel_filters.shape, (513, 13))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n )\n self.assertEqual(spec.shape, (513, 732))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n mel_filters=mel_filters,\n )\n self.assertEqual(spec.shape, (13, 732))\n\n # fmt: off\n expected = np.array([\n 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,\n 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,\n 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,\n 9.44153646e-04\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:, 300], expected))", + "completions": [ + "```python\n def test_mel_spectrogram(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n mel_floor = 1e-10\n log_mel = \"log\"\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (1, 10))\n self.assertEqual(spectrogram[0][0], 1)\n self.assertEqual(spectrogram[0][1], 2)\n self.assertEqual(spectrogram[0][2], 3)\n self.assertEqual(spectrogram[0][3], 4)\n self.assertEqual(spectrogram[0][4], 5)\n self.assertEqual(spectrogram[0][5], 6)\n self.assertEqual(spectrogram[0][6], 7)\n self.assertEqual(spectrogram[0][7], 8)\n self.assertEqual(spectrogram[0][8], 9)\n self.assertEqual(spectrogram[0][9], 10)\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 93, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef create(\n meta_schema,\n validators=(),\n version=None,\n type_checker=_types.draft202012_type_checker,\n format_checker=_format.draft202012_format_checker,\n id_of=_id_of,\n applicable_validators=methodcaller(\"items\"),\n):\n # preemptively don't shadow the `Validator.format_checker` local\n format_checker_arg = format_checker\n\n @attr.s\n class Validator:\n\n VALIDATORS = dict(validators)\n META_SCHEMA = dict(meta_schema)\n TYPE_CHECKER = type_checker\n FORMAT_CHECKER = format_checker_arg\n ID_OF = staticmethod(id_of)\n\n schema = attr.ib(repr=reprlib.repr)\n resolver = attr.ib(default=None, repr=False)\n format_checker = attr.ib(default=None)\n\n def __init_subclass__(cls):\n warnings.warn(\n (\n \"Subclassing validator classes is not intended to \"\n \"be part of their public API. A future version \"\n \"will make doing so an error, as the behavior of \"\n \"subclasses isn't guaranteed to stay the same \"\n \"between releases of jsonschema. Instead, prefer \"\n \"composition of validators, wrapping them in an object \"\n \"owned entirely by the downstream library.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n\n def __attrs_post_init__(self):\n if self.resolver is None:\n self.resolver = RefResolver.from_schema(\n self.schema,\n id_of=id_of,\n )\n\n @classmethod\n def check_schema(cls, schema, format_checker=_UNSET):\n Validator = validator_for(cls.META_SCHEMA, default=cls)\n if format_checker is _UNSET:\n format_checker = Validator.FORMAT_CHECKER\n validator = Validator(\n schema=cls.META_SCHEMA,\n format_checker=format_checker,\n )\n for error in validator.iter_errors(schema):\n raise exceptions.SchemaError.create_from(error)\n\n def evolve(self, **changes):\n # Essentially reproduces attr.evolve, but may involve instantiating\n # a different class than this one.\n cls = self.__class__\n\n schema = changes.setdefault(\"schema\", self.schema)\n NewValidator = validator_for(schema, default=cls)\n\n for field in attr.fields(cls):\n if not field.init:\n continue\n attr_name = field.name # To deal with private attributes.\n init_name = attr_name if attr_name[0] != \"_\" else attr_name[1:]\n if init_name not in changes:\n changes[init_name] = getattr(self, attr_name)\n\n return NewValidator(**changes)\n\n def iter_errors(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.iter_errors \"\n \"is deprecated and will be removed in a future \"\n \"release. Call validator.evolve(schema=new_schema).\"\n \"iter_errors(...) instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n else:\n _schema = self.schema\n\n if _schema is True:\n return\n elif _schema is False:\n yield exceptions.ValidationError(\n f\"False schema does not allow {instance!r}\",\n validator=None,\n validator_value=None,\n instance=instance,\n schema=_schema,\n )\n return\n\n scope = id_of(_schema)\n if scope:\n self.resolver.push_scope(scope)\n try:\n for k, v in applicable_validators(_schema):\n validator = self.VALIDATORS.get(k)\n if validator is None:\n continue\n\n errors = validator(self, v, instance, _schema) or ()\n for error in errors:\n # set details if not already set by the called fn\n error._set(\n validator=k,\n validator_value=v,\n instance=instance,\n schema=_schema,\n type_checker=self.TYPE_CHECKER,\n )\n if k not in {\"if\", \"$ref\"}:\n error.schema_path.appendleft(k)\n yield error\n finally:\n if scope:\n self.resolver.pop_scope()\n\n def descend(self, instance, schema, path=None, schema_path=None):\n for error in self.evolve(schema=schema).iter_errors(instance):\n if path is not None:\n error.path.appendleft(path)\n if schema_path is not None:\n error.schema_path.appendleft(schema_path)\n yield error\n\n def validate(self, *args, **kwargs):\n for error in self.iter_errors(*args, **kwargs):\n raise error\n\n def is_type(self, instance, type):\n try:\n return self.TYPE_CHECKER.is_type(instance, type)\n except exceptions.UndefinedTypeCheck:\n raise exceptions.UnknownType(type, instance, self.schema)\n\n def is_valid(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.is_valid is deprecated \"\n \"and will be removed in a future release. Call \"\n \"validator.evolve(schema=new_schema).is_valid(...) \"\n \"instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n self = self.evolve(schema=_schema)\n\n error = next(self.iter_errors(instance), None)\n return error is None\n\n if version is not None:\n safe = version.title().replace(\" \", \"\").replace(\"-\", \"\")\n Validator.__name__ = Validator.__qualname__ = f\"{safe}Validator\"\n Validator = validates(version)(Validator)\n\n return Validator\n```\n###test function signature:\n```python\n\n def test_validates_registers_meta_schema_id(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validates_registers_meta_schema_id", + "reference": "\n def test_validates_registers_meta_schema_id(self):\n meta_schema_key = \"meta schema id\"\n my_meta_schema = {\"id\": meta_schema_key}\n\n validators.create(\n meta_schema=my_meta_schema,\n version=\"my version\",\n id_of=lambda s: s.get(\"id\", \"\"),\n )\n self.addCleanup(validators._META_SCHEMAS.pop, meta_schema_key)\n\n self.assertIn(meta_schema_key, validators._META_SCHEMAS)", + "completions": [ + "```python\n def test_validates_registers_meta_schema_id(self):\n validator = create(\n meta_schema={\n \"$id\": \"https://example.com/schema.json\",\n \"type\": \"object\",\n \"properties\": {\"foo\": {\"type\": \"string\"}},\n }\n )\n self.assertEqual(\n validator.resolver.store[\"https://example.com/schema.json\"],\n validator.schema,\n )\n```\n" + ], + "line": 160, + "token": 896, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 94, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_metrics(directory: str) -> pd.Series:\n\n dfs = []\n # Loop through all csv files in the directory\n csv_files = [\n filename for filename in os.listdir(directory) if filename.endswith(\".csv\")\n ]\n\n if len(csv_files) == 0:\n return pd.Series()\n\n for filename in csv_files:\n filepath = os.path.join(directory, filename)\n # Read the csv file into a pandas DataFrame\n df = pd.read_csv(\n filepath,\n header=None,\n names=[\n \"initial_time_iso\",\n \"lead_time_hours\",\n \"channel\",\n \"metric\",\n \"value\",\n ],\n )\n out = pd.DataFrame()\n out[\"initial_time\"] = pd.to_datetime(df[\"initial_time_iso\"])\n out[\"lead_time\"] = pd.to_timedelta(df[\"lead_time_hours\"], unit=\"h\")\n out[\"channel\"] = df[\"channel\"]\n out[\"metric\"] = df[\"metric\"]\n out[\"value\"] = df[\"value\"]\n dfs.append(out)\n\n df = pd.concat(dfs, axis=0)\n df.set_index([\"initial_time\", \"lead_time\", \"channel\", \"metric\"], inplace=True)\n return df[\"value\"]\n```\n###test function signature:\n```python\n\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_and_read_metric", + "reference": "\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n initial_time = datetime.datetime(2022, 1, 1, 0, 0, 0)\n lead_time = datetime.timedelta(hours=24)\n channel = \"t2m\"\n metric = \"rmse\"\n value = 25.6\n\n # Read the metric from the directory\n directory = tmpdir.mkdir(\"metrics\")\n filename = \"test_metric.csv\"\n filepath = os.path.join(directory, filename)\n with open(filepath, \"w\") as f:\n forecast_metrics_io.write_metric(\n f, initial_time, lead_time, channel, metric, value\n )\n metrics = forecast_metrics_io.read_metrics(directory)\n\n # Check that the metric was written correctly\n expected_index = pd.MultiIndex.from_tuples(\n [(initial_time, lead_time, channel, metric)],\n names=[\"initial_time\", \"lead_time\", \"channel\", \"metric\"],\n )\n expected_value = pd.Series([value], index=expected_index, name=\"value\")\n pd.testing.assert_series_equal(metrics, expected_value)", + "completions": [ + "```python\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n metric = pd.Series(\n data=[1, 2, 3],\n index=[\n pd.Timestamp(\"2020-01-01 00:00:00\"),\n pd.Timestamp(\"2020-01-01 01:00:00\"),\n pd.Timestamp(\"2020-01-01 02:00:00\"),\n ],\n name=\"metric\",\n )\n metric.to_csv(tmpdir.join(\"metric.csv\"), header=False)\n\n # Read the metric from the file object\n read_metric = read_metrics(tmpdir.strpath)\n\n # Check that the metric is the same\n assert_series_equal(metric, read_metric)\n```\n" + ], + "line": 43, + "token": 320, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 95, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n\n def test_kde_bandwidth_method(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_kde_bandwidth_method", + "reference": "\n def test_kde_bandwidth_method(self):\n\n np.random.seed(8765678)\n n_basesample = 50\n xn = np.random.randn(n_basesample)\n\n # Default\n gkde = mlab.GaussianKDE(xn)\n # Supply a callable\n gkde2 = mlab.GaussianKDE(xn, 'scott')\n # Supply a scalar\n gkde3 = mlab.GaussianKDE(xn, bw_method=gkde.factor)\n\n xs = np.linspace(-7, 7, 51)\n kdepdf = gkde.evaluate(xs)\n kdepdf2 = gkde2.evaluate(xs)\n assert kdepdf.all() == kdepdf2.all()\n kdepdf3 = gkde3.evaluate(xs)\n assert kdepdf.all() == kdepdf3.all()", + "completions": [ + "```python\n def test_kde_bandwidth_method(self):\n from sklearn.neighbors import KernelDensity\n from sklearn.metrics import euclidean_distances\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n from sklearn.model_selection import train_test_split\n from sklearn.metrics import accuracy_score\n from sklearn.metrics import f1_score\n from sklearn.metrics import precision_score\n from sklearn.metrics import recall_score\n from sklearn.metrics import roc_auc_score\n from sklearn.metrics import confusion_matrix\n from sklearn.metrics import classification_report\n from sklearn.metrics import cohen_kappa_score\n from sklearn.metrics import mean_squared_error\n from sklearn.metrics import mean_absolute_error\n from sklearn.metrics import mean_squared_log_error\n from sklearn.metrics import median_absolute_error\n from sklearn.metrics import r2_score\n from sklearn.metrics import explained_variance_score\n from sklearn.metrics import silhouette_score\n from sklearn.metrics import calinski_harabasz_score\n from sklearn.metrics import davies_bouldin_score\n from sklearn.metrics import mutual_info_score\n from sklearn.metrics import normalized_mutual_info_score\n from sklearn.metrics import adjusted_mutual_info_score\n from sklearn.metrics import adjusted_rand_score\n from sklearn.metrics import completeness_score\n from sklearn.metrics import contingency_matrix\n from sklearn.metrics import homogeneity_completeness_v_measure\n from sklearn.metrics import homogeneity_score\n from sklearn.metrics import v_measure_score\n from sklearn.metrics import normalized_v_measure_score\n from sklearn.metrics import poisson\n from sklearn.metrics import kld\n from sklearn.metrics import kld_normalized\n from sklearn.metrics import kld_unnormalized\n from sklearn.metrics import kld_unnormalized_normalized\n from sklearn.metrics import kld_unnormalized_normalized_sym\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_top_k_embeddings(\n query_embedding: List[float],\n embeddings: List[List[float]],\n similarity_fn: Optional[Callable[..., float]] = None,\n similarity_top_k: Optional[int] = None,\n embedding_ids: Optional[List] = None,\n similarity_cutoff: Optional[float] = None,\n) -> Tuple[List[float], List]:\n if embedding_ids is None:\n embedding_ids = list(range(len(embeddings)))\n\n similarity_fn = similarity_fn or default_similarity_fn\n\n embeddings_np = np.array(embeddings)\n query_embedding_np = np.array(query_embedding)\n\n similarity_heap: List[Tuple[float, Any]] = []\n for i, emb in enumerate(embeddings_np):\n similarity = similarity_fn(query_embedding_np, emb)\n if similarity_cutoff is None or similarity > similarity_cutoff:\n heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))\n if similarity_top_k and len(similarity_heap) > similarity_top_k:\n heapq.heappop(similarity_heap)\n result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)\n\n result_similarities = [s for s, _ in result_tups]\n result_ids = [n for _, n in result_tups]\n\n return result_similarities, result_ids\n```\n###test function signature:\n```python\ndef test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_top_k_mmr_embeddings", + "reference": "def test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n query_embedding = [5.0, 0.0, 0.0]\n embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]]\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.8\n )\n\n assert np.isclose(0.8 * 4 / 5, result_similarities[0], atol=0.00001)\n assert np.isclose(\n 0.8 * 3 / 5 - (1 - 0.8) * (3 * 4 / 25 + 3 * 4 / 25),\n result_similarities[1],\n atol=0.00001,\n )\n assert np.isclose(\n 0.8 * -4 / 5 - (1 - 0.8) * (3 * -4 / 25 + 4 * 3 / 25),\n result_similarities[2],\n atol=0.00001,\n )\n assert result_ids == [0, 1, 2]\n\n # Tests that if the first embedding vector is close to the second,\n # it will return the third\n query_embedding = [1.0, 0.0, 1.0]\n embeddings = [[1.0, 0.0, 0.9], [1.0, 0.0, 0.8], [0.7, 0.0, 1.0]]\n\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.5\n )\n assert result_ids == [0, 2, 1]\n\n # Tests that embedding ids map properly to results\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, embedding_ids=[\"A\", \"B\", \"C\"], mmr_threshold=0.5\n )\n assert result_ids == [\"A\", \"C\", \"B\"]\n # Test that it will go back to the original order under a high threshold\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n assert result_ids == [0, 1, 2]\n\n # Test similarity_top_k works\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1, similarity_top_k=2\n )\n assert result_ids == [0, 1]\n\n # Test the results for get_top_k_embeddings and get_top_k_mmr_embeddings are the\n # same for threshold = 1\n query_embedding = [10, 23, 90, 78]\n embeddings = [[1, 23, 89, 68], [1, 74, 144, 23], [0.23, 0.0, 1.0, 9]]\n result_similarities_no_mmr, result_ids_no_mmr = get_top_k_embeddings(\n query_embedding, embeddings\n )\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n\n for result_no_mmr, result_with_mmr in zip(\n result_similarities_no_mmr, result_similarities\n ):\n assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)", + "completions": [ + "```python\ndef test_get_top_k_mmr_embeddings() -> None:\n query_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]\n embeddings = [\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n ]\n result_similarities, result_ids = get_top_k_embeddings(\n query_embedding, embeddings, similarity_top_k=3\n )\n assert result_similarities == [0.5, 0.4, 0.3]\n assert result_ids == [0, 1, 2]\n```\n" + ], + "line": 35, + "token": 309, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 97, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef frame_accuracy(\n logits: np.ndarray,\n label: np.ndarray,\n background_logit_threshold: float = 0.0,\n) -> float:\n top1_idx = np.argmax(logits, axis=-1)\n background_label = np.sum(label, axis=-1) == 0\n pred_background = (\n np.max(logits, axis=-1) <= background_logit_threshold\n ).astype(np.int32)\n\n # Extracts the label at the highest logit index for each input.\n top1_correct = np.take_along_axis(label, top1_idx[..., None], axis=-1)\n top1_correct = np.squeeze(top1_correct)\n foreground_correct = ~pred_background.astype(bool) * top1_correct\n\n # Count correctly classified background samples.\n background_correct = pred_background * background_label\n correct = foreground_correct + background_correct\n return np.sum(correct) / len(logits)\n```\n###test function signature:\n```python\n\n def test_frame_accuracy_all_background(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_frame_accuracy_all_background", + "reference": "\n def test_frame_accuracy_all_background(self):\n logits = np.array(\n [\n [1.5, 0.0, -1.0],\n [0.1, 0.7, 0.2],\n [0.1, 0.7, 0.2],\n [0.1, -0.7, 0.2],\n [-0.1, -0.7, -0.2],\n ],\n np.float32,\n )\n label = np.array(\n [\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n ],\n np.int32,\n )\n actual = metrics.frame_accuracy(logits, label)\n self.assertAlmostEqual(actual, 0.2)", + "completions": [ + "```python\ndef test_frame_accuracy_all_background(self):\n logits = np.array([[0.1, 0.2, 0.3, 0.4]])\n label = np.array([[0, 0, 0, 0]])\n self.assertAlmostEqual(frame_accuracy(logits, label), 1.0)\n```\n" + ], + "line": 28, + "token": 206, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_local_image_to_be_loaded", + "reference": "\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n file_path, _ = example_local_image\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=file_path,\n max_width=64,\n max_height=128,\n )\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert abs(scaling_factor - 0.5) < 1e-5\n assert decoding_result.shape == (64, 64, 3)\n assert (decoding_result == 0).all()", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n reference, image = example_local_image\n # when\n result = load_image_from_string(\n reference=reference,\n max_height=image.shape[0],\n max_width=image.shape[1],\n )\n # then\n assert result[0] == reference\n assert result[1] is None\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 99, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = \"htk\") -> Union[float, np.ndarray]:\n\n if mel_scale not in [\"slaney\", \"htk\", \"kaldi\"]:\n raise ValueError('mel_scale should be one of \"htk\", \"slaney\" or \"kaldi\".')\n\n if mel_scale == \"htk\":\n return 2595.0 * np.log10(1.0 + (freq / 700.0))\n elif mel_scale == \"kaldi\":\n return 1127.0 * np.log(1.0 + (freq / 700.0))\n\n min_log_hertz = 1000.0\n min_log_mel = 15.0\n logstep = 27.0 / np.log(6.4)\n mels = 3.0 * freq / 200.0\n\n if isinstance(freq, np.ndarray):\n log_region = freq >= min_log_hertz\n mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep\n elif freq >= min_log_hertz:\n mels = min_log_mel + np.log(freq / min_log_hertz) * logstep\n\n return mels\n```\n###test function signature:\n```python\n\n def test_hertz_to_mel(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_hertz_to_mel", + "reference": "\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(0.0), 0.0)\n self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)\n\n inputs = np.array([100, 200])\n expected = np.array([150.48910241, 283.22989816])\n self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))\n\n self.assertEqual(hertz_to_mel(0.0, \"slaney\"), 0.0)\n self.assertEqual(hertz_to_mel(100, \"slaney\"), 1.5)\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"slaney\"), expected))\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"kaldi\"), expected))\n\n with pytest.raises(ValueError):\n hertz_to_mel(100, mel_scale=None)", + "completions": [ + "```python\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(" + ], + "line": 27, + "token": 262, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/5/EI/line_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Python/5/EI/line_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..f4a8c45ee064a671089531a3ee20cf471744f695 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/5/EI/line_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(15~160),subset_1(166~187),subset_2,subset_3,subset_4(749~749) +StarCoder2-15b,24.30,28.83,0.00,0.00,20.70 +CodeLlama-7b,26.69,21.94,0.00,0.00,27.18 +CodeLlama-13b,26.72,22.07,0.00,0.00,31.82 +CodeLlama-34b,27.19,27.40,0.00,0.00,24.58 +DeepSeek-Coder-1.3b,25.37,34.47,0.00,0.00,21.28 +DeepSeek-Coder-6.7b,24.65,26.73,0.00,0.00,20.74 +DeepSeek-Coder-33b,28.91,34.42,0.00,0.00,25.46 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/5/EI/token_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Python/5/EI/token_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..2dc13f7b82d4eaa6420800cadc70ed771f77b0a4 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/5/EI/token_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(122~1112),subset_1(1981~1981),subset_2,subset_3,subset_4(7038~7038) +StarCoder2-15b,24.65,15.68,0.00,0.00,20.70 +CodeLlama-7b,26.49,22.34,0.00,0.00,27.18 +CodeLlama-13b,26.64,10.43,0.00,0.00,31.82 +CodeLlama-34b,27.29,19.34,0.00,0.00,24.58 +DeepSeek-Coder-1.3b,25.90,21.64,0.00,0.00,21.28 +DeepSeek-Coder-6.7b,24.71,29.94,0.00,0.00,20.74 +DeepSeek-Coder-33b,29.18,32.72,0.00,0.00,25.46 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/5/EI/tongji.py b/dataset/Test Generation/ComplexCodeEval-Python/5/EI/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/5/EI/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/5/QS/QS.json b/dataset/Test Generation/ComplexCodeEval-Python/5/QS/QS.json new file mode 100644 index 0000000000000000000000000000000000000000..0e8b6521b5057fa0d35b0ad858addbfca1ad4509 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/5/QS/QS.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_pair_vector_and_distance(g: dgl.DGLGraph):\n dst_pos = g.ndata[\"pos\"][g.edges()[1]] + g.edata[\"pbc_offshift\"]\n src_pos = g.ndata[\"pos\"][g.edges()[0]]\n bond_vec = dst_pos - src_pos\n bond_dist = torch.norm(bond_vec, dim=1)\n\n return bond_vec, bond_dist\n```\n###test function signature:\n```python\n\n def test_compute_angle(self, graph_Mo, graph_CH4):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_angle", + "reference": "\n def test_compute_angle(self, graph_Mo, graph_CH4):\n s1, g1, state1 = graph_Mo\n lattice = torch.tensor(s1.lattice.matrix, dtype=matgl.float_th).unsqueeze(dim=0)\n g1.edata[\"pbc_offshift\"] = torch.matmul(g1.edata[\"pbc_offset\"], lattice[0])\n g1.ndata[\"pos\"] = g1.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g1)\n g1.edata[\"bond_vec\"] = bv\n g1.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g1, 4.0)\n\n line_graph = create_line_graph(g1, 4.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))\n np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata[\"theta\"])), decimal=4)\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n s2, g2, state2 = graph_CH4\n lattice = torch.tensor(np.identity(3), dtype=matgl.float_th).unsqueeze(dim=0)\n g2.edata[\"pbc_offshift\"] = torch.matmul(g2.edata[\"pbc_offset\"], lattice[0])\n g2.ndata[\"pos\"] = g2.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g2)\n g2.edata[\"bond_vec\"] = bv\n g2.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g2, 2.0)\n\n line_graph = create_line_graph(g2, 2.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata[\"theta\"]))\n )\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )", + "completions": [ + "```python\n def test_compute_angle(self, graph_Mo, graph_CH4):\n \"\"\"\n Test compute_angle function\n \"\"\"\n # compute angle between Mo-C and Mo-H\n angle_Mo_C = compute_angle(graph_Mo, graph_CH4, 0, 1)\n angle_Mo_H = compute_angle(graph_Mo, graph_CH4, 0, 2)\n\n # check angle between Mo-C and Mo-H\n self.assertAlmostEqual(angle_Mo_C, 104.5, places=1)\n self.assertAlmostEqual(angle_Mo_H, 109.5, places=1)\n```\n" + ], + "line": 16, + "token": 148, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 1, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def _make_schedule(self, targets: List[Task]) -> List[Task]:\n self.schedule = []\n self.dependencies = self._build_dependencies(targets)\n\n edge_tups = []\n for node in self.dependencies:\n for dependency in self.dependencies[node]:\n edge_tups.append((dependency, node))\n\n for task in targets:\n # add edges from a dummy node to each target to guarantee\n # they will be included in the final schedule\n edge_tups.append((Executor.DUMMY_TASK_VALUE, task))\n\n def _compare_key(task: Union[Task, str]):\n if task == Executor.DUMMY_TASK_VALUE:\n return (\"\", 0)\n return (\n task.group_label() or \"\",\n -task.priority(),\n )\n\n graph = networkx.DiGraph(edge_tups)\n res = [\n t\n for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)\n if t != Executor.DUMMY_TASK_VALUE\n ]\n return res\n```\n###test function signature:\n```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_group_label_with_dependencies", + "reference": "\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = create_mock_task(\"task1\", result=1, group_label=\"group1\")\n task2 = create_mock_task(\n \"task2\", result=2, dependencies={\"task1\": task1}, group_label=\"group2\"\n )\n task3 = create_mock_task(\n \"task3\", result=3, dependencies={\"task2\": task2}, group_label=\"group1\"\n )\n\n executor = Executor([task3])\n schedule = executor._make_schedule([task3])\n scheduled_labels = [\n task.group_label() for task in schedule if task.group_label()\n ]\n\n # Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1\n group1_indices = [\n i for i, label in enumerate(scheduled_labels) if label == \"group1\"\n ]\n group2_index = scheduled_labels.index(\"group2\")\n\n assert (\n group1_indices[-1] > group2_index\n ), \"Task with the same group label but later dependency was not scheduled after different group label\"", + "completions": [ + "```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = Task(name=\"task1\", group_label=\"group1\")\n task2 = Task(name=\"task2\", group_label=\"group1\")\n task3 = Task(name=\"task3\", group_label=\"group2\")\n task4 = Task(name=\"task4\", group_label=\"group2\")\n task5 = Task(name=\"task5\", group_label=\"group3\")\n task6 = Task(name=\"task6\", group_label=\"group3\")\n task7 = Task(name=\"task7\", group_label=\"group4\")\n task8 = Task(name=\"task8\", group_label=\"group4\")\n task9 = Task(name=\"task9\", group_label=\"group5\")\n task10 = Task(name=\"task10\", group_label=\"group5\")\n\n # Creating dependencies between tasks\n task1.depends_on(task2)\n task2.depends_on(task3)\n task3.depends_on(task4)\n task4.depends_on(task5)\n task5.depends_on(task6)\n task6.depends_on(task7)\n task7.depends_on(task8)\n task8.depends_on(task9)\n task9.depends_on(task10)\n\n # Creating a list of tasks\n tasks = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies\n tasks_with_dependencies = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with group labels\n tasks_with_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_" + ], + "line": 36, + "token": 242, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 2, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n):\n if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:\n return\n\n head_dim = dim // heads\n max_seq_len = find_multiple(max_seq_len, 8)\n\n for b in layers:\n b.attention.kv_cache = KVCache(\n max_batch_size, max_seq_len, heads, head_dim\n )\n\n freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)\n causal_mask = torch.tril(\n torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)\n )\n\n return causal_mask, freq_cis\n```\n###test function signature:\n```python\n\ndef test_setup_cache_max_batch_size_greater_than_max():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_setup_cache_max_batch_size_greater_than_max", + "reference": "\ndef test_setup_cache_max_batch_size_greater_than_max():\n layers = [\n Mock(attention=Mock(kw_cache=None)),\n Mock(attention=Mock(kw_cache=None)),\n ]\n max_batch_size = 64\n max_seq_len = 32\n heads = 4\n head_dim = 32\n block_size = 32\n rope_base = 1000\n setup_cache(\n max_batch_size + 10,\n max_seq_len,\n head_dim * heads,\n heads,\n layers,\n block_size,\n rope_base,\n )\n for layer in layers:\n assert isinstance(layer.attention.kw_cache, KVCache)\n assert layer.attention.kw_cache.k_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )\n assert layer.attention.kw_cache.v_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )", + "completions": [ + "```python\ndef test_setup_cache_max_batch_size_greater_than_max():\n max_batch_size = 10\n max_seq_len = 10\n dim = 10\n heads = 10\n layers = [1, 2, 3]\n block_size = 10\n rope_base = 10\n\n causal_mask, freq_cis = setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n )\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size" + ], + "line": 26, + "token": 168, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_file_to_be_downloaded", + "reference": "\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n response = Response()\n response.status_code = 200\n image = np.zeros((128, 128, 3), dtype=np.uint8)\n _, encoded_image = cv2.imencode(\".jpg\", image)\n response._content = encoded_image\n requests_get_mock.return_value = response\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=\"https://some/file.jpg\"\n )\n\n # then\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert scaling_factor is None\n assert decoding_result.shape == image.shape\n assert np.allclose(decoding_result, image)\n requests_get_mock.assert_called_once_with(\"https://some/file.jpg\")", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n reference = \"https://www.example.com/image.jpg\"\n max_height = 100\n max_width = 100\n requests_get_mock.return_value.content = b\"image_bytes\"\n # when\n result = load_image_from_string(\n reference=reference, max_height=max_height, max_width=max_width\n )\n # then\n assert result == (\n \"data:image/jpeg;base64,image_bytes\",\n None,\n )\n```\n" + ], + "line": 47, + "token": 276, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 4, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_random_circuits", + "reference": "\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n placer = placer(connectivity=connectivity)\n layout_circ = Circuit(5)\n initial_layout = placer(layout_circ)\n router = Sabre(connectivity=connectivity, lookahead=look, decay_lookahead=decay)\n circuit = generate_random_circuit(nqubits=5, ngates=n_gates)\n measurement = gates.M(*range(5))\n circuit.add(measurement)\n transpiled_circuit, final_qubit_map = router(circuit, initial_layout)\n assert router.added_swaps >= 0\n assert_connectivity(connectivity, transpiled_circuit)\n assert_placement(transpiled_circuit, final_qubit_map)\n assert n_gates + router.added_swaps + 1 == transpiled_circuit.ngates\n assert_circuit_equivalence(\n original_circuit=circuit,\n transpiled_circuit=transpiled_circuit,\n final_map=final_qubit_map,\n initial_map=initial_layout,\n )\n assert transpiled_circuit.queue[-1].register_name == measurement.register_name", + "completions": [ + "```python\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n import random\n import networkx as nx\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import" + ], + "line": 27, + "token": 250, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 5, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef xx_circuit_step(source, strength, target, embodiment):\n\n permute_source_for_overlap, permute_target_for_overlap = None, None\n\n # apply all possible reflections, shifts to the source\n for source_reflection_name in reflection_options:\n reflected_source_coord, source_reflection, reflection_phase_shift = apply_reflection(\n source_reflection_name, source\n )\n for source_shift_name in shift_options:\n shifted_source_coord, source_shift, shift_phase_shift = apply_shift(\n source_shift_name, reflected_source_coord\n )\n\n # check for overlap, back out permutation\n source_shared, target_shared = None, None\n for i, j in [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]:\n\n if (\n abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi)) < EPSILON\n or abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi) - np.pi)\n < EPSILON\n ):\n source_shared, target_shared = i, j\n break\n if source_shared is None:\n continue\n\n # pick out the other coordinates\n source_first, source_second = (x for x in [0, 1, 2] if x != source_shared)\n target_first, target_second = (x for x in [0, 1, 2] if x != target_shared)\n\n # check for arccos validity\n r, s, u, v, x, y = decompose_xxyy_into_xxyy_xx(\n float(target[target_first]),\n float(target[target_second]),\n float(shifted_source_coord[source_first]),\n float(shifted_source_coord[source_second]),\n float(strength),\n )\n if any(math.isnan(val) for val in (r, s, u, v, x, y)):\n continue\n\n # OK: this combination of things works.\n # save the permutation which rotates the shared coordinate into ZZ.\n permute_source_for_overlap = canonical_rotation_circuit(source_first, source_second)\n permute_target_for_overlap = canonical_rotation_circuit(target_first, target_second)\n break\n\n if permute_source_for_overlap is not None:\n break\n\n if permute_source_for_overlap is None:\n raise QiskitError(\n \"Error during RZX decomposition: Could not find a suitable Weyl \"\n f\"reflection to match {source} to {target} along {strength}.\"\n )\n\n prefix_circuit, affix_circuit = QuantumCircuit(2), QuantumCircuit(2)\n\n # the basic formula we're trying to work with is:\n # target^p_t_f_o =\n # rs * (source^s_reflection * s_shift)^p_s_f_o * uv * operation * xy\n # but we're rearranging it into the form\n # target = affix source prefix\n # and computing just the prefix / affix circuits.\n\n # the outermost prefix layer comes from the (inverse) target permutation.\n prefix_circuit.compose(permute_target_for_overlap.inverse(), inplace=True)\n # the middle prefix layer comes from the local Z rolls.\n prefix_circuit.rz(2 * x, [0])\n prefix_circuit.rz(2 * y, [1])\n prefix_circuit.compose(embodiment, inplace=True)\n prefix_circuit.rz(2 * u, [0])\n prefix_circuit.rz(2 * v, [1])\n # the innermost prefix layer is source_reflection, shifted by source_shift,\n # finally conjugated by p_s_f_o.\n prefix_circuit.compose(permute_source_for_overlap, inplace=True)\n prefix_circuit.compose(source_reflection, inplace=True)\n prefix_circuit.global_phase += -np.log(reflection_phase_shift).imag\n prefix_circuit.global_phase += -np.log(shift_phase_shift).imag\n\n # the affix circuit is constructed in reverse.\n # first (i.e., innermost), we install the other half of the source transformations and p_s_f_o.\n affix_circuit.compose(source_reflection.inverse(), inplace=True)\n affix_circuit.compose(source_shift, inplace=True)\n affix_circuit.compose(permute_source_for_overlap.inverse(), inplace=True)\n # then, the other local rolls in the middle.\n affix_circuit.rz(2 * r, [0])\n affix_circuit.rz(2 * s, [1])\n # finally, the other half of the p_t_f_o conjugation.\n affix_circuit.compose(permute_target_for_overlap, inplace=True)\n\n return {\"prefix_circuit\": prefix_circuit, \"affix_circuit\": affix_circuit}\n```\n###test function signature:\n```python\n def test_xx_circuit_step(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_xx_circuit_step", + "reference": " def test_xx_circuit_step(self):\n\n for _ in range(100):\n source_coordinate, interaction, target_coordinate = self._generate_xxyy_test_case()\n\n source_embodiment = qiskit.QuantumCircuit(2)\n source_embodiment.append(UnitaryGate(canonical_matrix(*source_coordinate)), [0, 1])\n interaction_embodiment = qiskit.QuantumCircuit(2)\n interaction_embodiment.append(UnitaryGate(canonical_matrix(*interaction)), [0, 1])\n\n prefix_circuit, affix_circuit = itemgetter(\"prefix_circuit\", \"affix_circuit\")(\n xx_circuit_step(\n source_coordinate, interaction[0], target_coordinate, interaction_embodiment\n )\n )\n\n target_embodiment = QuantumCircuit(2)\n target_embodiment.compose(prefix_circuit, inplace=True)\n target_embodiment.compose(source_embodiment, inplace=True)\n target_embodiment.compose(affix_circuit, inplace=True)\n self.assertTrue(\n np.all(\n np.abs(\n qiskit.quantum_info.operators.Operator(target_embodiment).data\n - canonical_matrix(*target_coordinate)\n )\n < EPSILON\n )\n )", + "completions": [ + "```python\n def test_xx_circuit_step(self):\n source = np.array([0.1, 0.2, 0.3])\n strength = 0.4\n target = np.array([0.5, 0.6, 0.7])\n embodiment = QuantumCircuit(2)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.cx(0, 1)\n embodiment.barrier()\n embodiment.rz(0.1, 0)\n embodiment.rz(0.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.3, 0)\n embodiment.rz(0.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.5, 0)\n embodiment.rz(0.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.7, 0)\n embodiment.rz(0.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.9, 0)\n embodiment.rz(1.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.1, 0)\n embodiment.rz(1.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.3, 0)\n embodiment.rz(1.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.5, 0)\n embodiment.rz(1.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.7, 0)\n embodiment.rz(1.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.9, 0)\n embodiment.rz(2.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(2.1, 0)\n embodiment.rz(2.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment" + ], + "line": 90, + "token": 807, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 6, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_einsum(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_einsum", + "reference": "\n def test_einsum(self):\n x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n self.assertAllClose(\n knp.einsum(\"ijk,lkj->il\", x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"ijk,ikj->i\", x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"i...,j...k->...ijk\", x, y),\n np.einsum(\"i..., j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n self.assertAllClose(\n knp.Einsum(\"ijk,lkj->il\")(x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"ijk,ikj->i\")(x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n np.einsum(\"i...,j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))", + "completions": [ + "```python\n\n def test_arange(self):\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32),\n torch.arange(0, 10, 1, dtype=torch.float32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64),\n torch.arange(0, 10, 1, dtype=torch.float64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32),\n torch.arange(0, 10, 1, dtype=torch.int32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64),\n torch.arange(0, 10, 1, dtype=torch.int64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool),\n torch.arange(0, 10, 1, dtype=torch.bool),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64),\n torch.arange(0, 10, 1, dtype=torch.complex64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128),\n torch.arange(0, 10, 1, dtype=torch.complex128),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n )\n```\n" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 7, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_center_padding(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_center_padding", + "reference": "\n def test_spectrogram_center_padding(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"reflect\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,\n 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,\n 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,\n 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,\n 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,\n 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,\n 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,\n 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,\n 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,\n 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,\n 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,\n 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,\n 0.00217659, 0.00276204, 0.00260835, 0.00299299,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"constant\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,\n 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,\n 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,\n 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,\n 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,\n 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,\n 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,\n 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,\n 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,\n 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,\n 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,\n 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,\n 0.00788239, 0.00664407, 0.00824227, 0.00628301,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=False,\n )\n self.assertEqual(spec.shape, (257, 728))\n\n # fmt: off\n expected = np.array([\n 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,\n 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,\n 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,\n 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,\n 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,\n 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,\n 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,\n 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,\n 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,\n 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,\n 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,\n 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,\n 0.00811857, 0.00538216, 0.00685749, 0.00535275,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))", + "completions": [ + "```python\n def test_spectrogram_center_padding(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 4\n hop_length = 2\n fft_length = 16\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n np.testing.assert_array_equal(\n spectrogram,\n np.array(\n [\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0," + ], + "line": 95, + "token": 804, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 8, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_kword_ref_line", + "reference": "\ndef test_nanoplot_out_bars_with_kword_ref_line():\n\n out_bars_with_kword_ref_line = _generate_nanoplot(**CASES[12])\n\n assert _is_nanoplot_output(out_bars_with_kword_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = \"max\"\n y_ref_area = [1, 2]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = \"min\"\n expand_y = \"min\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_stroke_color=data_bar_negative_stroke_" + ], + "line": 749, + "token": 7038, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 9, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_take_along_axis(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_take_along_axis", + "reference": "\n def test_take_along_axis(self):\n x = np.arange(24).reshape([1, 2, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=1),\n np.take_along_axis(x, indices, axis=1),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=1)(x, indices),\n np.take_along_axis(x, indices, axis=1),\n )\n\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )\n\n # Test with axis=None\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.array([1, 2, 3], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=None),\n np.take_along_axis(x, indices, axis=None),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=None)(x, indices),\n np.take_along_axis(x, indices, axis=None),\n )\n\n # Test with negative indices\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )", + "completions": [ + "```python\n\n def test_take_along_axis(self):\n # test arange\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=1, dtype=torch.float32),\n torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=2, dtype=torch.float32),\n torch.tensor([0, 2, 4, 6, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=3, dtype=torch.float32),\n torch.tensor([0, 3, 6, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=4, dtype=torch.float32),\n torch.tensor([0, 4, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=5, dtype=torch.float32),\n torch.tensor([0, 5], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=6, dtype=torch.float32),\n torch.tensor([0, 6], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=7, dtype=torch.float32),\n torch.tensor([0, 7], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=8, dtype=torch.float32),\n torch.tensor([0, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=9, dtype=torch.float32),\n torch.tensor([0, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=10, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=11, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=12, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=13, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=14, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 10, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_correlation(\n response_level_scores: dict[str, list[dict[str, dict[str, Any]]]],\n name: str = '',\n) -> dict[str, dict[str, dict[str, float]]]:\n factscore_metrics = set(list_metrics(response_level_scores[FACTSCORE]))\n pred_metrics = set(list_metrics(response_level_scores[EVAL_METHOD]))\n result = {}\n\n for metric in factscore_metrics.intersection(pred_metrics):\n factscore_scores = [\n find_metric(response_level_scores[FACTSCORE][i], metric)\n for i in range(len(response_level_scores[FACTSCORE]))\n ]\n pred_scores = [\n find_metric(response_level_scores[EVAL_METHOD][i], metric)\n for i in range(len(response_level_scores[EVAL_METHOD]))\n ]\n\n # Remove failed runs\n for i in range(len(factscore_scores) - 1, -1, -1):\n if factscore_scores[i] == -1 or pred_scores[i] == -1:\n factscore_scores.pop(i)\n pred_scores.pop(i)\n\n # Cannot calculate correlation with only one data point\n if len(factscore_scores) <= 1 or len(pred_scores) <= 1:\n pearson_result, spearman_result = None, None\n else:\n pearson_result = stats.pearsonr(factscore_scores, pred_scores)\n spearman_result = stats.spearmanr(factscore_scores, pred_scores)\n\n scatter_plot(\n factscore_scores,\n pred_scores,\n f'{name}-{metric}',\n FACTSCORE,\n EVAL_METHOD,\n )\n result[metric] = {\n _PEARSON: CorrelationResult(pearson_result).result,\n _SPEARMAN: CorrelationResult(spearman_result).result,\n }\n\n return result\n```\n###test function signature:\n```python\n\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_correlation_one_data_point", + "reference": "\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n response_level_scores = {\n correlation_vs_factscore.FACTSCORE: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'N': 1}},\n ],\n correlation_vs_factscore.EVAL_METHOD: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'M': 1}},\n ],\n }\n actual_output = correlation_vs_factscore.compute_correlation(\n response_level_scores=response_level_scores\n )\n mock_pearsonr.assert_not_called()\n mock_spearmanr.assert_not_called()\n mock_scatter_plot.assert_called()\n self.assertIsInstance(actual_output, dict)\n self.assertLen(actual_output, 1)\n self.assertIn('S', actual_output)\n self.assertIsInstance(actual_output['S'], dict)\n self.assertLen(actual_output['S'], 2)\n self.assertIn('Pearson', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Pearson'], dict)\n self.assertIn('correlation', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['correlation']))\n self.assertIn('p_value', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['p_value']))\n self.assertIn('Spearman', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Spearman'], dict)\n self.assertIn('correlation', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['correlation']))\n self.assertIn('p_value', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['p_value']))", + "completions": [ + "```python\ndef test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n) -> None:\n response_level_scores = {\n FACTSCORE: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n EVAL_METHOD: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n }\n result = compute_correlation(response_level_scores)\n self.assertEqual(result, {})\n mock_spearmanr.assert_not_called()\n mock_pearsonr.assert_not_called()\n mock_scatter_plot.assert_not_called()\n```\n" + ], + "line": 54, + "token": 384, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 11, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_data_lines():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_data_lines", + "reference": "\ndef test_nanoplot_out_data_lines():\n\n out_data_lines = _generate_nanoplot(**CASES[0])\n\n assert _is_nanoplot_output(out_data_lines)\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_data_lines():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n expected = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_" + ], + "line": 749, + "token": 7038, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 12, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_daily_temp_mean", + "reference": "\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data = il_electricity_cdd_hdd_daily[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_daily[\"temperature_data\"]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert df.shape == (810, 3)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_mean\",\n ]\n\n assert round(df.temperature_mean.mean()) == 55.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"D\", tz=\"UTC\"\n )\n temperature_data = pd.DataFrame(\n {\"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]},\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n df = compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=[],\n cooling_balance_points=[],\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n )\n expected = pd.DataFrame(\n {\n \"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],\n \"n_hours_kept\": [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],\n \"n_hours_dropped\": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n },\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n pd.testing.assert_frame_equal(df, expected)\n```\n" + ], + "line": 166, + "token": 983, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 13, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef power_color(\n frequency,\n power,\n power_err=None,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n):\n freq_edges = np.asarray(freq_edges)\n if len(freq_edges) != 5:\n raise ValueError(\"freq_edges must have 5 elements\")\n\n frequency = np.asarray(frequency)\n power = np.asarray(power)\n\n if df is None:\n df = np.median(np.diff(frequency))\n input_frequency_low_edges = frequency - df / 2\n input_frequency_high_edges = frequency + df / 2\n\n if freq_edges.min() < input_frequency_low_edges[0]:\n raise ValueError(\"The minimum frequency is larger than the first frequency edge\")\n if freq_edges.max() > input_frequency_high_edges[-1]:\n raise ValueError(\"The maximum frequency is lower than the last frequency edge\")\n\n if power_err is None:\n power_err = power / np.sqrt(m)\n else:\n power_err = np.asarray(power_err)\n\n if freqs_to_exclude is not None:\n if len(np.shape(freqs_to_exclude)) == 1:\n freqs_to_exclude = [freqs_to_exclude]\n\n if (\n not isinstance(freqs_to_exclude, Iterable)\n or len(np.shape(freqs_to_exclude)) != 2\n or np.shape(freqs_to_exclude)[1] != 2\n ):\n raise ValueError(\"freqs_to_exclude must be of format [[f0, f1], [f2, f3], ...]\")\n for f0, f1 in freqs_to_exclude:\n frequency_mask = (input_frequency_low_edges > f0) & (input_frequency_high_edges < f1)\n idx0, idx1 = np.searchsorted(frequency, [f0, f1])\n power[frequency_mask] = np.mean([power[idx0], power[idx1]])\n\n var00, var00_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[:2],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var01, var01_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[2:4],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var10, var10_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[1:3],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var11, var11_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[3:5],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n pc0 = var00 / var01\n pc1 = var10 / var11\n pc0_err = pc0 * (var00_err / var00 + var01_err / var01)\n pc1_err = pc1 * (var10_err / var10 + var11_err / var11)\n if return_log:\n pc0_err = 1 / pc0 * pc0_err\n pc1_err = 1 / pc1 * pc1_err\n pc0 = np.log10(pc0)\n pc1 = np.log10(pc1)\n return pc0, pc0_err, pc1, pc1_err\n```\n###test function signature:\n```python\n\n def test_with_power_err(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_power_err", + "reference": "\n def test_with_power_err(self):\n pc0, pc0_err, pc1, pc1_err = power_color(\n self.freq,\n self.power,\n power_err=self.power / 2,\n )\n pc0e, pc0e_err, pc1e, pc1e_err = power_color(\n self.freq,\n self.power,\n power_err=self.power,\n )\n assert np.isclose(pc0, 1, atol=0.001)\n assert np.isclose(pc1, 1, atol=0.001)\n assert np.isclose(pc0e, 1, atol=0.001)\n assert np.isclose(pc1e, 1, atol=0.001)\n assert np.isclose(pc0e_err / pc0_err, 2, atol=0.001)\n assert np.isclose(pc1e_err / pc1_err, 2, atol=0.001)", + "completions": [ + "```python\n def test_with_power_err(self):\n frequency = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power_err = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n pc0, pc0_err, pc1, pc1_err = power_color(\n frequency,\n power,\n power_err=power_err,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n )\n self.assertEqual(pc0, 0.0)\n self.assertEqual(pc0_err, 0.0)\n self.assertEqual(pc1, 0.0)\n self.assertEqual(pc1_err, 0.0)\n```\n" + ], + "line": 97, + "token": 560, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 14, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def apply_top_k(\n self, logits: numpy.ndarray, filter_value=_MIN_FLOAT\n ) -> numpy.ndarray:\n logits_shape = logits.shape\n logits = logits.reshape(logits.shape[-1])\n top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]\n logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value\n\n return logits.reshape(logits_shape)\n```\n###test function signature:\n```python\n\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_apply_topk", + "reference": "\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11))\n\n token_generator = TokenGenerator(\n logits_shape=logits[-1].shape[-1],\n top_k=3,\n )\n\n filter_value = -float(\"Inf\")\n new_logits = token_generator.apply_top_k(\n logits.copy(), filter_value=filter_value\n )\n\n for _ in range(token_generator.top_k):\n curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits)\n assert curr_max > filter_value\n new_logits = numpy.delete(new_logits, idx)\n\n assert numpy.all(new_logits == filter_value)", + "completions": [ + "```python\n def test_apply_topk(\n self,\n ):\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9," + ], + "line": 21, + "token": 152, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(\n self,\n sources: List[Union[str, Path, ByteStream]],\n meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,\n ):\n parser = MarkdownIt(renderer_cls=RendererPlain)\n if self.table_to_single_line:\n parser.enable(\"table\")\n\n documents = []\n meta_list = normalize_metadata(meta=meta, sources_count=len(sources))\n\n for source, metadata in tqdm(\n zip(sources, meta_list),\n total=len(sources),\n desc=\"Converting markdown files to Documents\",\n disable=not self.progress_bar,\n ):\n try:\n bytestream = get_bytestream_from_source(source)\n except Exception as e:\n logger.warning(\"Could not read {source}. Skipping it. Error: {error}\", source=source, error=e)\n continue\n try:\n file_content = bytestream.data.decode(\"utf-8\")\n text = parser.render(file_content)\n except Exception as conversion_e:\n logger.warning(\n \"Failed to extract text from {source}. Skipping it. Error: {error}\",\n source=source,\n error=conversion_e,\n )\n continue\n\n merged_metadata = {**bytestream.meta, **metadata}\n document = Document(content=text, meta=merged_metadata)\n documents.append(document)\n\n return {\"documents\": documents}\n```\n###test function signature:\n```python\n\n def test_run_calls_normalize_metadata(self, test_files_path):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_calls_normalize_metadata", + "reference": "\n def test_run_calls_normalize_metadata(self, test_files_path):\n bytestream = ByteStream(data=b\"test\", meta={\"author\": \"test_author\", \"language\": \"en\"})\n\n converter = MarkdownToDocument()\n\n with patch(\"haystack.components.converters.markdown.normalize_metadata\") as normalize_metadata, patch(\n \"haystack.components.converters.markdown.MarkdownIt\"\n ):\n converter.run(sources=[bytestream, test_files_path / \"markdown\" / \"sample.md\"], meta={\"language\": \"it\"})\n\n # check that the metadata normalizer is called properly\n normalize_metadata.assert_called_with(meta={\"language\": \"it\"}, sources_count=2)", + "completions": [ + "```python\n def test_run_calls_normalize_metadata(self, test_files_path):\n # Arrange\n sources = [test_files_path / \"test_file.md\"]\n meta = {\"key\": \"value\"}\n\n # Act\n run(sources, meta)\n\n # Assert\n normalize_metadata.assert_called_once_with(meta=meta, sources_count=1)\n```\n" + ], + "line": 45, + "token": 311, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 16, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_config(config: object) -> None:\n if hasattr(config, \"kind\"):\n Config.running_config.classes[config.kind] = config # type: ignore\n Config.running_config.types[config.kind] = {}\n for field in fields(config): # type: ignore\n if hasattr(field, \"type\"):\n Config.running_config.types[config.kind][field.name] = optional_origin(field.type)\n else:\n raise RuntimeError(\"Config must have a 'kind' attribute\")\n```\n###test function signature:\n```python\n\ndef test_config():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_config", + "reference": "\ndef test_config():\n arg_parser = get_arg_parser()\n core_add_args(arg_parser)\n arg_parser.parse_known_args()\n cfg = Config(\"test\")\n cfg.add_config(ConfigTest)\n cfg.init_default_config()\n assert Config.dict() == {\n \"configtest\": {\n \"testvar1\": \"testing123\",\n \"testvar2\": 12345,\n \"testvar3\": {\n \"mydict\": {\"foo\": \"bar\", \"abc\": {\"def\": \"ghi\"}},\n \"myint\": 0,\n \"mystr\": \"Hello\",\n },\n }\n }\n cfg2 = Config(\"test2\")\n assert cfg.configtest.testvar1 == cfg2.configtest.testvar1\n assert cfg.configtest.testvar1 == Config.configtest.testvar1\n assert cfg.configtest.testvar2 == cfg2.configtest.testvar2\n assert cfg.configtest.testvar3 == cfg2.configtest.testvar3\n Config.configtest.testvar2 += 1\n assert cfg.configtest.testvar2 == 12346\n with pytest.raises(ConfigNotFoundError):\n Config.does_not_exist.foo = \"bar\"\n with pytest.raises(ConfigNotFoundError):\n cfg.does_not_exist.foo = \"bar\"", + "completions": [ + "```python\ndef test_config():\n from dataclasses import dataclass, field\n from typing import Optional\n\n @dataclass\n class Config:\n kind: str\n a: Optional[int] = field(default=None)\n b: Optional[int] = field(default=None)\n\n add_config(Config)\n\n assert Config.running_config.classes[\"Config\"] == Config\n assert Config.running_config.types[\"Config\"] == {\"a\": int, \"b\": int}\n```\n" + ], + "line": 19, + "token": 166, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 17, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_with_num_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_with_num_ref_line", + "reference": "\ndef test_nanoplot_out_with_num_ref_line():\n\n out_with_num_ref_line = _generate_nanoplot(**CASES[1])\n\n assert _is_nanoplot_output(out_with_num_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_with_num_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color," + ], + "line": 749, + "token": 7038, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 18, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def begin_typing(self, x=None):\n self.capturekeystrokes = True\n # Disable keypress shortcuts, which may otherwise cause the figure to\n # be saved, closed, etc., until the user stops typing. The way to\n # achieve this depends on whether toolmanager is in use.\n stack = ExitStack() # Register cleanup actions when user stops typing.\n self._on_stop_typing = stack.close\n toolmanager = getattr(\n self.ax.figure.canvas.manager, \"toolmanager\", None)\n if toolmanager is not None:\n # If using toolmanager, lock keypresses, and plan to release the\n # lock when typing stops.\n toolmanager.keypresslock(self)\n stack.callback(toolmanager.keypresslock.release, self)\n else:\n # If not using toolmanager, disable all keypress-related rcParams.\n # Avoid spurious warnings if keymaps are getting deprecated.\n with _api.suppress_matplotlib_deprecation_warning():\n stack.enter_context(mpl.rc_context(\n {k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n```\n###test function signature:\n```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_TextBox", + "reference": "\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n plt.rcParams._set(\"toolbar\", toolbar)\n\n submit_event = mock.Mock(spec=noop, return_value=None)\n text_change_event = mock.Mock(spec=noop, return_value=None)\n tool = widgets.TextBox(ax, '')\n tool.on_submit(submit_event)\n tool.on_text_change(text_change_event)\n\n assert tool.text == ''\n\n do_event(tool, '_click')\n\n tool.set_val('x**2')\n\n assert tool.text == 'x**2'\n assert text_change_event.call_count == 1\n\n tool.begin_typing()\n tool.stop_typing()\n\n assert submit_event.call_count == 2\n\n do_event(tool, '_click')\n do_event(tool, '_keypress', key='+')\n do_event(tool, '_keypress', key='5')\n\n assert text_change_event.call_count == 3", + "completions": [ + "```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n with ExitStack() as stack:\n stack.enter_context(mpl.rc_context({k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n ax.begin_typing()\n assert ax.capturekeystrokes is True\n assert toolbar.keypresslock.is_locked(ax)\n```\n" + ], + "line": 31, + "token": 258, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 19, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def read_events(\n db: PrefectDBInterface,\n session: AsyncSession,\n events_filter: EventFilter,\n limit: \"int | None\" = None,\n offset: \"int | None\" = None,\n) -> Sequence[ORMEvent]:\n # Always order by occurred timestamp, with placeholder for order direction\n order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc\n\n # Check if distinct fields are provided\n if distinct_fields := build_distinct_queries(events_filter):\n # Define window function\n window_function = (\n sa.func.row_number()\n .over(partition_by=distinct_fields, order_by=order(db.Event.occurred))\n .label(\"row_number\")\n )\n # Create a subquery with the window function\n subquery = (\n sa.select(db.Event, window_function)\n .where(\n sa.and_(\n *events_filter.build_where_clauses(db)\n ) # Ensure the same filters are applied here\n )\n .subquery()\n )\n\n # Alias the subquery for easier column references\n aliased_table = aliased(db.Event, subquery)\n\n # Create the final query from the subquery, filtering to get only rows with row_number = 1\n select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1)\n\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(subquery.c.occurred))\n\n else:\n # If no distinct fields are provided, create a query for all events\n select_events_query = sa.select(db.Event).where(\n sa.and_(*events_filter.build_where_clauses(db))\n )\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(db.Event.occurred))\n\n if limit is not None:\n limit = max(0, min(limit, events_filter.logical_limit))\n select_events_query = select_events_query.limit(limit=limit)\n if offset is not None:\n select_events_query = select_events_query.offset(offset=offset)\n\n logger.debug(\"Running PostgreSQL query: %s\", select_events_query)\n\n select_events_query_result = await session.execute(select_events_query)\n return select_events_query_result.scalars().unique().all()\n```\n###test function signature:\n```python\n\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_read_events", + "reference": "\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_1.id\n\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=3),\n until=pendulum.now(\"UTC\").subtract(days=1),\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_2.id\n\n # Read criteria should apply AND logic\n # Nothing should be returned for this resource\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n resource=EventResourceFilter(id=[\"prefect.garbage.foo\"]),\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 0", + "completions": [ + "```python\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n # Arrange\n events_filter = EventFilter(\n occurred_after=datetime.datetime(2022, 1, 1),\n occurred_before=datetime.datetime(2022, 1, 2),\n order=EventOrder.DESC,\n )\n # Act\n events = await read_events(\n db=self.db,\n session=session,\n events_filter=events_filter,\n )\n # Assert\n self.assertEqual(len(events), 2)\n self.assertEqual(events[0], event_2)\n self.assertEqual(events[1], event_1)\n```\n" + ], + "line": 58, + "token": 412, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 20, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n def test_evaluate_diff_dim(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_evaluate_diff_dim", + "reference": " def test_evaluate_diff_dim(self):\n x1 = np.arange(3, 10, 2)\n kde = mlab.GaussianKDE(x1)\n x2 = np.arange(3, 12, 2)\n y_expected = [\n 0.08797252, 0.11774109, 0.11774109, 0.08797252, 0.0370153\n ]\n y = kde.evaluate(x2)\n np.testing.assert_array_almost_equal(y, y_expected, 7)", + "completions": [ + "```python\n def test_evaluate_diff_dim(self):\n \"\"\"\n Test evaluate with different dimensions\n \"\"\"\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n\n # Generate a random dataset\n X, _ = make_blobs(n_samples=100, n_features=10, centers=5,\n random_state=42)\n\n # Standardize the data\n scaler = StandardScaler()\n X = scaler.fit_transform(X)\n\n # Create a GaussianMixture object\n gmm = GaussianMixture(n_components=5, covariance_type='full')\n gmm.fit(X)\n\n # Generate a random set of points\n n_points = 100\n points = np.random.rand(n_points, 10)\n\n # Evaluate the points\n result = gmm.evaluate(points)\n\n # Check that the result is correct\n self.assertEqual(result.shape, (n_points,))\n self.assertTrue(np.all(result >= 0))\n self.assertTrue(np.all(result <= 1))\n```\n" + ], + "line": 33, + "token": 273, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 21, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _differentiate(func, x, *, args=(), atol=None, rtol=None, maxiter=10,\n order=8, initial_step=0.5, step_factor=2.0,\n step_direction=0, preserve_shape=False, callback=None):\n # TODO (followup):\n # - investigate behavior at saddle points\n # - array initial_step / step_factor?\n # - multivariate functions?\n\n res = _differentiate_iv(func, x, args, atol, rtol, maxiter, order, initial_step,\n step_factor, step_direction, preserve_shape, callback)\n (func, x, args, atol, rtol, maxiter, order,\n h0, fac, hdir, preserve_shape, callback) = res\n\n # Initialization\n # Since f(x) (no step) is not needed for central differences, it may be\n # possible to eliminate this function evaluation. However, it's useful for\n # input validation and standardization, and everything else is designed to\n # reduce function calls, so let's keep it simple.\n temp = eim._initialize(func, (x,), args, preserve_shape=preserve_shape)\n func, xs, fs, args, shape, dtype = temp\n x, f = xs[0], fs[0]\n df = np.full_like(f, np.nan)\n # Ideally we'd broadcast the shape of `hdir` in `_elementwise_algo_init`, but\n # it's simpler to do it here than to generalize `_elementwise_algo_init` further.\n # `hdir` and `x` are already broadcasted in `_differentiate_iv`, so we know\n # that `hdir` can be broadcasted to the final shape.\n hdir = np.broadcast_to(hdir, shape).flatten()\n\n status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress\n nit, nfev = 0, 1 # one function evaluations performed above\n # Boolean indices of left, central, right, and (all) one-sided steps\n il = hdir < 0\n ic = hdir == 0\n ir = hdir > 0\n io = il | ir\n\n # Most of these attributes are reasonably obvious, but:\n # - `fs` holds all the function values of all active `x`. The zeroth\n # axis corresponds with active points `x`, the first axis corresponds\n # with the different steps (in the order described in\n # `_differentiate_weights`).\n # - `terms` (which could probably use a better name) is half the `order`,\n # which is always even.\n work = _RichResult(x=x, df=df, fs=f[:, np.newaxis], error=np.nan, h=h0,\n df_last=np.nan, error_last=np.nan, h0=h0, fac=fac,\n atol=atol, rtol=rtol, nit=nit, nfev=nfev,\n status=status, dtype=dtype, terms=(order+1)//2,\n hdir=hdir, il=il, ic=ic, ir=ir, io=io)\n # This is the correspondence between terms in the `work` object and the\n # final result. In this case, the mapping is trivial. Note that `success`\n # is prepended automatically.\n res_work_pairs = [('status', 'status'), ('df', 'df'), ('error', 'error'),\n ('nit', 'nit'), ('nfev', 'nfev'), ('x', 'x')]\n\n def pre_func_eval(work):\n \"\"\"Determine the abscissae at which the function needs to be evaluated.\n\n See `_differentiate_weights` for a description of the stencil (pattern\n of the abscissae).\n\n In the first iteration, there is only one stored function value in\n `work.fs`, `f(x)`, so we need to evaluate at `order` new points. In\n subsequent iterations, we evaluate at two new points. Note that\n `work.x` is always flattened into a 1D array after broadcasting with\n all `args`, so we add a new axis at the end and evaluate all point\n in one call to the function.\n\n For improvement:\n - Consider measuring the step size actually taken, since `(x + h) - x`\n is not identically equal to `h` with floating point arithmetic.\n - Adjust the step size automatically if `x` is too big to resolve the\n step.\n - We could probably save some work if there are no central difference\n steps or no one-sided steps.\n \"\"\"\n n = work.terms # half the order\n h = work.h # step size\n c = work.fac # step reduction factor\n d = c**0.5 # square root of step reduction factor (one-sided stencil)\n # Note - no need to be careful about dtypes until we allocate `x_eval`\n\n if work.nit == 0:\n hc = h / c**np.arange(n)\n hc = np.concatenate((-hc[::-1], hc))\n else:\n hc = np.asarray([-h, h]) / c**(n-1)\n\n if work.nit == 0:\n hr = h / d**np.arange(2*n)\n else:\n hr = np.asarray([h, h/d]) / c**(n-1)\n\n n_new = 2*n if work.nit == 0 else 2 # number of new abscissae\n x_eval = np.zeros((len(work.hdir), n_new), dtype=work.dtype)\n il, ic, ir = work.il, work.ic, work.ir\n x_eval[ir] = work.x[ir, np.newaxis] + hr\n x_eval[ic] = work.x[ic, np.newaxis] + hc\n x_eval[il] = work.x[il, np.newaxis] - hr\n return x_eval\n\n def post_func_eval(x, f, work):\n \"\"\" Estimate the derivative and error from the function evaluations\n\n As in `pre_func_eval`: in the first iteration, there is only one stored\n function value in `work.fs`, `f(x)`, so we need to add the `order` new\n points. In subsequent iterations, we add two new points. The tricky\n part is getting the order to match that of the weights, which is\n described in `_differentiate_weights`.\n\n For improvement:\n - Change the order of the weights (and steps in `pre_func_eval`) to\n simplify `work_fc` concatenation and eliminate `fc` concatenation.\n - It would be simple to do one-step Richardson extrapolation with `df`\n and `df_last` to increase the order of the estimate and/or improve\n the error estimate.\n - Process the function evaluations in a more numerically favorable\n way. For instance, combining the pairs of central difference evals\n into a second-order approximation and using Richardson extrapolation\n to produce a higher order approximation seemed to retain accuracy up\n to very high order.\n - Alternatively, we could use `polyfit` like Jacobi. An advantage of\n fitting polynomial to more points than necessary is improved noise\n tolerance.\n \"\"\"\n n = work.terms\n n_new = n if work.nit == 0 else 1\n il, ic, io = work.il, work.ic, work.io\n\n # Central difference\n # `work_fc` is *all* the points at which the function has been evaluated\n # `fc` is the points we're using *this iteration* to produce the estimate\n work_fc = (f[ic, :n_new], work.fs[ic, :], f[ic, -n_new:])\n work_fc = np.concatenate(work_fc, axis=-1)\n if work.nit == 0:\n fc = work_fc\n else:\n fc = (work_fc[:, :n], work_fc[:, n:n+1], work_fc[:, -n:])\n fc = np.concatenate(fc, axis=-1)\n\n # One-sided difference\n work_fo = np.concatenate((work.fs[io, :], f[io, :]), axis=-1)\n if work.nit == 0:\n fo = work_fo\n else:\n fo = np.concatenate((work_fo[:, 0:1], work_fo[:, -2*n:]), axis=-1)\n\n work.fs = np.zeros((len(ic), work.fs.shape[-1] + 2*n_new))\n work.fs[ic] = work_fc\n work.fs[io] = work_fo\n\n wc, wo = _differentiate_weights(work, n)\n work.df_last = work.df.copy()\n work.df[ic] = fc @ wc / work.h\n work.df[io] = fo @ wo / work.h\n work.df[il] *= -1\n\n work.h /= work.fac\n work.error_last = work.error\n # Simple error estimate - the difference in derivative estimates between\n # this iteration and the last. This is typically conservative because if\n # convergence has begin, the true error is much closer to the difference\n # between the current estimate and the *next* error estimate. However,\n # we could use Richarson extrapolation to produce an error estimate that\n # is one order higher, and take the difference between that and\n # `work.df` (which would just be constant factor that depends on `fac`.)\n work.error = abs(work.df - work.df_last)\n\n def check_termination(work):\n \"\"\"Terminate due to convergence, non-finite values, or error increase\"\"\"\n stop = np.zeros_like(work.df).astype(bool)\n\n i = work.error < work.atol + work.rtol*abs(work.df)\n work.status[i] = eim._ECONVERGED\n stop[i] = True\n\n if work.nit > 0:\n i = ~((np.isfinite(work.x) & np.isfinite(work.df)) | stop)\n work.df[i], work.status[i] = np.nan, eim._EVALUEERR\n stop[i] = True\n\n # With infinite precision, there is a step size below which\n # all smaller step sizes will reduce the error. But in floating point\n # arithmetic, catastrophic cancellation will begin to cause the error\n # to increase again. This heuristic tries to avoid step sizes that are\n # too small. There may be more theoretically sound approaches for\n # detecting a step size that minimizes the total error, but this\n # heuristic seems simple and effective.\n i = (work.error > work.error_last*10) & ~stop\n work.status[i] = _EERRORINCREASE\n stop[i] = True\n\n return stop\n\n def post_termination_check(work):\n return\n\n def customize_result(res, shape):\n return shape\n\n return eim._loop(work, callback, shape, maxiter, func, args, dtype,\n pre_func_eval, post_func_eval, check_termination,\n post_termination_check, customize_result, res_work_pairs,\n preserve_shape)\n```\n###test function signature:\n```python\n\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_maxiter_callback", + "reference": "\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n x = 0.612814\n dist = stats.norm()\n maxiter = 3\n\n def f(x):\n res = dist.cdf(x)\n return res\n\n default_order = 8\n res = differentiate(f, x, maxiter=maxiter, rtol=1e-15)\n assert not np.any(res.success)\n assert np.all(res.nfev == default_order + 1 + (maxiter - 1)*2)\n assert np.all(res.nit == maxiter)\n\n def callback(res):\n callback.iter += 1\n callback.res = res\n assert hasattr(res, 'x')\n assert res.df not in callback.dfs\n callback.dfs.add(res.df)\n assert res.status == eim._EINPROGRESS\n if callback.iter == maxiter:\n raise StopIteration\n callback.iter = -1 # callback called once before first iteration\n callback.res = None\n callback.dfs = set()\n\n res2 = differentiate(f, x, callback=callback, rtol=1e-15)\n # terminating with callback is identical to terminating due to maxiter\n # (except for `status`)\n for key in res.keys():\n if key == 'status':\n assert res[key] == eim._ECONVERR\n assert callback.res[key] == eim._EINPROGRESS\n assert res2[key] == eim._ECALLBACK\n else:\n assert res2[key] == callback.res[key] == res[key]", + "completions": [ + "```python\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n def callback(work):\n if work.nit == 0:\n work.maxiter = 1\n else:\n work.maxiter = 0\n res = _differentiate(lambda x: x**2, 0, maxiter=2, callback=callback)\n self.assertEqual(res.nit, 1)\n self.assertEqual(res.status, 0)\n```\n" + ], + "line": 187, + "token": 1981, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 22, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef get_standard_metadata(\n pyproject_dict: Mapping[str, Any],\n settings: ScikitBuildSettings,\n) -> StandardMetadata:\n new_pyproject_dict = dict(pyproject_dict)\n # Handle any dynamic metadata\n for field, provider, config in load_dynamic_metadata(settings.metadata):\n if provider is None:\n msg = f\"{field} is missing provider\"\n raise KeyError(msg)\n if field not in pyproject_dict.get(\"project\", {}).get(\"dynamic\", []):\n msg = f\"{field} is not in project.dynamic\"\n raise KeyError(msg)\n new_pyproject_dict[\"project\"][field] = provider.dynamic_metadata(field, config)\n new_pyproject_dict[\"project\"][\"dynamic\"].remove(field)\n\n metadata = StandardMetadata.from_pyproject(new_pyproject_dict)\n\n # For scikit-build-core < 0.5, we keep the normalized name for back-compat\n if settings.minimum_version is not None and settings.minimum_version < Version(\n \"0.5\"\n ):\n metadata.name = metadata.canonical_name\n\n # The description field is required to be one line. Instead of merging it\n # or cutting off subsequent lines (setuptools), we throw a nice error.\n # But we didn't validate before 0.9.\n if (\n settings.minimum_version is None or settings.minimum_version >= Version(\"0.9\")\n ) and \"\\n\" in (metadata.description or \"\"):\n msg = \"Multiple lines in project.description are not supported; this is supposed to be a one line summary\"\n raise ValueError(msg)\n\n return metadata\n```\n###test function signature:\n```python\n\ndef test_plugin_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_plugin_metadata", + "reference": "\ndef test_plugin_metadata():\n reason_msg = (\n \"install hatch-fancy-pypi-readme and setuptools-scm to test the \"\n \"dynamic metadata plugins\"\n )\n pytest.importorskip(\"hatch_fancy_pypi_readme\", reason=reason_msg)\n pytest.importorskip(\"setuptools_scm\", reason=reason_msg)\n if shutil.which(\"git\") is None:\n pytest.skip(\"git is not installed\")\n\n shutil.copy(\"plugin_project.toml\", \"pyproject.toml\")\n\n subprocess.run([\"git\", \"init\", \"--initial-branch=main\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.name\", \"bot\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.email\", \"bot@scikit-build.org\"], check=True)\n subprocess.run([\"git\", \"add\", \"pyproject.toml\"], check=True)\n subprocess.run([\"git\", \"commit\", \"-m\", \"initial commit\"], check=True)\n subprocess.run([\"git\", \"tag\", \"v0.1.0\", \"-m\", \"initial commint\"], check=True)\n\n with Path(\"pyproject.toml\").open(\"rb\") as ft:\n pyproject = tomllib.load(ft)\n settings_reader = SettingsReader(pyproject, {}, state=\"metadata_wheel\")\n settings = settings_reader.settings\n\n settings_reader.validate_may_exit()\n\n metadata = get_standard_metadata(pyproject, settings)\n\n assert str(metadata.version) == \"0.1.0\"\n assert metadata.readme == pyproject_metadata.Readme(\n \"Fragment #1Fragment #2\", None, \"text/x-rst\"\n )\n\n assert set(GetRequires().dynamic_metadata()) == {\n \"hatch-fancy-pypi-readme>=22.3\",\n \"setuptools-scm\",\n }", + "completions": [ + "```python\n\ndef test_plugin_metadata():\n from skbuild import __version__ as skbuild_version\n from skbuild.settings import ScikitBuildSettings\n from skbuild.metadata import StandardMetadata\n from skbuild.utils import load_dynamic_metadata\n from skbuild.version import Version\n\n pyproject_dict = {\n \"build-system\": {\n \"requires\": [\"setuptools>=40.8.0\", \"wheel\"],\n \"build-backend\": \"setuptools.build_meta\",\n },\n \"project\": {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n \"description\": \"A build system for Python packages\",\n \"dynamic\": [\"name\", \"version\"],\n },\n }\n\n settings = ScikitBuildSettings(\n minimum_version=Version(\"0.1.0\"),\n metadata=[\n (\"name\", None, None),\n (\"version\", None, None),\n ],\n )\n\n metadata = get_standard_metadata(pyproject_dict, settings)\n\n assert metadata.name == \"scikit-build\"\n assert metadata.version == \"0.1.0\"\n assert metadata.description == \"A build system for Python packages\"\n\n assert metadata.canonical_name == \"scikit-build\"\n assert metadata.canonical_version == \"0.1.0\"\n\n assert metadata.dynamic_metadata == {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n }\n\n assert metadata.dynamic_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_providers == [None, None]\n\n assert metadata.dynamic_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_metadata_providers == [None, None]\n\n assert metadata.dynamic_metadata_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields_by_provider == {None: [\"name\", \"version\"]}\n\n assert metadata.dynamic_metadata_providers_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_configs_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field == {\n (None, \"name\"): \"scikit-build\",\n (None, \"version\"): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider == {\n (\"name\", None): \"scikit-build\",\n (\"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider == {\n (\"name\", None): None,\n (\"version\", None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config == {\n (None, \"name\", None): \"scikit-build\",\n (None, \"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config == {\n (\"name\", None, None): \"scikit-build\",\n (\"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider_and_config == {\n (\"name\", None, None): None,\n (\"version\", None, None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config_and_provider == {\n (None, \"name\", None, None): \"scikit-build\",\n (None, \"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config_and_provider == {\n (\"name\", None, None" + ], + "line": 40, + "token": 347, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 23, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _get_datetime_format(value):\n if not isinstance(value, pd.Series):\n value = pd.Series(value)\n\n value = value[~value.isna()]\n value = value.astype(str).to_numpy()\n\n return _guess_datetime_format_for_array(value)\n```\n###test function signature:\n```python\ndef test__get_datetime_format():\n # Setup\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test__get_datetime_format", + "reference": "def test__get_datetime_format():\n # Setup\n string_value = '2021-02-02'\n list_value = [np.nan, '2021-02-02']\n series_value = pd.Series(['2021-02-02T12:10:59'])\n\n # Run\n string_out = _get_datetime_format(string_value)\n list_out = _get_datetime_format(list_value)\n series_out = _get_datetime_format(series_value)\n\n # Assert\n expected_output = '%Y-%m-%d'\n assert string_out == expected_output\n assert list_out == expected_output\n assert series_out == '%Y-%m-%dT%H:%M:%S'", + "completions": [ + "```python\ndef test__get_datetime_format():\n # Setup\n value = pd.Series([\"2020-01-01\", \"2020-01-02\", \"2020-01-03\"])\n expected = \"%Y-%m-%d\"\n\n # Exercise\n actual = _get_datetime_format(value)\n\n # Verify\n assert actual == expected\n```\n" + ], + "line": 17, + "token": 124, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 24, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def load(self) -> None:\n super().load()\n\n _path = str(\n importlib_resources.files(\"distilabel\")\n / \"steps\"\n / \"tasks\"\n / \"templates\"\n / \"ultrafeedback\"\n / f\"{self.aspect}.jinja2\"\n )\n\n self._template = Template(open(_path).read())\n```\n###test function signature:\n```python\n\n def test_process_with_complex_aspect(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_process_with_complex_aspect", + "reference": "\n def test_process_with_complex_aspect(self) -> None:\n pipeline = Pipeline(name=\"unit-test-pipeline\")\n llm = UltraFeedbackLLM()\n\n task = UltraFeedback(\n name=\"ultrafeedback\",\n aspect=\"truthfulness\",\n llm=llm,\n pipeline=pipeline,\n )\n task.load()\n\n assert next(\n task.process([{\"instruction\": \"test\", \"generations\": [\"A\", \"B\"]}])\n ) == [\n {\n \"instruction\": \"test\",\n \"generations\": [\"A\", \"B\"],\n \"types\": [1, 2],\n \"rationales\": [\"text\", \"text\"],\n \"ratings\": [1, 2],\n \"rationales-for-ratings\": [\"text\", \"text\"],\n \"model_name\": \"ultrafeedback-model\",\n }\n ]", + "completions": [ + "```python\n def test_process_with_complex_aspect(self) -> None:\n self.aspect = \"complex_aspect\"\n self.load()\n self.assertEqual(\n self._template.render(\n {\n \"aspect\": self.aspect,\n \"aspect_name\": \"Complex Aspect\",\n \"aspect_description\": \"This is a complex aspect\",\n \"aspect_type\": \"complex\",\n \"aspect_value\": \"1\",\n \"aspect_weight\": \"1\",\n \"aspect_color\": \"red\",\n \"aspect_icon\": \"icon-circle\",\n \"aspect_icon_color\": \"red\",\n \"aspect_icon_bg_color\": \"white\",\n \"aspect_icon_size\": \"20\",\n \"aspect_icon_border_color\": \"red\",\n \"aspect_icon_border_width\": \"2\",\n \"aspect_icon_border_radius\": \"50\",\n \"aspect_icon_padding\": \"5\",\n \"aspect_icon_margin\": \"5\",\n \"aspect_icon_bg_size\": \"20\",\n \"aspect_icon_bg_border_color\": \"red\",\n \"aspect_icon_bg_border_width\": \"2\",\n \"aspect_icon_bg_border_radius\": \"50\",\n \"aspect_icon_bg_padding\": \"5\",\n \"aspect_icon_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_" + ], + "line": 21, + "token": 145, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 25, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _setup_cuda_device(device: torch.device) -> torch.device:\n local_rank = _get_local_rank() or 0\n if device.index is None:\n device = torch.device(type=\"cuda\", index=local_rank)\n\n # Ensure index is available before setting device\n if device.index >= torch.cuda.device_count():\n raise RuntimeError(\n \"The local rank is larger than the number of available GPUs.\"\n )\n\n torch.cuda.set_device(device)\n return device\n```\n###test function signature:\n```python\n\n def test_get_gpu_device(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_gpu_device", + "reference": "\n def test_get_gpu_device(self) -> None:\n device_idx = torch.cuda.device_count() - 1\n assert device_idx >= 0\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(device_idx)}, clear=True):\n device = get_device()\n assert device.type == \"cuda\"\n assert device.index == device_idx\n assert device.index == torch.cuda.current_device()\n\n # Test that we raise an error if the device index is specified on distributed runs\n if device_idx > 0:\n with pytest.raises(\n RuntimeError,\n match=f\"Device specified is cuda:0 but was assigned cuda:{device_idx}\",\n ):\n device = get_device(\"cuda:0\")\n\n invalid_device_idx = device_idx + 10\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(invalid_device_idx)}):\n with pytest.raises(\n RuntimeError,\n match=\"The local rank is larger than the number of available GPUs\",\n ):\n device = get_device(\"cuda\")\n\n # Test that we fall back to 0 if LOCAL_RANK is not specified\n device = torch.device(_get_device_type_from_env())\n device = _setup_cuda_device(device)\n assert device.type == \"cuda\"\n assert device.index == 0\n assert device.index == torch.cuda.current_device()", + "completions": [ + "```python\n def test_get_gpu_device(self) -> None:\n device = torch.device(type=\"cuda\", index=0)\n self.assertEqual(device, _setup_cuda_device(device))\n```\n" + ], + "line": 21, + "token": 157, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):\n local_dir = settings.home_dir / settings.cache_dir / \"models\" / model_id\n client.load_artifact(local_dir, posixpath.join(\"models\", model_id))\n meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)\n # internal worker case, no token based http client\n if not local_dir.exists():\n raise RuntimeError(f\"Cannot find existing model {project_key}.{model_id} in {local_dir}\")\n with (Path(local_dir) / META_FILENAME).open(encoding=\"utf-8\") as f:\n file_meta = yaml.load(f, Loader=yaml.Loader)\n classification_labels = cls.cast_labels(meta_response)\n meta = ModelMeta(\n name=meta_response.name,\n description=meta_response.description,\n model_type=SupportedModelTypes[meta_response.modelType],\n feature_names=meta_response.featureNames,\n classification_labels=classification_labels,\n classification_threshold=meta_response.threshold,\n loader_module=file_meta[\"loader_module\"],\n loader_class=file_meta[\"loader_class\"],\n )\n\n model_py_ver = (\n tuple(meta_response.languageVersion.split(\".\")) if \"PYTHON\" == meta_response.language.upper() else None\n )\n\n clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver)\n\n constructor_params = meta.__dict__\n constructor_params[\"id\"] = str(model_id)\n\n del constructor_params[\"loader_module\"]\n del constructor_params[\"loader_class\"]\n\n model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params)\n return model\n```\n###test function signature:\n```python\n\ndef test_model_download(request):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_model_download", + "reference": "\ndef test_model_download(request):\n model: BaseModel = request.getfixturevalue(\"german_credit_model\")\n project_key = str(uuid.uuid4())\n\n with utils.MockedProjectCacheDir():\n with utils.MockedClient(mock_all=False) as (client, mr):\n # The model needs to request files\n requested_urls = []\n requested_urls.extend(utils.register_uri_for_model_meta_info(mr, model, project_key))\n requested_urls.extend(\n utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True)\n )\n\n downloaded_model = BaseModel.download(client=client, project_key=project_key, model_id=str(model.id))\n\n for requested_url in requested_urls:\n assert utils.is_url_requested(mr.request_history, requested_url)\n\n assert downloaded_model.id == model.id\n assert downloaded_model.meta == model.meta", + "completions": [ + "```python\ndef test_model_download(request):\n from pathlib import Path\n from unittest.mock import MagicMock\n from giskard.client import GiskardClient\n from giskard.models import ModelMetaInfo, ModelMeta, SupportedModelTypes\n from giskard.utils import cast_labels\n from giskard.models.base import Model\n from giskard.models.classification import ClassificationModel\n from giskard.models.regression import RegressionModel\n from giskard.models.clustering import ClusteringModel\n from giskard.models.dim_reduction import DimReductionModel\n from giskard.models.anomaly_detection import AnomalyDetectionModel\n from giskard.models.time_series import TimeSeriesModel\n from giskard.models.text_classification import TextClassificationModel\n from giskard.models.text_recommendation import TextRecommendationModel\n from giskard.models.image_classification import ImageClassificationModel\n from giskard.models.image_recommendation import ImageRecommendationModel\n from giskard.models.image_segmentation import ImageSegmentationModel\n from giskard.models.image_generation import ImageGenerationModel\n from giskard.models.image_style_transfer import ImageStyleTransferModel\n from giskard.models.image_super_resolution import ImageSuperResolutionModel\n from giskard.models.image_denoising import ImageDenoisingModel\n from giskard.models.image_colorization import ImageColorizationModel\n from giskard.models.image_deraining import ImageDerainingModel\n from giskard.models.image_deblurring import ImageDeblurringModel\n from giskard.models.image_edge_detection import ImageEdgeDetectionModel\n from giskard.models.image_face_detection import ImageFaceDetectionModel\n from giskard.models.image_face_recognition import ImageFaceRecognitionModel\n from giskard.models.image_object_detection import ImageObjectDetectionModel\n from giskard.models.image_object_recognition import ImageObjectRecognitionModel\n from giskard.models.image_segmentation_mask_rcnn import ImageSegmentationMaskRCNNModel\n from giskard.models.image_segmentation_yolo import ImageSegmentationYOLOModel\n from giskard.models.image_segmentation_faster_rcnn import ImageSegmentationFasterRCNNModel\n from giskard.models.image_segmentation_unet import ImageSegmentationUNetModel\n from giskard.models.image_segmentation_deeplab import ImageSegmentationDeepLabModel\n from giskard.models.image_segmentation_fcn import ImageSegmentationFCNModel\n from giskard.models.image_segmentation_pspnet import ImageSegmentationPSPNetModel\n from giskard.models.image_segmentation_linknet import ImageSegmentationLinkNetModel\n from giskard.models.image_segmentation_unet_plus import ImageSegmentationUNetPlusModel\n from giskard.models.image_segmentation_deeplab_v3 import ImageSegmentationDeepLabV3Model\n from giskard.models.image_segmentation_deeplab_v3_plus import ImageSegmentationDeepLabV3PlusModel\n from giskard.models.image_segmentation_fcn_resnet101 import ImageSegmentationFCNResNet101Model\n from giskard.models.image_segmentation_fcn_resnet50 import ImageSegmentationFCNResNet50Model\n from giskard.models.image_segmentation_fcn_mobilenet_v2 import ImageSegmentationFCNMobileNetV2Model\n from giskard.models.image_segmentation_fcn_mobilenet_v3 import ImageSegmentationFCNMobileNetV3Model\n from giskard.models.image_segmentation_fcn_efficientnet_b" + ], + "line": 40, + "token": 316, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 27, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\ndef test_rounded_date_columns():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rounded_date_columns", + "reference": "def test_rounded_date_columns():\n df = pl.DataFrame(\n [\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.000000\", \"exp\": True},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.123456\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:01.000000\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00\", \"exp\": True},\n ]\n )\n\n df = df.with_columns(df[\"a\"].str.to_date().alias(\"a_dt\"))\n actual = columns_equal(df[\"a_dt\"], df[\"b\"])\n expected = df[\"exp\"]\n assert_series_equal(actual, expected, check_names=False)", + "completions": [ + "```python\ndef test_rounded_date_columns():\n # given\n col_1 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n col_2 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n\n # when\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n\n # then\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 28, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef annotations_to_instances(annos, image_size, mask_format=\"polygon\"):\n boxes = [BoxMode.convert(obj[\"bbox\"], obj[\"bbox_mode\"], BoxMode.XYXY_ABS) for obj in annos]\n target = Instances(image_size)\n boxes = target.gt_boxes = Boxes(boxes)\n boxes.clip(image_size)\n\n classes = [obj[\"category_id\"] for obj in annos]\n classes = torch.tensor(classes, dtype=torch.int64)\n target.gt_classes = classes\n\n if len(annos) and \"segmentation\" in annos[0]:\n segms = [obj[\"segmentation\"] for obj in annos]\n if mask_format == \"polygon\":\n masks = PolygonMasks(segms)\n else:\n assert mask_format == \"bitmask\", mask_format\n masks = []\n for segm in segms:\n if isinstance(segm, list):\n # polygon\n masks.append(polygons_to_bitmask(segm, *image_size))\n elif isinstance(segm, dict):\n # COCO RLE\n masks.append(mask_util.decode(segm))\n elif isinstance(segm, np.ndarray):\n assert segm.ndim == 2, \"Expect segmentation of 2 dimensions, got {}.\".format(\n segm.ndim\n )\n # mask array\n masks.append(segm)\n else:\n raise ValueError(\n \"Cannot convert segmentation of type '{}' to BitMasks!\"\n \"Supported types are: polygons as list[list[float] or ndarray],\"\n \" COCO-style RLE as a dict, or a full-image segmentation mask \"\n \"as a 2D ndarray.\".format(type(segm))\n )\n # torch.from_numpy does not support array with negative stride.\n masks = BitMasks(\n torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])\n )\n target.gt_masks = masks\n\n if len(annos) and \"keypoints\" in annos[0]:\n kpts = [obj.get(\"keypoints\", []) for obj in annos]\n target.gt_keypoints = Keypoints(kpts)\n\n return target\n```\n###test function signature:\n```python\n\n def test_transform_simple_annotation(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_simple_annotation", + "reference": "\n def test_transform_simple_annotation(self):\n transforms = T.TransformList([T.HFlipTransform(400)])\n anno = {\n \"bbox\": np.asarray([10, 10, 200, 300]),\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 3,\n \"segmentation\": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],\n }\n\n output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))\n self.assertTrue(np.allclose(output[\"bbox\"], [200, 10, 390, 300]))\n self.assertEqual(len(output[\"segmentation\"]), len(anno[\"segmentation\"]))\n self.assertTrue(np.allclose(output[\"segmentation\"][0], [390, 10, 300, 100, 300, 10]))\n\n detection_utils.annotations_to_instances([output, output], (400, 400))", + "completions": [ + "```python\n def test_transform_simple_annotation(self):\n annos = [\n {\n \"bbox\": [10, 10, 100, 100],\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 0,\n \"segmentation\": [\n [10, 10, 50, 50, 90, 90, 100, 100],\n [20, 20, 50, 50, 80, 80, 100, 100],\n ],\n \"keypoints\": [10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n }\n ]\n image_size = (100, 100)\n target = annotations_to_instances(annos, image_size)\n assert target.gt_boxes.tensor.tolist() == [[10, 10, 100, 100]]\n assert target.gt_classes.tolist() == [0]\n assert target.gt_masks.polygons().tolist() == [\n [[10, 10], [50, 50], [90, 90], [100, 100]],\n [[20, 20], [50, 50], [80, 80], [100, 100]],\n ]\n assert target.gt_keypoints.tensor.tolist() == [[10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n```\n" + ], + "line": 54, + "token": 437, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 29, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_simple(seed):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_simple", + "reference": "\ndef test_sabre_simple(seed):\n placer = Trivial()\n circ = Circuit(5)\n circ.add(gates.CZ(0, 1))\n initial_layout = placer(circ)\n router = Sabre(connectivity=star_connectivity(), seed=seed)\n routed_circuit, final_map = router(circuit=circ, initial_layout=initial_layout)\n assert router.added_swaps == 1\n assert final_map == {\"q0\": 2, \"q1\": 1, \"q2\": 0, \"q3\": 3, \"q4\": 4}\n assert routed_circuit.queue[0].qubits == (0, 2)\n assert isinstance(routed_circuit.queue[0], gates.SWAP)\n assert isinstance(routed_circuit.queue[1], gates.CZ)\n assert_connectivity(circuit=routed_circuit, connectivity=star_connectivity())\n assert_circuit_equivalence(\n original_circuit=circ,\n transpiled_circuit=routed_circuit,\n final_map=final_map,\n initial_map=initial_layout,\n )", + "completions": [ + "```python\ndef test_sabre_simple(seed):\n import random\n random.seed(seed)\n import networkx as nx\n from pyquil.quil import Program\n from pyquil.api import QVMConnection\n from pyquil.gates import CNOT, H, MEASURE\n from pyquil.quilbase import Declare\n from pyquil.quilbase import DefGate\n from pyquil.quilbase import DefCircuit\n from pyquil.quilbase import Gate\n from pyquil.quilbase import Circuit\n from pyquil.quilbase import QubitPlaceholder\n from pyquil.quilbase import MemoryReference\n from pyquil.quilbase import Parameter\n from pyquil.quilbase import Pragma\n from pyquil.quilbase import PragmaDeclare\n from pyquil.quilbase import PragmaDefGate\n from pyquil.quilbase import PragmaDefCircuit\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefQubit\n from pyquil.quilbase import PragmaDefMemory\n from pyquil.quilbase import PragmaDefRegister\n from pyquil.quilbase import PragmaDefLabel\n from pyquil.quilbase import PragmaDefFunction\n from pyquil.quilbase import PragmaDefProcedure\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyqu" + ], + "line": 27, + "token": 242, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 30, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef unique_patterns_2d(\n agrid: NDArray[np.int64], ksize: int, periodic_input: bool\n) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:\n assert ksize >= 1\n if periodic_input:\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n else:\n # TODO: implement non-wrapped image handling\n # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n\n patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(\n agrid,\n (\n agrid.shape[0] - ksize + 1,\n agrid.shape[1] - ksize + 1,\n ksize,\n ksize,\n *agrid.shape[2:],\n ),\n agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],\n writeable=False,\n )\n patch_codes = hash_downto(patches, 2)\n uc, ui = np.unique(patch_codes, return_index=True)\n locs = np.unravel_index(ui, patch_codes.shape)\n up: NDArray[np.int64] = patches[locs[0], locs[1]]\n ids: NDArray[np.int64] = np.vectorize(\n {code: ind for ind, code in enumerate(uc)}.get\n )(patch_codes)\n return ids, up, patch_codes\n```\n###test function signature:\n```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_patterns_2d", + "reference": "\ndef test_unique_patterns_2d(img_redmaze) -> None:\n img = img_redmaze\n tile_size = 1\n pattern_width = 2\n _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(\n img, tile_size\n )\n\n (\n _patterns_in_grid,\n pattern_contents_list,\n patch_codes,\n ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)\n assert patch_codes[1][2] == 4867810695119132864\n assert pattern_contents_list[7][1][1] == 8253868773529191888", + "completions": [ + "```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n ids, up, patch_codes = unique_patterns_2d(img_redmaze, 3, True)\n assert ids.shape == (10,)\n assert up.shape == (10, 3, 3)\n assert patch_codes.shape == (10,)\n```\n" + ], + "line": 48, + "token": 425, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 31, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef serialize(data: Any, encoding: str | None = None) -> str | bytes:\n serialized: bytes | str\n\n if encoding is not None:\n serialized = yaml.dump(data, encoding=encoding, Dumper=AiiDADumper)\n else:\n serialized = yaml.dump(data, Dumper=AiiDADumper)\n\n return serialized\n```\n###test function signature:\n```python\ndef test_mixed_attribute_normal_dict():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mixed_attribute_normal_dict", + "reference": "def test_mixed_attribute_normal_dict():\n from aiida.common.extendeddicts import AttributeDict\n\n # Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively\n dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}\n attribute_dict = AttributeDict(dictionary)\n\n # Now add a normal dictionary in the attribute dictionary\n attribute_dict['nested']['normal'] = {'a': 2}\n\n serialized = serialize.serialize(attribute_dict)\n deserialized = serialize.deserialize_unsafe(serialized)\n\n assert attribute_dict, deserialized", + "completions": [ + "```python\ndef test_mixed_attribute_normal_dict():\n data = {\n \"a\": 1,\n \"b\": 2,\n \"c\": 3,\n \"d\": 4,\n \"e\": 5,\n \"f\": 6,\n \"g\": 7,\n \"h\": 8,\n \"i\": 9,\n \"j\": 10,\n \"k\": 11,\n \"l\": 12,\n \"m\": 13,\n \"n\": 14,\n \"o\": 15,\n \"p\": 16,\n \"q\": 17,\n \"r\": 18,\n \"s\": 19,\n \"t\": 20,\n \"u\": 21,\n \"v\": 22,\n \"w\": 23,\n \"x\": 24,\n \"y\": 25,\n \"z\": 26,\n }\n assert serialize(data) == \"\"\"a: 1\nb: 2\nc: 3\nd: 4\ne: 5\nf: 6\ng: 7\nh: 8\ni: 9\nj: 10\nk: 11\nl: 12\nm: 13\nn: 14\no: 15\np: 16\nq: 17\nr: 18\ns: 19\nt: 20\nu: 21\nv: 22\nw: 23\nx: 24\ny: 25\nz: 26\n\"\"\"\n```\n" + ], + "line": 17, + "token": 133, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected_against_alias", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n http_client.select_model(model_id=\"yolov8n-640\")\n\n # when\n result = http_client.unload_model(model_id=\"yolov8n-640\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"coco/3\",\n }\n assert (\n http_client.selected_model is None\n ), \"Even when alias is in use - selected model should be emptied\"", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n requests_mock.post(\n f\"{API_URL}/model/remove\",\n json={\n \"model_id\": \"test_model_id\",\n },\n headers=DEFAULT_HEADERS,\n )\n client = RoboflowClient(api_key=\"test_api_key\", api_secret=\"test_api_secret\")\n # when\n response = client.unload_model(model_id=\"test_model_alias\")\n # then\n assert response.model_id == \"test_model_id\"\n assert response.model_name == \"test_model_name\"\n assert response.model_version == \"test_model_version\"\n assert response.model_type == \"test_model_type\"\n assert response.model_status == \"test_model_status\"\n assert response.model_created_at == \"test_model_created_at\"\n assert response.model_updated_at == \"test_model_updated_at\"\n assert response.model_tags == \"test_model_tags\"\n assert response.model_description == \"test_model_description\"\n assert response.model_url == \"test_model_url\"\n assert response.model_thumbnail_url == \"test_model_thumbnail_url\"\n assert response.model_training_data_url == \"test_model_training_data_url\"\n assert response.model_training_data_type == \"test_model_training_data_type\"\n assert response.model_training_data_size == \"test_model_training_data_size\"\n assert response.model_training_data_description == \"test_model_training_data_description\"\n assert response.model_training_data_tags == \"test_model_training_data_tags\"\n assert response.model_training_data_created_at == \"test_model_training_data_created_at\"\n assert response.model_training_data_updated_at == \"test_model_training_data_updated_at\"\n assert response.model_training_data_status == \"test_model_training_data_status\"\n assert response.model_training_data_status_message == \"test_model_training_data_status_message\"\n assert response.model_training_data_status_details == \"test_model_training_data_status_details\"\n assert response.model_training_data_status_created_at == \"test_model_training_data_status_created_at\"\n assert response.model_training_data_status_updated_at == \"test_model_training_data_status_updated_at\"\n assert response.model_training_data_status_completed_at == \"test_model_training_data_status_completed_at\"\n assert response.model_training_data_status_duration == \"test_model_training_data_status_duration\"\n assert response.model_training_data_status_budget_type == \"test_model_training_data_status_budget_type\"\n assert response.model_training_data_status_budget_amount == \"test_model_training_data_status_budget_amount\"\n assert response.model_training_data_status_budget_unit == \"test_model_training_data_status_budget_unit\"\n assert response.model_training_data_status_budget_remaining == \"test_model_training_data_status_budget_remaining\"\n assert response.model_training_data_status_budget_message == \"test_model_training_data_status_budget_message\"\n assert response.model_training_data_status_budget_details == \"test_model_training_data_status_budget_details\"\n assert response.model_training_data_status_budget_created_at == \"test_model_training_data_status_budget_created_at\"\n assert response.model_training_data_status_budget_updated_at == \"test_" + ], + "line": 31, + "token": 162, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 33, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef edge_index_to_vector(\n edge_index: Tensor,\n size: Tuple[int, int],\n bipartite: bool,\n force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n row, col = edge_index\n\n if bipartite: # No need to account for self-loops.\n idx = (row * size[1]).add_(col)\n population = size[0] * size[1]\n return idx, population\n\n elif force_undirected:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We only operate on the upper triangular matrix:\n mask = row < col\n row, col = row[mask], col[mask]\n offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n idx = row.mul_(num_nodes).add_(col).sub_(offset)\n population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n return idx, population\n\n else:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We remove self-loops as we do not want to take them into account\n # when sampling negative values.\n mask = row != col\n row, col = row[mask], col[mask]\n col[row < col] -= 1\n idx = row.mul_(num_nodes - 1).add_(col)\n population = num_nodes * num_nodes - num_nodes\n return idx, population\n```\n###test function signature:\n```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_edge_index_to_vector_and_vice_versa", + "reference": "\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n N = 10\n row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n edge_index = torch.stack([row, col], dim=0)\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n assert population == N * N\n assert idx.tolist() == list(range(population))\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n assert is_undirected(edge_index2)\n assert edge_index.tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n assert population == N * N - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n force_undirected=True)\n assert population == (N * (N + 1)) / 2 - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n force_undirected=True)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()", + "completions": [ + "```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n\n # Test force_undirected:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test force_undirected with bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n```\n" + ], + "line": 43, + "token": 351, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 34, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_integration_test(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_integration_test", + "reference": "\n def test_spectrogram_integration_test(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,\n 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,\n 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,\n 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,\n 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,\n 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,\n 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,\n 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,\n 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,\n 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,\n 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,\n 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,\n 0.0293578 , 0.03452379, 0.02194803, 0.01676056,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=256,\n num_mel_filters=400,\n min_frequency=20,\n max_frequency=8000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"kaldi\",\n triangularize_in_mel_space=True,\n )\n\n mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"povey\", periodic=False),\n frame_length=400,\n hop_length=160,\n fft_length=512,\n power=2.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n preemphasis=0.97,\n mel_filters=mel_filters,\n log_mel=\"log\",\n mel_floor=1.1920928955078125e-07,\n remove_dc_offset=True,\n )\n self.assertEqual(spec.shape, (400, 584))\n\n # fmt: off\n expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -6.52463769, -7.73677889, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.18650018, -3.37195286,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -4.70190154, -2.4217066 , -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -5.62755239, -3.53385194,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -9.43303023, -8.77480925, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.2951092 , -5.51585994,\n -15.94238515, -15.94238515, -15.94238515, -4.40151721,\n -3.95228878, -15.94238515, -15.94238515, -15.94238515,\n -6.10365415, -4.59494697, -15.94238515, -15.94238515,\n -15.94238515, -8.10727767, -6.2585298 , -15.94238515,\n -15.94238515, -15.94238515, -5.60161702, -4.47217004,\n -15.94238515, -15.94238515, -15.94238515, -5.91641988]\n )\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))", + "completions": [ + "```python\n def test_spectrogram_integration_test(self):\n waveform = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])\n window = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n expected_spectrogram = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]])\n\n self.assertTrue(np.allclose(spectrogram, expected_spectrogram))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 35, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef run_query_expansion_node(modules: List[Callable],\n module_params: List[Dict],\n previous_result: pd.DataFrame,\n node_line_dir: str,\n strategies: Dict,\n ) -> pd.DataFrame:\n if not os.path.exists(node_line_dir):\n os.makedirs(node_line_dir)\n node_dir = os.path.join(node_line_dir, \"query_expansion\")\n if not os.path.exists(node_dir):\n os.makedirs(node_dir)\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n\n # run query expansion\n results, execution_times = zip(*map(lambda task: measure_speed(\n task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))\n average_times = list(map(lambda x: x / len(results[0]), execution_times))\n\n # save results to folder\n pseudo_module_params = deepcopy(module_params)\n for i, module_param in enumerate(pseudo_module_params):\n if 'prompt' in module_params:\n module_param['prompt'] = str(i)\n filepaths = list(map(lambda x: os.path.join(node_dir, f'{x}.parquet'), range(len(modules))))\n list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet\n filenames = list(map(lambda x: os.path.basename(x), filepaths))\n\n # make summary file\n summary_df = pd.DataFrame({\n 'filename': filenames,\n 'module_name': list(map(lambda module: module.__name__, modules)),\n 'module_params': module_params,\n 'execution_time': average_times,\n })\n\n # Run evaluation when there are more than one module.\n if len(modules) > 1:\n # pop general keys from strategies (e.g. metrics, speed_threshold)\n general_key = ['metrics', 'speed_threshold']\n general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))\n extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))\n\n # first, filter by threshold if it is enabled.\n if general_strategy.get('speed_threshold') is not None:\n results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],\n filenames)\n\n # check metrics in strategy\n if general_strategy.get('metrics') is None:\n raise ValueError(\"You must at least one metrics for query expansion evaluation.\")\n\n if extra_strategy.get('top_k') is None:\n extra_strategy['top_k'] = 10 # default value\n\n # get retrieval modules from strategy\n retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)\n\n # get retrieval_gt\n retrieval_gt = pd.read_parquet(os.path.join(project_dir, \"data\", \"qa.parquet\"))['retrieval_gt'].tolist()\n\n # run evaluation\n evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(\n retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,\n general_strategy['metrics'], project_dir, previous_result), results))\n\n evaluation_df = pd.DataFrame({\n 'filename': filenames,\n **{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))\n for metric_name in general_strategy['metrics']}\n })\n summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')\n\n best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)\n # change metric name columns to query_expansion_metric_name\n best_result = best_result.rename(columns={\n metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})\n best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])\n else:\n best_result, best_filename = results[0], filenames[0]\n best_result = pd.concat([previous_result, best_result], axis=1)\n\n # add 'is_best' column at summary file\n summary_df['is_best'] = summary_df['filename'] == best_filename\n\n # save files\n summary_df.to_csv(os.path.join(node_dir, \"summary.csv\"), index=False)\n best_result.to_parquet(os.path.join(node_dir, f\"best_{os.path.splitext(best_filename)[0]}.parquet\"), index=False)\n\n return best_result\n```\n###test function signature:\n```python\n\ndef test_run_query_expansion_node_default(node_line_dir):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_query_expansion_node_default", + "reference": "\ndef test_run_query_expansion_node_default(node_line_dir):\n generator_models['mock'] = MockLLM\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n qa_path = os.path.join(project_dir, \"data\", \"qa.parquet\")\n previous_result = pd.read_parquet(qa_path)\n\n modules = [query_decompose, hyde]\n module_params = [{'llm': \"mock\", 'temperature': 0.2, 'batch': 7}, {'llm': \"mock\"}]\n strategies = {\n 'metrics': metrics\n }\n best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n base_query_expansion_test(best_result, node_line_dir)", + "completions": [ + "```python\ndef test_run_query_expansion_node_default(node_line_dir):\n modules = [query_expansion_module]\n module_params = [{'prompt': 'default'}]\n previous_result = pd.DataFrame({'query': ['a', 'b', 'c']})\n strategies = {'metrics': ['recall', 'precision', 'f1'], 'speed_threshold': 1000}\n run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n```\n" + ], + "line": 84, + "token": 850, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 36, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(self) -> Tuple[pd.DataFrame, FreqDetectionResult]:\n try:\n if self._prophet_detector_params[\"suppress_stan\"]:\n pd.set_option(\"mode.chained_assignment\", None)\n\n # Skip measurements based on feedbacks given from SODA Cloud\n preprocessed_df = self.preprocess(time_series_df=self.raw_time_series_df)\n\n # Automatically detect frequency of the time series\n freq_detector = FrequencyDetector(\n logs=self.logs,\n params=self.params,\n time_series_df=preprocessed_df,\n manual_freq=self.training_dataset_params.frequency,\n )\n freq_detection_result = freq_detector.detect_frequency()\n\n # Return if frequency detection failed\n if freq_detection_result.error_code_int >= ERROR_CODE_LEVEL_CUTTOFF:\n return self.exit_with_warning(freq_detection_result)\n\n # Apply training dataset configurations\n training_df = self.apply_training_dataset_configs(\n time_series_df=preprocessed_df, freq_detection_result=freq_detection_result\n )\n\n # Remove big gaps from the time series to not confuse Prophet\n training_df = self.remove_big_gaps_from_time_series(\n time_series_df=training_df, freq_detection_result=freq_detection_result\n )\n\n # Only use the last n points for training based on the window length\n window_length = self.get_window_length(training_df=training_df)\n training_df = training_df.iloc[-window_length:]\n\n training_df_shape = training_df[\"y\"].dropna().shape[0]\n if training_df_shape <= self._min_n_points:\n freq_detection_result = get_not_enough_measurements_freq_result(n_data_points=training_df_shape)\n return self.exit_with_warning(freq_detection_result)\n\n model_hyperparameters = self.get_prophet_hyperparameters(time_series_df=training_df)\n\n predictions_df = self.setup_fit_predict(\n time_series_df=training_df, model_hyperparameters=model_hyperparameters\n )\n anomalies_df = self.detect_anomalies(time_series_df=training_df, predictions_df=predictions_df)\n anomalies_df = self.generate_severity_zones(anomalies_df=anomalies_df)\n anomalies_df = self.compute_alert_level(anomalies_df=anomalies_df)\n return anomalies_df, freq_detection_result\n except Exception as e:\n raise e\n```\n###test function signature:\n```python\n\ndef test_with_exit() -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_exit", + "reference": "\ndef test_with_exit() -> None:\n time_series_df = generate_random_dataframe(size=2, n_rows_to_convert_none=0, frequency=\"D\")\n detector = ProphetDetector(\n logs=LOGS,\n params=PARAMS,\n time_series_df=time_series_df,\n model_cfg=ModelConfigs(),\n training_dataset_params=TrainingDatasetParameters(),\n severity_level_params=SeverityLevelParameters(),\n )\n df_anomalies, frequency_result = detector.run()\n assert df_anomalies.empty\n assert frequency_result.error_code_int == 100\n assert frequency_result.freq_detection_strategy == \"not_enough_measurements_custom\"", + "completions": [ + "```python\ndef test_with_exit() -> None:\n from soda.sodacl.checks.check_frequency import FrequencyDetector\n from soda.sodacl.checks.check_frequency import get_not_enough_measurements_freq_result\n from soda.sodacl.checks.check_frequency import ERROR_CODE_LEVEL_CUTTOFF\n from soda.sodacl.checks.check_frequency import FreqDetectionResult\n from soda.sodacl.checks.check_frequency import CheckFrequency\n from soda.sodacl.checks.check_frequency import CheckFrequencyParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyLogs\n from soda.sodacl.checks.check_frequency import CheckFrequencyTrainingDatasetParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyPro" + ], + "line": 52, + "token": 310, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 37, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef average_transforms(\n R: torch.Tensor,\n t: torch.Tensor,\n w: torch.Tensor,\n mask: torch.Tensor,\n dim: int,\n t_edge: Optional[torch.Tensor] = None,\n dither: Optional[bool] = True,\n dither_eps: float = 1e-4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert dim >= 0, \"dimension must index from the left\"\n w = torch.where(\n mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)\n )\n\n # We use different averaging models based on the number of weights\n num_transform_weights = w.size(-1)\n if num_transform_weights == 1:\n # Share a single scalar weight between t and R.\n probs = w.softmax(dim)\n t_probs = probs\n R_probs = probs[..., None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 2:\n # Use separate scalar weights for each of t and R.\n probs = w.softmax(dim)\n t_probs, R_probs = probs.unbind(-1)\n t_probs = t_probs[..., None]\n R_probs = R_probs[..., None, None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 3:\n # For R use a signed scalar weight.\n R_probs = w[..., 2].softmax(dim)[..., None, None]\n\n # For t use a two-parameter precision matrix P = P_isometric + P_radial.\n # We need to hand compute softmax over the shared dim x 2 elements.\n w_t = w[..., :2]\n w_t_total = w_t.logsumexp([dim, -1], True)\n p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)\n\n # Use Gaussian fusion for translation.\n t_edge = t_edge * mask.to(t_edge.dtype)[..., None]\n t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)\n else:\n raise NotImplementedError\n\n # Average rotation via SVD\n R_avg_unc = (R * R_probs).sum(dim)\n R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)\n U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)\n R_avg = U @ Vh\n\n # Enforce that matrix is rotation matrix\n d = torch.linalg.det(R_avg)\n d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)\n Vh = Vh * d_expand\n R_avg = U @ Vh\n return R_avg, t_avg\n```\n###test function signature:\n```python\n\ndef test_average_transforms(rotations, translations):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_average_transforms", + "reference": "\ndef test_average_transforms(rotations, translations):\n R = torch.stack([rotations[0], torch.eye(3)])\n t = torch.stack([translations[0], torch.zeros(3)])\n w = torch.ones(2, 2)\n mask = torch.ones(2)\n # average of a transform with the identity is \"half\" the transform\n R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)\n R_total_fromavg, _ = compose_transforms(\n R_avg, torch.zeros(3), R_avg, torch.zeros(3)\n )\n _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)\n assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)\n assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)", + "completions": [ + "```python\ndef test_average_transforms(rotations, translations):\n # test case 1\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 2\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 3\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 4\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 5\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 6\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 7\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 8\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 9\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w" + ], + "line": 65, + "token": 537, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 38, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n return GPT2Config(\n vocab_size=btlm_config.vocab_size,\n n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n n_embd=btlm_config.hidden_size,\n n_layer=btlm_config.num_hidden_layers,\n n_head=btlm_config.num_attention_heads,\n n_inner=btlm_config.n_inner,\n activation_function=btlm_config.activation_function,\n resid_pdrop=btlm_config.resid_pdrop,\n embd_pdrop=btlm_config.embd_pdrop,\n attn_pdrop=btlm_config.attn_pdrop,\n layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n initializer_range=btlm_config.initializer_range,\n bos_token_id=btlm_config.bos_token_id,\n eos_token_id=btlm_config.eos_token_id,\n # These are new arguments not in the original GPT2Config\n use_alibi=btlm_config.position_embedding_type == \"alibi\",\n use_flash_attn=btlm_config.position_embedding_type == \"alibi\", # Alibi code path requires flash_attn\n mup_width_scale=btlm_config.mup_width_scale,\n mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n mup_output_multiplier=btlm_config.mup_output_alpha,\n mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n mlp_multiple_of=1,\n )\n```\n###test function signature:\n```python\ndef test_btlm_optimized(model_name):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_btlm_optimized", + "reference": "def test_btlm_optimized(model_name):\n dtype = torch.float16\n device = \"cuda\"\n config = btlm_config_to_gpt2_config(\n AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n )\n config.fused_bias_fc = True\n config.fused_dropout_add_ln = True\n config.residual_in_fp32 = True\n\n pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n model = GPTLMHeadModel(config, device=device, dtype=dtype)\n model.load_state_dict(pretrained_state_dict)\n model.eval()\n\n torch.manual_seed(0)\n batch_size = 2\n max_seqlen = 256\n seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n input_ids = torch.randint(\n 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n )\n with torch.no_grad():\n out = model.transformer(input_ids)\n logits = model(input_ids).logits\n del model\n\n # Without device_map, the model is loaded on the CPU, which is very slow\n # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n model_ref = AutoModelForCausalLM.from_pretrained(\n model_name, device_map=\"auto\", trust_remote_code=True\n )\n model_ref.eval()\n with torch.no_grad():\n out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n logits_ref = model_ref(input_ids).logits.to(device=device)\n del model_ref\n\n model_hf = AutoModelForCausalLM.from_pretrained(\n model_name,\n torch_dtype=dtype,\n device_map={\"\": device},\n trust_remote_code=True,\n )\n model_hf.eval()\n with torch.no_grad():\n out_hf = model_hf.transformer(input_ids).last_hidden_state\n logits_hf = model_hf(input_ids).logits\n del model_hf\n\n print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n assert (logits - logits_ref).abs().max().item() < 3 * (\n logits_hf - logits_ref\n ).abs().max().item()", + "completions": [ + "```python\ndef test_btlm_optimized(model_name):\n from transformers import GPT2Config, PretrainedConfig\n from bert_extended.optimized import btlm_config_to_gpt2_config\n\n config = PretrainedConfig.from_pretrained(model_name)\n gpt2_config = btlm_config_to_gpt2_config(config)\n gpt2_config.to_json_file(f\"{model_name}_optimized.json\")\n```\n" + ], + "line": 35, + "token": 169, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 39, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef grad_solid_harmonic(r: jnp.ndarray, l_max: int) -> jnp.ndarray:\n def cartesian_sum(r, ell):\n # computes all cartesian sums of order ell\n # i.e. all terms (a*x + b*y + c*z) for a+b+c = ell\n if ell == 0:\n return np.array([0.])\n return np.sum(np.stack(\n np.meshgrid(*[r]*ell, indexing='ij'), axis=0).reshape(ell, -1), axis=0)\n\n dharmonics = []\n for ell in range(l_max + 1):\n dx = cartesian_sum(np.array([1., 0., 0.]), ell)\n dy = cartesian_sum(np.array([0., 1., 0.]), ell)\n dz = cartesian_sum(np.array([0., 0., 1.]), ell)\n d = np.stack([dx, dy, dz], axis=1)\n cp = jax.vmap(functools.partial(cartesian_product, ell=ell), out_axes=-1)\n dcp = jnp.where(r.T[None, :, :] == 0, 0,\n d[:, :, None]*cp(r)[:, None, :] / r.T[None, :, :])\n sh = jnp.einsum('cs,cxn->snx', full_cart2sph(ell, True), dcp)\n pad = l_max - ell\n sh = jnp.pad(sh, [[pad, pad], [0, 0], [0, 0]])\n dharmonics.append(sh)\n\n return jnp.stack(dharmonics, axis=1) # [L, M, grid, xyz]\n```\n###test function signature:\n```python\n\n def test_grad_solid_harmonic(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_grad_solid_harmonic", + "reference": "\n def test_grad_solid_harmonic(self):\n np.random.seed(0)\n r = np.random.randn(100, 3)\n l_max = 5\n\n jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max))\n expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])),\n [1, 2, 0, 3])\n\n with self.subTest('by hand'):\n observed = gto.grad_solid_harmonic(r, l_max)\n np.testing.assert_allclose(observed, expected, atol=1.e-4)\n with self.subTest('by jax'):\n observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max)\n np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4)", + "completions": [ + "```python\n def test_grad_solid_harmonic(self):\n r = jnp.array([[0.5, 0.5, 0.5]])\n l_max = 2\n dharmonics = grad_solid_harmonic(r, l_max)\n self.assertEqual(dharmonics.shape, (l_max + 1, l_max + 1, 1, 3))\n self.assertTrue(jnp.allclose(dharmonics[0, 0, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 0, 0, :], jnp.array([1., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 1, 0, :], jnp.array([0., 1., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 0, 0, :], jnp.array([0., 0., 1.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 1, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 2, 0, :], jnp.array([0., 0., 0.])))\n```\n" + ], + "line": 32, + "token": 389, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 40, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_basis(\n base_shape, angular_tesselation, remove_symmetries=True, eps=1e-4\n):\n\n if base_shape == 'tetrahedron':\n verts = np.array([\n (np.sqrt(8 / 9), 0, -1 / 3),\n (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),\n (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3),\n (0, 0, 1),\n ])\n faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)])\n elif base_shape == 'icosahedron':\n a = (np.sqrt(5) + 1) / 2\n verts = np.array([\n (-1, 0, a),\n (1, 0, a),\n (-1, 0, -a),\n (1, 0, -a),\n (0, a, 1),\n (0, a, -1),\n (0, -a, 1),\n (0, -a, -1),\n (a, 1, 0),\n (-a, 1, 0),\n (a, -1, 0),\n (-a, -1, 0),\n ]) / np.sqrt(a + 2)\n faces = np.array([\n (0, 4, 1),\n (0, 9, 4),\n (9, 5, 4),\n (4, 5, 8),\n (4, 8, 1),\n (8, 10, 1),\n (8, 3, 10),\n (5, 3, 8),\n (5, 2, 3),\n (2, 7, 3),\n (7, 10, 3),\n (7, 6, 10),\n (7, 11, 6),\n (11, 0, 6),\n (0, 1, 6),\n (6, 1, 10),\n (9, 0, 11),\n (9, 11, 2),\n (9, 2, 5),\n (7, 2, 11),\n ])\n elif base_shape == 'octahedron':\n verts = np.array(\n [(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]\n )\n corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n else:\n raise ValueError(f'base_shape {base_shape} not supported')\n verts = tesselate_geodesic(verts, faces, angular_tesselation)\n\n if remove_symmetries:\n # Remove elements of `verts` that are reflections of each other.\n match = compute_sq_dist(verts.T, -verts.T) < eps\n verts = verts[~np.any(np.triu(match), axis=0), :]\n\n basis = verts[:, ::-1]\n return basis\n```\n###test function signature:\n```python\n def test_generate_basis_golden(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_basis_golden", + "reference": " def test_generate_basis_golden(self):\n\n basis = geopoly.generate_basis('tetrahedron', 1)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('tetrahedron', 2)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.57735027, 0.00000000, -0.81649658],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.57735027, -0.70710678, 0.40824829],\n [-0.57735027, 0.70710678, 0.40824829],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('icosahedron', 2)\n basis_golden = np.array([\n [0.85065081, 0.00000000, 0.52573111],\n [0.80901699, 0.50000000, 0.30901699],\n [0.52573111, 0.85065081, 0.00000000],\n [1.00000000, 0.00000000, 0.00000000],\n [0.80901699, 0.50000000, -0.30901699],\n [0.85065081, 0.00000000, -0.52573111],\n [0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, -0.85065081],\n [0.50000000, 0.30901699, -0.80901699],\n [0.00000000, 1.00000000, 0.00000000],\n [-0.52573111, 0.85065081, 0.00000000],\n [-0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, 0.85065081],\n [-0.30901699, 0.80901699, 0.50000000],\n [0.30901699, 0.80901699, 0.50000000],\n [0.50000000, 0.30901699, 0.80901699],\n [0.50000000, -0.30901699, 0.80901699],\n [0.00000000, 0.00000000, 1.00000000],\n [-0.50000000, 0.30901699, 0.80901699],\n [-0.80901699, 0.50000000, 0.30901699],\n [-0.80901699, 0.50000000, -0.30901699],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('octahedron', 4)\n basis_golden = np.array([\n [0.00000000, 0.00000000, -1.00000000],\n [0.00000000, -0.31622777, -0.94868330],\n [0.00000000, -0.70710678, -0.70710678],\n [0.00000000, -0.94868330, -0.31622777],\n [0.00000000, -1.00000000, 0.00000000],\n [-0.31622777, 0.00000000, -0.94868330],\n [-0.40824829, -0.40824829, -0.81649658],\n [-0.40824829, -0.81649658, -0.40824829],\n [-0.31622777, -0.94868330, 0.00000000],\n [-0.70710678, 0.00000000, -0.70710678],\n [-0.81649658, -0.40824829, -0.40824829],\n [-0.70710678, -0.70710678, 0.00000000],\n [-0.94868330, 0.00000000, -0.31622777],\n [-0.94868330, -0.31622777, 0.00000000],\n [-1.00000000, 0.00000000, 0.00000000],\n [0.00000000, -0.31622777, 0.94868330],\n [0.00000000, -0.70710678, 0.70710678],\n [0.00000000, -0.94868330, 0.31622777],\n [0.40824829, -0.40824829, 0.81649658],\n [0.40824829, -0.81649658, 0.40824829],\n [0.31622777, -0.94868330, 0.00000000],\n [0.81649658, -0.40824829, 0.40824829],\n [0.70710678, -0.70710678, 0.00000000],\n [0.94868330, -0.31622777, 0.00000000],\n [0.31622777, 0.00000000, -0.94868330],\n [0.40824829, 0.40824829, -0.81649658],\n [0.40824829, 0.81649658, -0.40824829],\n [0.70710678, 0.00000000, -0.70710678],\n [0.81649658, 0.40824829, -0.40824829],\n [0.94868330, 0.00000000, -0.31622777],\n [0.40824829, -0.40824829, -0.81649658],\n [0.40824829, -0.81649658, -0.40824829],\n [0.81649658, -0.40824829, -0.40824829],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))", + "completions": [ + "```python\n def test_generate_basis_golden(self):\n basis = generate_basis(\n base_shape='tetrahedron', angular_tesselation=10, remove_symmetries=True, eps=1e-4\n )\n self.assertTrue(np.allclose(basis, np.array([[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0]])))\n```\n" + ], + "line": 75, + "token": 692, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 41, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:\n if not isinstance(t, torch.Tensor):\n t = torch.Tensor([t]).float()\n\n t_max = self.t_max.to(t.device)\n t_min = self.t_min.to(t.device)\n t_tilde = t_min + (t_max - t_min) * t\n\n return t_tilde\n```\n###test function signature:\n```python\n\ndef test_noise_schedule_diffusion_coeff(kind):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_noise_schedule_diffusion_coeff", + "reference": "\ndef test_noise_schedule_diffusion_coeff(kind):\n noise_schedule = GaussianNoiseSchedule(kind=kind, log_snr_range=(-6, 6))\n ts = torch.linspace(1e-2, 1 - 1e-2, 10)\n t_map = noise_schedule.t_map(ts) # map time to the prescribed log_SNR range\n\n if kind == \"log_snr\":\n g = noise_schedule.g(ts)\n\n # compute true beta_t\n l_range = noise_schedule.log_snr_range\n snr = torch.exp((1 - t_map) * l_range[1] + t_map * l_range[0])\n g_true = (-(l_range[0] - l_range[1]) / (snr + 1)).sqrt()\n\n assert torch.allclose(g, g_true, atol=1e-4)\n\n if kind == \"ot_linear\":\n g = noise_schedule.g(ts)\n g_true = (2.0 * t_map / (1.0 - t_map)).sqrt()\n assert torch.allclose(g, g_true, atol=1e-4)", + "completions": [ + "```python\ndef test_noise_schedule_diffusion_coeff(kind):\n from torch.testing import assert_allclose\n from torch.nn import functional as F\n from torch import tensor\n from torch.distributions import Uniform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import Transform\n from torch.distributions.transforms import SigmoidTransform\n from torch.distributions.transforms import TanhTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import ComposeTransform\n from torch.distributions.transforms import PowerTransform\n from torch.distributions.transforms import ExpTransform\n from torch.distributions.transforms import LogTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import" + ], + "line": 17, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 42, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_data_quality", + "reference": "\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index, temperature_data, temperature_mean=False, data_quality=True\n )\n assert df.shape == (27, 4)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_not_null\",\n \"temperature_null\",\n ]\n assert round(df.temperature_not_null.mean(), 2) == 729.23\n assert round(df.temperature_null.mean(), 2) == 0.0", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"1H\", tz=\"America/Chicago\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0" + ], + "line": 168, + "token": 984, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 43, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n\n # when\n result = http_client.unload_model(model_id=\"some/1\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"some/1\",\n }", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n model_id = \"my_model_id\"\n de_aliased_model_id = \"my_de_aliased_model_id\"\n response_payload = {\n \"model_id\": de_aliased_model_id,\n \"model_name\": \"my_model_name\",\n \"model_version\": \"my_model_version\",\n \"model_type\": \"my_model_type\",\n \"model_status\": \"my_model_status\",\n \"model_description\": \"my_model_description\",\n \"model_tags\": [\"my_model_tag1\", \"my_model_tag2\"],\n \"model_created_at\": \"my_model_created_at\",\n \"model_updated_at\": \"my_model_updated_at\",\n }\n requests_mock.post(\n f\"{self.__api_url}/model/remove\",\n json=response_payload,\n headers=DEFAULT_HEADERS,\n )\n # when\n response = self.client.unload_model(model_id=model_id)\n # then\n assert response.model_id == de_aliased_model_id\n assert response.model_name == \"my_model_name\"\n assert response.model_version == \"my_model_version\"\n assert response.model_type == \"my_model_type\"\n assert response.model_status == \"my_model_status\"\n assert response.model_description == \"my_model_description\"\n assert response.model_tags == [\"my_model_tag1\", \"my_model_tag2\"]\n assert response.model_created_at == \"my_model_created_at\"\n assert response.model_updated_at == \"my_model_updated_at\"\n assert self.client.selected_model is None\n```\n" + ], + "line": 31, + "token": 162, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=20.0,\n max_iteration=100,\n distance_strict=False,\n seed=None,\n):\n rng = np.random.default_rng(seed=seed)\n units_locations = np.zeros((num_units, 3), dtype=\"float32\")\n\n minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um\n minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um\n\n units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)\n units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)\n units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)\n\n if minimum_distance is not None:\n solution_found = False\n renew_inds = None\n for i in range(max_iteration):\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n inds0, inds1 = np.nonzero(distances < minimum_distance)\n mask = inds0 != inds1\n inds0 = inds0[mask]\n inds1 = inds1[mask]\n\n if inds0.size > 0:\n if renew_inds is None:\n renew_inds = np.unique(inds0)\n else:\n # random only bad ones in the previous set\n renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]\n\n units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)\n units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)\n units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)\n else:\n solution_found = True\n break\n\n if not solution_found:\n if distance_strict:\n raise ValueError(\n f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} \"\n \"You can use distance_strict=False or reduce minimum distance\"\n )\n else:\n warnings.warn(f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}\")\n\n return units_locations\n```\n###test function signature:\n```python\n\ndef test_generate_unit_locations():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_unit_locations", + "reference": "\ndef test_generate_unit_locations():\n seed = 0\n\n probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)\n channel_locations = probe.contact_positions\n\n num_units = 100\n minimum_distance = 20.0\n unit_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=minimum_distance,\n max_iteration=500,\n distance_strict=False,\n seed=seed,\n )\n distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)\n dist_flat = np.triu(distances, k=1).flatten()\n dist_flat = dist_flat[dist_flat > 0]\n assert np.all(dist_flat > minimum_distance)\n\n # import matplotlib.pyplot as plt\n # fig, ax = plt.subplots()\n # ax.hist(dist_flat, bins = np.arange(0, 400, 10))\n # fig, ax = plt.subplots()\n # from probeinterface.plotting import plot_probe\n # plot_probe(probe, ax=ax)\n # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)\n # plt.show()", + "completions": [ + "```python\n\ndef test_generate_unit_locations():\n num_units = 10\n channel_locations = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])\n margin_um = 20.0\n minimum_z = 5.0\n maximum_z = 40.0\n minimum_distance = 20.0\n max_iteration = 100\n distance_strict = False\n seed = None\n\n units_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=margin_um,\n minimum_z=minimum_z,\n maximum_z=maximum_z,\n minimum_distance=minimum_distance,\n max_iteration=max_iteration,\n distance_strict=distance_strict,\n seed=seed,\n )\n\n assert units_locations.shape == (num_units, 3)\n assert np.all(units_locations[:, 0] >= margin_um)\n assert np.all(units_locations[:, 0] <= 100 - margin_um)\n assert np.all(units_locations[:, 1] >= margin_um)\n assert np.all(units_locations[:, 1] <= 100 - margin_um)\n assert np.all(units_locations[:, 2] >= minimum_z)\n assert np.all(units_locations[:, 2] <= maximum_z)\n\n # check distance\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n assert np.all(distances >= minimum_distance)\n\n # check no duplicate\n assert np.all(np.unique(units_locations, axis=0) == units_locations)\n```\n" + ], + "line": 58, + "token": 453, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 45, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def to_dataframe(\n self, index: Optional[Union[str, None]] = \"date\", sort_by: Optional[str] = None\n ) -> pd.DataFrame:\n\n def is_list_of_basemodel(items: Union[List[T], T]) -> bool:\n return isinstance(items, list) and all(\n isinstance(item, BaseModel) for item in items\n )\n\n if self.results is None or not self.results:\n raise OpenBBError(\"Results not found.\")\n\n if isinstance(self.results, pd.DataFrame):\n return self.results\n\n try:\n res = self.results\n df = None\n sort_columns = True\n\n # List[Dict]\n if isinstance(res, list) and len(res) == 1 and isinstance(res[0], dict):\n r = res[0]\n dict_of_df = {}\n\n for k, v in r.items():\n # Dict[str, List[BaseModel]]\n if is_list_of_basemodel(v):\n dict_of_df[k] = basemodel_to_df(v, index)\n sort_columns = False\n # Dict[str, Any]\n else:\n dict_of_df[k] = pd.DataFrame(v)\n\n df = pd.concat(dict_of_df, axis=1)\n\n # List[BaseModel]\n elif is_list_of_basemodel(res):\n dt: Union[List[Data], Data] = res # type: ignore\n df = basemodel_to_df(dt, index)\n sort_columns = False\n # List[List | str | int | float] | Dict[str, Dict | List | BaseModel]\n else:\n try:\n df = pd.DataFrame(res)\n # Set index, if any\n if index is not None and index in df.columns:\n df.set_index(index, inplace=True)\n\n except ValueError:\n if isinstance(res, dict):\n df = pd.DataFrame([res])\n\n if df is None:\n raise OpenBBError(\"Unsupported data format.\")\n\n # Drop columns that are all NaN, but don't rearrange columns\n if sort_columns:\n df.sort_index(axis=1, inplace=True)\n df = df.dropna(axis=1, how=\"all\")\n\n # Sort by specified column\n if sort_by:\n df.sort_values(by=sort_by, inplace=True)\n\n except OpenBBError as e:\n raise e\n except ValueError as ve:\n raise OpenBBError(\n f\"ValueError: {ve}. Ensure the data format matches the expected format.\"\n ) from ve\n except TypeError as te:\n raise OpenBBError(\n f\"TypeError: {te}. Check the data types in your results.\"\n ) from te\n except Exception as ex:\n raise OpenBBError(f\"An unexpected error occurred: {ex}\") from ex\n\n return df\n```\n###test function signature:\n```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_to_dataframe_w_args", + "reference": "def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n co = OBBject(results=results)\n\n # Act and Assert\n result = co.to_dataframe(index=index, sort_by=sort_by)\n assert isinstance(result, pd.DataFrame)\n assert result.index.name == index\n\n # check if dataframe is properly sorted\n assert result[sort_by].is_monotonic_increasing", + "completions": [ + "```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n from openbb_terminal.decorators import log_start_end\n from openbb_terminal.stocks.fundamental_analysis import financedatabase_model\n\n @log_start_end(log=logger)\n def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n df = financedatabase_model.to_dataframe(results, index, sort_by)\n # Act\n # Assert\n assert df.shape[0] > 0\n assert df.shape[1] > 0\n assert df.index.name == index\n assert df.columns.tolist() == sort_by\n\n # Act\n test_to_dataframe_w_args(results, index, sort_by)\n # Assert\n```\n" + ], + "line": 76, + "token": 566, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 46, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef unique(ar, return_index=False, return_inverse=False,\n return_counts=False, axis=None, *, equal_nan=True):\n ar = np.asanyarray(ar)\n if axis is None:\n ret = _unique1d(ar, return_index, return_inverse, return_counts, \n equal_nan=equal_nan)\n return _unpack_tuple(ret)\n\n # axis was specified and not None\n try:\n ar = np.moveaxis(ar, axis, 0)\n except np.AxisError:\n # this removes the \"axis1\" or \"axis2\" prefix from the error message\n raise np.AxisError(axis, ar.ndim) from None\n\n # Must reshape to a contiguous 2D array for this to work...\n orig_shape, orig_dtype = ar.shape, ar.dtype\n ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))\n ar = np.ascontiguousarray(ar)\n dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]\n\n # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured\n # data type with `m` fields where each field has the data type of `ar`.\n # In the following, we create the array `consolidated`, which has\n # shape `(n,)` with data type `dtype`.\n try:\n if ar.shape[1] > 0:\n consolidated = ar.view(dtype)\n else:\n # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is\n # a data type with itemsize 0, and the call `ar.view(dtype)` will\n # fail. Instead, we'll use `np.empty` to explicitly create the\n # array with shape `(len(ar),)`. Because `dtype` in this case has\n # itemsize 0, the total size of the result is still 0 bytes.\n consolidated = np.empty(len(ar), dtype=dtype)\n except TypeError as e:\n # There's no good way to do this for object arrays, etc...\n msg = 'The axis argument to unique is not supported for dtype {dt}'\n raise TypeError(msg.format(dt=ar.dtype)) from e\n\n def reshape_uniq(uniq):\n n = len(uniq)\n uniq = uniq.view(orig_dtype)\n uniq = uniq.reshape(n, *orig_shape[1:])\n uniq = np.moveaxis(uniq, 0, axis)\n return uniq\n\n output = _unique1d(consolidated, return_index,\n return_inverse, return_counts, equal_nan=equal_nan)\n output = (reshape_uniq(output[0]),) + output[1:]\n return _unpack_tuple(output)\n```\n###test function signature:\n```python\n\n def test_unique_1d(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_1d", + "reference": "\n def test_unique_1d(self):\n\n def check_all(a, b, i1, i2, c, dt):\n base_msg = 'check {0} failed for type {1}'\n\n msg = base_msg.format('values', dt)\n v = unique(a)\n assert_array_equal(v, b, msg)\n\n msg = base_msg.format('return_index', dt)\n v, j = unique(a, True, False, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i1, msg)\n\n msg = base_msg.format('return_inverse', dt)\n v, j = unique(a, False, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i2, msg)\n\n msg = base_msg.format('return_counts', dt)\n v, j = unique(a, False, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, c, msg)\n\n msg = base_msg.format('return_index and return_inverse', dt)\n v, j1, j2 = unique(a, True, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n\n msg = base_msg.format('return_index and return_counts', dt)\n v, j1, j2 = unique(a, True, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format('return_inverse and return_counts', dt)\n v, j1, j2 = unique(a, False, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i2, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format(('return_index, return_inverse '\n 'and return_counts'), dt)\n v, j1, j2, j3 = unique(a, True, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n assert_array_equal(j3, c, msg)\n\n a = [5, 7, 1, 2, 1, 5, 7]*10\n b = [1, 2, 5, 7]\n i1 = [2, 3, 0, 1]\n i2 = [2, 3, 0, 1, 0, 2, 3]*10\n c = np.multiply([2, 1, 2, 2], 10)\n\n # test for numeric arrays\n types = []\n types.extend(np.typecodes['AllInteger'])\n types.extend(np.typecodes['AllFloat'])\n types.append('datetime64[D]')\n types.append('timedelta64[D]')\n for dt in types:\n aa = np.array(a, dt)\n bb = np.array(b, dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for object arrays\n dt = 'O'\n aa = np.empty(len(a), dt)\n aa[:] = a\n bb = np.empty(len(b), dt)\n bb[:] = b\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for structured arrays\n dt = [('', 'i'), ('', 'i')]\n aa = np.array(list(zip(a, a)), dt)\n bb = np.array(list(zip(b, b)), dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for ticket #2799\n aa = [1. + 0.j, 1 - 1.j, 1]\n assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])\n\n # test for ticket #4785\n a = [(1, 2), (1, 2), (2, 3)]\n unq = [1, 2, 3]\n inv = [0, 1, 0, 1, 1, 2]\n a1 = unique(a)\n assert_array_equal(a1, unq)\n a2, a2_inv = unique(a, return_inverse=True)\n assert_array_equal(a2, unq)\n assert_array_equal(a2_inv, inv)\n\n # test for chararrays with return_inverse (gh-5099)\n a = np.chararray(5)\n a[...] = ''\n a2, a2_inv = np.unique(a, return_inverse=True)\n assert_array_equal(a2_inv, np.zeros(5))\n\n # test for ticket #9137\n a = []\n a1_idx = np.unique(a, return_index=True)[1]\n a2_inv = np.unique(a, return_inverse=True)[1]\n a3_idx, a3_inv = np.unique(a, return_index=True,\n return_inverse=True)[1:]\n assert_equal(a1_idx.dtype, np.intp)\n assert_equal(a2_inv.dtype, np.intp)\n assert_equal(a3_idx.dtype, np.intp)\n assert_equal(a3_inv.dtype, np.intp)\n\n # test for ticket 2111 - float\n a = [2.0, np.nan, 1.0, np.nan]\n ua = [1.0, 2.0, np.nan]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - complex\n a = [2.0-1j, np.nan, 1.0+1j, complex(0.0, np.nan), complex(1.0, np.nan)]\n ua = [1.0+1j, 2.0-1j, complex(0.0, np.nan)]\n ua_idx = [2, 0, 3]\n ua_inv = [1, 2, 0, 2, 2]\n ua_cnt = [1, 1, 3]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - datetime64\n nat = np.datetime64('nat')\n a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]\n ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - timedelta\n nat = np.timedelta64('nat')\n a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]\n ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for gh-19300\n all_nans = [np.nan] * 4\n ua = [np.nan]\n ua_idx = [0]\n ua_inv = [0, 0, 0, 0]\n ua_cnt = [4]\n assert_equal(np.unique(all_nans), ua)\n assert_equal(np.unique(all_nans, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(all_nans, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(all_nans, return_counts=True), (ua, ua_cnt))", + "completions": [ + "```python\n def test_unique_1d(self):\n ar = np.array([1, 2, 3, 4, 5, 5, 5, 5, 5, 5])\n self.assertEqual(unique(ar), np.array([1, 2, 3, 4, 5]))\n self.assertEqual(unique(ar, return_index=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n```\n" + ], + "line": 56, + "token": 571, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 47, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef scaled_dot_product_attention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n att_mask: Optional[Union[AttentionMask, \"SparseCS\", torch.Tensor]],\n dropout: Optional[torch.nn.Module] = None,\n block_size: int = 128,\n) -> torch.Tensor:\n autocast_disabled = (\n _has_cpp_library\n and isinstance(att_mask, SparseCS)\n or (att_mask is not None and att_mask.is_sparse)\n )\n seq_len = q.shape[-2]\n\n # switch if:\n # causal is required but mask is not sparse\n # fp16 or under amp context\n # sequence length is divisible by block size\n # same seq len for K and Q\n switch_to_blocksparse = (\n _is_blocksparse_available\n and (att_mask is not None and not att_mask.is_sparse)\n and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)\n and (q.dtype == torch.float16 or torch.is_autocast_enabled())\n and not seq_len % block_size\n and q.shape[-2] == k.shape[-2]\n )\n\n if switch_to_blocksparse:\n logger.info(\"Switching causal attention to Triton blocksparse...\")\n return blocksparse_attention(q, k, v, dropout, block_size)\n\n with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore\n if autocast_disabled:\n q, k, v = q.float(), k.float(), v.float()\n\n att = scaled_query_key_softmax(q, k, att_mask=att_mask)\n\n # Optional dropout, could be part of the masking in the future\n att = _apply_dropout(att, dropout)\n\n # Get to the predicted values, for all heads\n # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)\n y = bmm(att, v)\n return y\n```\n###test function signature:\n```python\n\ndef test_amp_attention_dense(device):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_amp_attention_dense", + "reference": "\ndef test_amp_attention_dense(device):\n b, s, d = 8, 64, 32\n prob = 0.9\n\n a = torch.rand(b, s, d, device=device)\n m = torch.rand(s, s, device=device) > prob\n\n with torch.cuda.amp.autocast():\n r = scaled_dot_product_attention(a, a, a, m)\n\n expected_device = torch.float16 if device == \"cuda\" else torch.float32\n assert r.dtype == expected_device", + "completions": [ + "```python\n\ndef test_amp_attention_dense(device):\n q = torch.randn(1, 128, 128, device=device)\n k = torch.randn(1, 128, 128, device=device)\n v = torch.randn(1, 128, 128, device=device)\n att_mask = torch.randn(1, 128, 128, device=device)\n dropout = torch.nn.Dropout(0.1)\n block_size = 128\n y = scaled_dot_product_attention(q, k, v, att_mask, dropout, block_size)\n assert y.shape == (1, 128, 128)\n```\n" + ], + "line": 50, + "token": 390, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 48, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def quantize(\n cls,\n tensor: Tensor,\n nbits: int = 4,\n channel_wise: bool = True,\n group_size: int = 64,\n optimize: bool = False,\n round_zero: bool = False,\n axis: int = 0,\n bitpack: bool = True,\n compute_dtype: torch.dtype | None = None,\n view_as_float: bool = False,\n device: str = \"cuda\",\n ) -> tuple:\n assert nbits in Quantizer.SUPPORTED_BITS, (\n \"nbits=\" + str(nbits) + \" not supported.\"\n )\n assert axis in [0, 1], \"axis should be either 0 or 1\"\n if group_size is not None:\n assert is_divisible(tensor.numel(), group_size), (\n \"group_size should be divisble by the total tensor dimensions. shape: \"\n + str(tensor.shape)\n + \", group_size: \"\n + str(group_size)\n )\n\n W = tensor.float()\n shape = W.shape\n\n # Reshape for grouping\n if (group_size is not None) and channel_wise:\n W = (\n W.reshape([-1, group_size])\n if (axis == 1)\n else W.reshape([group_size, -1])\n )\n\n # Get min/max values\n if not channel_wise:\n _min, _max = W.min(), W.max()\n optimize = False\n else:\n _min = W.min(axis=axis, keepdim=True)[0]\n _max = W.max(axis=axis, keepdim=True)[0]\n\n max_v = 2**nbits - 1\n min_v = 0\n min_max = [min_v, max_v]\n\n # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.\n scale = (max_v / (_max - _min)).clamp(\n max=2e4\n ) # clamp to avoid half-precision problems\n zero = -_min * scale\n\n # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14\n if round_zero:\n zero = torch.round(zero)\n\n # Fine-tune weights\n if optimize:\n W_q, scale, zero = Quantizer.optimize_weights(\n tensor=W,\n scale=scale,\n zero=zero,\n min_max=min_max,\n axis=axis,\n device=device,\n )\n else:\n W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])\n\n # Store meta-data (we invert the scale for dequantization)\n meta = {\n \"nbits\": nbits,\n \"group_size\": group_size,\n \"shape\": shape,\n \"scale\": 1.0 / scale,\n \"zero\": zero,\n \"axis\": axis,\n \"packing\": Quantizer.bit_to_packing[nbits],\n }\n meta[\"unpack_view_dtype\"] = Quantizer.unpack_view_dtype[meta[\"packing\"]]\n\n # Pack bits\n meta[\"view_as_float\"] = view_as_float\n if bitpack:\n W_q = Quantizer.pack[meta[\"packing\"]](W_q)\n if view_as_float:\n W_q = W_q.view(\n torch.float32 if compute_dtype is None else compute_dtype\n ) # store quantized weights as compute_dtype\n else:\n W_q = W_q.to(tensor.dtype)\n meta[\"packing\"] = None\n\n # cleanup\n del W, _min, _max\n torch.cuda.empty_cache()\n\n return W_q, meta\n```\n###test function signature:\n```python\n\n def test_quantizer(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_quantizer", + "reference": "\n def test_quantizer(self): \n for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:\n for nbits in [8,4,3,2,1]:\n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, view_as_float=False)\n if nbits == 3:\n assert W_q.dtype == torch.int32\n else:\n assert W_q.dtype == torch.uint8\n w_dq = Quantizer.dequantize(W_q, meta)\n norm1 = torch.norm(self.w - w_dq, p=0.7)\n \n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, compute_dtype=compute_dtype, view_as_float=True)\n assert W_q.dtype == compute_dtype\n w_dq = Quantizer.dequantize(W_q, meta)\n norm2 = torch.norm(self.w - w_dq, p=0.7)\n \n self.assertTrue(torch.equal(norm1, norm2))", + "completions": [ + "```python\n\n def test_quantizer(self):\n # test case 1\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 2\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=True, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], 0.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 3\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=True, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 4\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor," + ], + "line": 100, + "token": 645, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 49, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_num_ref_area", + "reference": "\ndef test_nanoplot_out_bars_with_num_ref_area():\n\n out_bars_with_num_ref_area = _generate_nanoplot(**CASES[13])\n\n assert _is_nanoplot_output(out_bars_with_num_ref_area)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_area = [1, 3]\n y_ref_line = 2\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_" + ], + "line": 749, + "token": 7038, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 50, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef mean(x, axis=None, keepdims=False):\n if isinstance(x, (list, tuple)):\n x = stack(x)\n x = convert_to_tensor(x)\n if axis == () or axis == []:\n # Torch handles the empty axis case differently from numpy.\n return x\n axis = to_tuple_or_list(axis) # see [NB] below\n\n ori_dtype = standardize_dtype(x.dtype)\n # torch.mean only supports floating point inputs\n compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n if \"int\" in ori_dtype or ori_dtype == \"bool\":\n result_dtype = compute_dtype\n else:\n result_dtype = ori_dtype\n\n # [NB] the python torch op torch.mean() is generated into\n # `torch._C._VariableFunctions.pyi`, and the method\n # signature is overloaded.\n # Dynamo won't actually find the correct signature of\n # `torch.mean()` if arguments are passed via kwargs\n # So we have to pass the arguments via positional args\n # EXCEPT for those that are forced as kwargs via the `*`\n # delimiter in the overloaded method signatures.\n # Additionally, we have to create a singleton-tuple\n # when `axis` is an int to match the existing fn signature\n result = torch.mean(\n x,\n axis,\n keepdims,\n dtype=to_torch_dtype(compute_dtype),\n )\n return cast(result, result_dtype)\n```\n###test function signature:\n```python\n\n def test_mean(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mean", + "reference": "\n def test_mean(self):\n x = np.array([[1, 2, 3], [3, 2, 1]])\n self.assertAllClose(knp.mean(x), np.mean(x))\n self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n self.assertAllClose(\n knp.mean(x, axis=1, keepdims=True),\n np.mean(x, axis=1, keepdims=True),\n )\n\n self.assertAllClose(knp.Mean()(x), np.mean(x))\n self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n self.assertAllClose(\n knp.Mean(axis=1, keepdims=True)(x),\n np.mean(x, axis=1, keepdims=True),\n )\n\n # test overflow\n x = np.array([65504, 65504, 65504], dtype=\"float16\")\n self.assertAllClose(knp.mean(x), np.mean(x))", + "completions": [ + "```python\n\n def test_mean(self):\n x = torch.randn(2, 3, 4)\n self.assertEqual(mean(x), torch.mean(x))\n self.assertEqual(mean(x, axis=0), torch.mean(x, dim=0))\n self.assertEqual(mean(x, axis=1), torch.mean(x, dim=1))\n self.assertEqual(mean(x, axis=2), torch.mean(x, dim=2))\n self.assertEqual(mean(x, axis=(0, 1)), torch.mean(x, dim=(0, 1)))\n self.assertEqual(mean(x, axis=(0, 2)), torch.mean(x, dim=(0, 2)))\n self.assertEqual(mean(x, axis=(1, 2)), torch.mean(x, dim=(1, 2)))\n self.assertEqual(mean(x, axis=(0, 1, 2)), torch.mean(x, dim=(0, 1, 2)))\n self.assertEqual(mean(x, axis=0, keepdims=True), torch.mean(x, dim=0, keepdim=True))\n self.assertEqual(mean(x, axis=1, keepdims=True), torch.mean(x, dim=1, keepdim=True))\n self.assertEqual(mean(x, axis=2, keepdims=True), torch.mean(x, dim=2, keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1), keepdims=True), torch.mean(x, dim=(0, 1), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 2), keepdims=True), torch.mean(x, dim=(0, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(1, 2), keepdims=True), torch.mean(x, dim=(1, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1, 2), keepdims=True), torch.mean(x, dim=(0, 1, 2), keepdim=True))\n```\n" + ], + "line": 42, + "token": 322, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 51, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_intrinsic_value(\n cash_flow: float,\n growth_rate: float,\n perpetual_growth_rate: float,\n weighted_average_cost_of_capital: float,\n cash_and_cash_equivalents: float,\n total_debt: float,\n shares_outstanding: float,\n periods: int = 5,\n) -> pd.DataFrame:\n components = {}\n\n cash_flow_projection = [cash_flow]\n\n # Cash Flow to Use\n for period in range(1, periods + 1):\n if period == 1:\n cash_flow_projection.append(cash_flow_projection[0] * (1 + growth_rate))\n else:\n cash_flow_projection.append(\n cash_flow_projection[period - 1] * (1 + growth_rate)\n )\n\n # Calculate the Terminal Value\n terminal_value = (\n cash_flow_projection[-1]\n * (1 + perpetual_growth_rate)\n / (weighted_average_cost_of_capital - perpetual_growth_rate)\n )\n\n # Add Terminal Value to the end of the cash flow projection\n cash_flow_projection[-1] = cash_flow_projection[-1] + terminal_value\n\n # Calculate the Present Value based on the Discounted Cash Flow Formula\n cash_flow_present_value = []\n for index, cash_flow_value in enumerate(cash_flow_projection):\n cash_flow_present_value.append(\n cash_flow_value / (1 + weighted_average_cost_of_capital) ** (index + 1)\n )\n\n # Calculate the Enterprise Value\n enterprise_value = sum(cash_flow_present_value)\n\n # Calculate the Equity Value\n equity_value = enterprise_value + cash_and_cash_equivalents - total_debt\n\n # Calculate the Intrinsic Value\n intrinsic_value = equity_value / shares_outstanding\n\n # Combine the components into a dictionary\n components = {\n \"Terminal Value\": terminal_value,\n \"Cash Flow Projection\": cash_flow_projection[-1],\n \"Enterprise Value\": enterprise_value,\n \"Equity Value\": equity_value,\n \"Intrinsic Value\": intrinsic_value,\n }\n\n return pd.DataFrame.from_dict(\n components, orient=\"index\", columns=[f\"Periods = {periods}\"]\n )\n```\n###test function signature:\n```python\n\ndef test_get_intrinsic_value(recorder):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_intrinsic_value", + "reference": "\ndef test_get_intrinsic_value(recorder):\n recorder.capture(\n intrinsic_model.get_intrinsic_value(\n cash_flow=500,\n growth_rate=0.05,\n perpetual_growth_rate=0.02,\n weighted_average_cost_of_capital=0.1,\n cash_and_cash_equivalents=100,\n total_debt=100,\n shares_outstanding=100,\n periods=5,\n )\n )", + "completions": [ + "```python\n\ndef test_get_intrinsic_value(recorder):\n cash_flow = 1000000\n growth_rate = 0.05\n perpetual_growth_rate = 0.05\n weighted_average_cost_of_capital = 0.1\n cash_and_cash_equivalents = 1000000\n total_debt = 1000000\n shares_outstanding = 1000000\n periods = 5\n\n result = get_intrinsic_value(\n cash_flow,\n growth_rate,\n perpetual_growth_rate,\n weighted_average_cost_of_capital,\n cash_and_cash_equivalents,\n total_debt,\n shares_outstanding,\n periods,\n )\n\n recorder.capture(result)\n```\n" + ], + "line": 61, + "token": 379, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 52, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef to_numpy(\n evset: EventSet,\n timestamp_to_datetime: bool = True,\n timestamps: bool = True,\n) -> Dict[str, ndarray]:\n timestamp_key = \"timestamp\"\n index_names = evset.schema.index_names()\n feature_names = evset.schema.feature_names()\n\n column_names = index_names + feature_names\n if timestamps:\n column_names += [timestamp_key]\n\n dst = {column_name: [] for column_name in column_names}\n for index, data in evset.data.items():\n assert isinstance(index, tuple)\n\n if timestamps:\n # Timestamps\n if evset.schema.is_unix_timestamp and timestamp_to_datetime:\n dst[timestamp_key].append(\n data.timestamps.astype(\"datetime64[s]\")\n )\n else:\n dst[timestamp_key].append(data.timestamps)\n\n # Features\n for feature_name, feature in zip(feature_names, data.features):\n dst[feature_name].append(feature)\n\n # Indexes\n num_timestamps = len(data.timestamps)\n for index_name, index_item in zip(index_names, index):\n dst[index_name].append(np.repeat(index_item, num_timestamps))\n\n dst = {k: np.concatenate(v) for k, v in dst.items()}\n return dst\n```\n###test function signature:\n```python\n\n def test_correct(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct", + "reference": "\n def test_correct(self):\n evset = event_set(\n timestamps=[\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n features={\n \"feature_1\": [0.5, 0.6],\n \"my_index\": [\"red\", \"blue\"],\n },\n indexes=[\"my_index\"],\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"timestamp\": np.array(\n [\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n dtype=\"datetime64[s]\",\n ),\n \"feature_1\": np.array([0.5, 0.6]),\n \"my_index\": np.array([b\"red\", b\"blue\"]),\n }\n\n for k in expected:\n np.testing.assert_array_equal(\n np.sort(result[k]), np.sort(expected[k])\n )", + "completions": [ + "```python\n def test_correct(self):\n evset = EventSet(\n data={\n \"timestamp\": [1610000000, 1610000000, 1610000000, 1610000000],\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n },\n schema=EventSetSchema(\n index_names=[\"index\"],\n feature_names=[\"feature1\", \"feature2\"],\n is_unix_timestamp=True,\n ),\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n \"timestamp\": [\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n ],\n }\n\n self.assertDictEqual(result, expected)\n```\n" + ], + "line": 41, + "token": 276, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 53, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def from_file(cls, path: str) -> \"LayerConfig\":\n logger.info(f\"Loading route config from {path}\")\n _, ext = os.path.splitext(path)\n with open(path, \"r\") as f:\n if ext == \".json\":\n layer = json.load(f)\n elif ext in [\".yaml\", \".yml\"]:\n layer = yaml.safe_load(f)\n else:\n raise ValueError(\n \"Unsupported file type. Only .json and .yaml are supported\"\n )\n\n if not is_valid(json.dumps(layer)):\n raise Exception(\"Invalid config JSON or YAML\")\n\n encoder_type = layer[\"encoder_type\"]\n encoder_name = layer[\"encoder_name\"]\n routes = []\n for route_data in layer[\"routes\"]:\n # Handle the 'llm' field specially if it exists\n if \"llm\" in route_data and route_data[\"llm\"] is not None:\n llm_data = route_data.pop(\n \"llm\"\n ) # Remove 'llm' from route_data and handle it separately\n # Use the module path directly from llm_data without modification\n llm_module_path = llm_data[\"module\"]\n # Dynamically import the module and then the class from that module\n llm_module = importlib.import_module(llm_module_path)\n llm_class = getattr(llm_module, llm_data[\"class\"])\n # Instantiate the LLM class with the provided model name\n llm = llm_class(name=llm_data[\"model\"])\n # Reassign the instantiated llm object back to route_data\n route_data[\"llm\"] = llm\n\n # Dynamically create the Route object using the remaining route_data\n route = Route(**route_data)\n routes.append(route)\n\n return cls(\n encoder_type=encoder_type, encoder_name=encoder_name, routes=routes\n )\n```\n###test function signature:\n```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_from_file_yaml", + "reference": "\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n config_path = tmp_path / \"config.yaml\"\n config_path.write_text(\n layer_yaml()\n ) # Assuming layer_yaml() returns a valid YAML string\n\n # Load the LayerConfig from the temporary file\n layer_config = LayerConfig.from_file(str(config_path))\n\n # Assertions to verify the loaded configuration\n assert layer_config.encoder_type == \"cohere\"\n assert layer_config.encoder_name == \"embed-english-v3.0\"\n assert len(layer_config.routes) == 2\n assert layer_config.routes[0].name == \"politics\"", + "completions": [ + "```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n yaml_file = tmp_path / \"test_layer.yaml\"\n yaml_file.write_text(\n \"\"\"\n encoder_type: openai\n encoder_name: gpt2\n routes:\n - name: test_route\n model_name: test_model\n llm:\n module: tests.test_layer\n class: TestLLM\n model: test_model\n encoder_kwargs:\n max_position_embeddings: 512\n num_attention_heads: 8\n intermediate_size: 2048\n hidden_act: gelu\n hidden_dropout_prob: 0.1\n attention_probs_dropout_prob: 0.1\n type_vocab_size: 16\n initializer_range: 0.02\n layer_norm_eps: 1e-12\n use_cache: True\n is_decoder: False\n classifier_dropout: 0.1\n use_multiplicative_input_mask: False\n num_labels: 2\n output_attentions: False\n output_hidden_states: False\n return_dict: True\n task_specific_params:\n summarization:\n max_length: 20\n min_length: 5\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n" + ], + "line": 49, + "token": 405, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 54, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def bind_all(self, circuit: QuantumCircuit) -> np.ndarray:\n arr = np.empty(self.shape, dtype=object)\n for idx in np.ndindex(self.shape):\n arr[idx] = self.bind(circuit, idx)\n return arr\n```\n###test function signature:\n```python\n def test_data(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_data", + "reference": " def test_data(self):\n with self.subTest(\"binding a single value\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 0)\n self.assertEqual(ba.shape, ())\n self.assertEqual(ba.size, 1)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuit = ba.bind(self.circuit, ())\n self.assertEqual(bound_circuit, self.circuit.assign_parameters(vals))\n\n with self.subTest(\"binding an array\"):\n vals = np.linspace(0, 1, 300).reshape((2, 3, 50))\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 2)\n self.assertEqual(ba.shape, (2, 3))\n self.assertEqual(ba.size, 6)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuits = ba.bind_all(self.circuit)\n self.assertEqual(bound_circuits.shape, (2, 3))\n self.assertEqual(bound_circuits[0, 0], self.circuit.assign_parameters(vals[0, 0]))\n self.assertEqual(bound_circuits[0, 1], self.circuit.assign_parameters(vals[0, 1]))\n self.assertEqual(bound_circuits[0, 2], self.circuit.assign_parameters(vals[0, 2]))\n self.assertEqual(bound_circuits[1, 0], self.circuit.assign_parameters(vals[1, 0]))\n self.assertEqual(bound_circuits[1, 1], self.circuit.assign_parameters(vals[1, 1]))\n self.assertEqual(bound_circuits[1, 2], self.circuit.assign_parameters(vals[1, 2]))\n\n with self.subTest(\"binding a single param\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params[0]: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 1)\n self.assertEqual(ba.ndim, 1)\n self.assertEqual(ba.shape, (50,))\n self.assertEqual(ba.size, 50)\n self.assertEqual(list(ba.data.keys()), [(self.params[0].name,)])\n np.testing.assert_allclose(list(ba.data.values()), [vals[..., np.newaxis]])", + "completions": [ + "```python\n def test_data(self):\n from qiskit import QuantumCircuit\n from qiskit.quantum_info import Pauli\n from qiskit.quantum_info import PauliTable\n from qiskit.quantum_info import PauliSumOp\n from qiskit.quantum_info import PauliOp\n from qiskit.quantum_info import PauliExpectation\n from qiskit.quantum_info import PauliMoment\n from qiskit.quantum_info import PauliTensor\n from qiskit.quantum_info import PauliList\n from qiskit.quantum_info import PauliTransform\n from qiskit.quantum_info import PauliTransformOp\n from qiskit.quantum_info import PauliTransformList\n from qiskit.quantum_info import PauliTransformSum\n from qiskit.quantum_info import PauliTransformProduct\n from qiskit.quantum_info import PauliTransformExpectation\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum" + ], + "line": 15, + "token": 122, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef set_seed(\n seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None\n) -> int:\n world_size, rank = get_world_size_and_rank()\n max_val = np.iinfo(np.uint32).max - world_size + 1\n min_val = np.iinfo(np.uint32).min\n if seed is None:\n rand_seed = torch.randint(min_val, max_val, (1,))\n seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks\n if seed < min_val or seed > max_val:\n raise ValueError(\n f\"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]\"\n )\n local_seed = seed + rank\n _log.debug(\n f\"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}\"\n )\n\n torch.manual_seed(local_seed)\n np.random.seed(local_seed)\n random.seed(local_seed)\n\n if debug_mode is not None:\n _log.debug(f\"Setting deterministic debug mode to {debug_mode}\")\n torch.set_deterministic_debug_mode(debug_mode)\n deterministic_debug_mode = torch.get_deterministic_debug_mode()\n if deterministic_debug_mode == 0:\n _log.debug(\"Disabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = False\n torch.backends.cudnn.benchmark = True\n else:\n _log.debug(\"Enabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.benchmark = False\n # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n return seed\n```\n###test function signature:\n```python\n\n def test_deterministic_true(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_deterministic_true", + "reference": "\n def test_deterministic_true(self) -> None:\n for det_debug_mode, det_debug_mode_str in [(1, \"warn\"), (2, \"error\")]:\n warn_only = det_debug_mode == 1\n for debug_mode in (det_debug_mode, det_debug_mode_str):\n set_seed(42, debug_mode=debug_mode)\n assert torch.backends.cudnn.deterministic\n assert not torch.backends.cudnn.benchmark\n assert det_debug_mode == torch.get_deterministic_debug_mode()\n assert torch.are_deterministic_algorithms_enabled()\n assert (\n warn_only == torch.is_deterministic_algorithms_warn_only_enabled()\n )\n assert os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] == \":4096:8\"", + "completions": [ + "```python\n def test_deterministic_true(self) -> None:\n seed = set_seed(debug_mode=True)\n self.assertEqual(seed, 123456789)\n```\n" + ], + "line": 45, + "token": 345, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 56, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n```\n###test function signature:\n```python\n\ndef test_concat_periodarray_2d():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_concat_periodarray_2d", + "reference": "\ndef test_concat_periodarray_2d():\n pi = pd.period_range(\"2016-01-01\", periods=36, freq=\"D\")\n arr = pi._data.reshape(6, 6)\n\n result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)\n tm.assert_period_array_equal(result, arr)\n\n result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)\n tm.assert_period_array_equal(result, arr)\n\n msg = (\n \"all the input array dimensions.* for the concatenation axis must match exactly\"\n )\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)\n\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:2], arr[2:]], axis=1)", + "completions": [ + "```python\n\ndef test_concat_periodarray_2d():\n import numpy as np\n import pandas as pd\n from pandas.core.dtypes.common import is_period_dtype\n\n def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n\n def _concat_datetime(to_concat, axis: AxisInt = 0):\n from pandas.core.arrays import PeriodArray\n\n if axis == 0:\n " + ], + "line": 70, + "token": 610, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 57, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef ingest_ocp_payload(request):\n request_id = uuid4().hex\n response_data = {\"request-id\": request_id, \"payload-name\": []}\n if request.method == \"POST\":\n if not request.FILES:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for _, file in request.FILES.items():\n payload_name = file.name\n response_data[\"payload-name\"].append(payload_name)\n s3_signature = get_s3_signature(settings.S3_ENDPOINT, payload_name, method=\"put_object\")\n res = upload_file_to_s3(s3_signature, data=file.file)\n if res.status_code == HTTPStatus.OK:\n response_data[\"upload\"] = \"success\"\n else:\n response_data[\"upload\"] = \"failed\"\n response_data[\"failed-reason\"] = res.reason\n return Response(response_data, status=res.status_code)\n send_payload(request_id, payload_name)\n else:\n params = request.query_params\n payload_names = params.get(\"payload_name\")\n if not payload_names:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for payload_name in payload_names.split(\",\"):\n send_payload(request_id, payload_name)\n response_data[\"payload-name\"].append(payload_name)\n\n response_data[\"ingest-started\"] = True\n\n return Response(response_data, status=HTTPStatus.ACCEPTED)\n```\n###test function signature:\n```python\n\n def test_ingest_ocp_payload(self):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ingest_ocp_payload", + "reference": "\n def test_ingest_ocp_payload(self):\n # Arrange\n file1 = SimpleUploadedFile(\"file1.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n file2 = SimpleUploadedFile(\"file2.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n test_table = [\n # Happy path tests\n (\n \"POST\",\n {\"file1\": file1, \"file2\": file2},\n \"\",\n HTTPStatus.ACCEPTED,\n {\"upload\": \"success\", \"ingest-started\": True},\n \"happy_path_post\",\n ),\n (\"GET\", {}, \"?payload_name=file1,file2\", HTTPStatus.ACCEPTED, {\"ingest-started\": True}, \"happy_path_get\"),\n # Edge cases\n (\"POST\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_files_post\"),\n (\"GET\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_payload_name_get\"),\n # Error cases\n (\"POST\", {\"file1\": file1}, \"\", HTTPStatus.INTERNAL_SERVER_ERROR, {\"upload\": \"failed\"}, \"error_case_post\"),\n (\"GET\", {}, \"?payload_name=non_existent_file\", HTTPStatus.ACCEPTED, {}, \"error_case_get\"),\n ]\n for test in test_table:\n method, files, query_params, expected_status, expected_response, test_id = test\n with self.subTest(test=test_id):\n\n mock_response = MagicMock()\n mock_response.status_code = (\n HTTPStatus.OK if expected_response.get(\"upload\") != \"failed\" else HTTPStatus.INTERNAL_SERVER_ERROR\n )\n mock_response.reason = \"Some reason\"\n\n with patch(\"masu.api.ingest_ocp_payload.get_s3_signature\") as mock_get_s3_signature, patch(\n \"masu.api.ingest_ocp_payload.upload_file_to_s3\", return_value=mock_response\n ) as mock_upload_file_to_s3, patch(\"masu.api.ingest_ocp_payload.send_payload\") as mock_send_payload:\n\n request = (\n APIRequestFactory().get(f\"/{query_params}\")\n if method == \"GET\"\n else APIRequestFactory().post(\"/\", data=files)\n )\n\n # Act\n response = ingest_ocp_payload(request)\n\n # Assert\n self.assertEqual(response.status_code, expected_status)\n if method == \"POST\" and files:\n mock_get_s3_signature.assert_called()\n mock_upload_file_to_s3.assert_called()\n if method == \"GET\" and \"payload_name\" in query_params:\n mock_send_payload.assert_called()\n for key, value in expected_response.items():\n self.assertEqual(response.data[key], value)", + "completions": [ + "```python\n def test_ingest_ocp_payload(self):\n # Arrange\n request = self.factory.post(\n \"/api/v1/ingest/ocp/payload\",\n data={\"payload_name\": \"payload_name\"},\n content_type=\"multipart/form-data\",\n )\n request.FILES = {\"payload_name\": \"payload_name\"}\n # Act\n response = ingest_ocp_payload(request)\n # Assert\n self.assertEqual(response.status_code, HTTPStatus.ACCEPTED)\n self.assertEqual(response.data[\"request-id\"], \"request_id\")\n self.assertEqual(response.data[\"payload-name\"], [\"payload_name\"])\n self.assertEqual(response.data[\"ingest-started\"], True)\n```\n" + ], + "line": 41, + "token": 308, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 58, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def compute_nb_items_with_metadata_field(\n project_id: str,\n metadata_field: str,\n collection_name: str,\n quantile_value: float = 0.1,\n) -> Tuple[int, float, int]:\n\n mongo_db = await get_mongo_db()\n\n try:\n pipeline = [\n # Match on the project id and the existence of the metadata field\n {\n \"$match\": {\n \"project_id\": project_id,\n f\"metadata.{metadata_field}\": {\"$exists\": True, \"$ne\": None},\n }\n },\n {\n \"$group\": {\n \"_id\": \"$user_id\", # Group by 'user_id'\n \"count\": {\"$sum\": 1}, # Count the documents in each group\n }\n },\n {\n \"$sort\": {\"count\": -1} # Sort by 'count' in descending order\n },\n ]\n\n # Execute the aggregation pipeline\n results = [] # List to hold the results\n async for group in mongo_db[collection_name].aggregate(pipeline):\n results.append(group[\"count\"])\n\n logger.debug(\n f\"Results for {metadata_field} in collection {collection_name} for project {project_id}: {results}\"\n )\n\n # If no results, return 0\n if len(results) == 0:\n return 0, 0.0, 0\n\n # Get the average and the quantiles\n average = sum(results) / len(results)\n bottom_quantile = results[int(len(results) * quantile_value)]\n top_quantile = results[int(len(results) * (1 - quantile_value))]\n\n return bottom_quantile, average, top_quantile\n\n except Exception as e:\n logger.warning(\n f\"Failed to fetch the number of {metadata_field} in collection {collection_name} for project {project_id}: {e}\",\n )\n return 0, 0.0, 0\n```\n###test function signature:\n```python\n\nasync def test_main_pipeline(db, populated_project):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_main_pipeline", + "reference": "\nasync def test_main_pipeline(db, populated_project):\n async for mongo_db in db:\n test_project_id = populated_project.id\n\n # Run the function\n topics = await fetch_topics(test_project_id)\n\n logger.debug(f\"Topics: {topics}\")\n\n count = await nb_items_with_a_metadata_field(\n test_project_id, \"tasks\", \"user_id\"\n )\n\n assert count > 0\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_successrate_metadata_quantiles(test_project_id, \"user_id\")\n\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert bottom_quantile >= 0\n # assert bottom_quantile <= top_quantile\n\n test_project_id = populated_project.id\n test_metadata_field = \"user_id\"\n test_collection_name = \"tasks\"\n test_quantile_value = 0.1\n\n # Run the function\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_nb_items_with_metadata_field(\n test_project_id,\n test_metadata_field,\n test_collection_name,\n test_quantile_value,\n )\n\n logger.debug(\"runned compute_nb_items_with_metadata_field\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, int)\n assert isinstance(average, float)\n assert isinstance(top_quantile, int)\n # assert bottom_quantile <= average <= top_quantile\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_session_length_per_metadata(\n test_project_id,\n test_metadata_field,\n test_quantile_value,\n )\n logger.debug(\"runned compute_session_length_per_metadata\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, float)\n assert isinstance(average, float)\n assert isinstance(top_quantile, float)\n # assert bottom_quantile <= average <= top_quantile", + "completions": [ + "```python\nasync def test_main_pipeline(db, populated_project):\n # Arrange\n project_id = populated_project.id\n collection_name = \"collection_name\"\n metadata_field = \"metadata_field\"\n quantile_value = 0.1\n\n # Act\n result = await compute_nb_items_with_metadata_field(\n project_id, metadata_field, collection_name, quantile_value\n )\n\n # Assert\n assert result == (0, 0.0, 0)\n```\n" + ], + "line": 56, + "token": 430, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 59, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n```\n###test function signature:\n```python\n\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_engine_auto_error_message", + "reference": "\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n from pandas.compat._optional import VERSIONS\n\n # Do we have engines installed, but a bad version of them?\n pa_min_ver = VERSIONS.get(\"pyarrow\")\n fp_min_ver = VERSIONS.get(\"fastparquet\")\n have_pa_bad_version = (\n False\n if not _HAVE_PYARROW\n else Version(pyarrow.__version__) < Version(pa_min_ver)\n )\n have_fp_bad_version = (\n False\n if not _HAVE_FASTPARQUET\n else Version(fastparquet.__version__) < Version(fp_min_ver)\n )\n # Do we have usable engines installed?\n have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version\n have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version\n\n if not have_usable_pa and not have_usable_fp:\n # No usable engines found.\n if have_pa_bad_version:\n match = f\"Pandas requires version .{pa_min_ver}. or newer of .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n\n if have_fp_bad_version:\n match = f\"Pandas requires version .{fp_min_ver}. or newer of .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")", + "completions": [ + "```python\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n import pyarrow\n import fastparquet\n from pyarrow.parquet import ParquetEngine\n from fastparquet.parquet import ParquetEngine as FastParquetEngine\n\n def get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n\n # pyarrow is installed\n assert get_engine(\"auto\").__class__ == PyArrowImpl\n assert get_engine(\"pyarrow\").__class__ == PyArrowImpl\n assert get_engine(\"fastparquet\").__class__ == FastParquetImpl\n\n # pyarrow is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'pyarrow'\" in str(excinfo.value)\n\n # fastparquet is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'fastparquet'\" in str(excinfo.value)\n\n # pyarrow is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n\n # fastparquet is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n" + ], + "line": 38, + "token": 277, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 60, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def attribute(\n self,\n inp: InterpretableInput,\n target: Union[str, torch.Tensor, None] = None,\n num_trials: int = 1,\n gen_args: Optional[Dict] = None,\n # internal callback hook can be used for logging\n _inspect_forward: Optional[Callable] = None,\n **kwargs,\n ) -> LLMAttributionResult:\n\n assert isinstance(\n inp, self.SUPPORTED_INPUTS\n ), f\"LLMAttribution does not support input type {type(inp)}\"\n\n if target is None:\n # generate when None\n assert hasattr(self.model, \"generate\") and callable(self.model.generate), (\n \"The model does not have recognizable generate function.\"\n \"Target must be given for attribution\"\n )\n\n if not gen_args:\n gen_args = DEFAULT_GEN_ARGS\n\n model_inp = self._format_model_input(inp.to_model_input())\n output_tokens = self.model.generate(model_inp, **gen_args)\n target_tokens = output_tokens[0][model_inp.size(1) :]\n else:\n assert gen_args is None, \"gen_args must be None when target is given\"\n\n if type(target) is str:\n # exclude sos\n target_tokens = self.tokenizer.encode(target)[1:]\n target_tokens = torch.tensor(target_tokens)\n elif type(target) is torch.Tensor:\n target_tokens = target\n\n attr = torch.zeros(\n [\n 1 + len(target_tokens) if self.include_per_token_attr else 1,\n inp.n_itp_features,\n ],\n dtype=torch.float,\n device=self.device,\n )\n\n for _ in range(num_trials):\n attr_input = inp.to_tensor().to(self.device)\n\n cur_attr = self.attr_method.attribute(\n attr_input,\n additional_forward_args=(inp, target_tokens, _inspect_forward),\n **kwargs,\n )\n\n # temp necessary due to FA & Shapley's different return shape of multi-task\n # FA will flatten output shape internally (n_output_token, n_itp_features)\n # Shapley will keep output shape (batch, n_output_token, n_input_features)\n cur_attr = cur_attr.reshape(attr.shape)\n\n attr += cur_attr\n\n attr = attr / num_trials\n\n attr = inp.format_attr(attr)\n\n return LLMAttributionResult(\n attr[0],\n attr[1:]\n if self.include_per_token_attr\n else None, # shape(n_output_token, n_input_features)\n inp.values,\n self.tokenizer.convert_ids_to_tokens(target_tokens),\n )\n```\n###test function signature:\n```python\n\n def test_llm_attr(self, AttrClass) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_llm_attr", + "reference": "\n def test_llm_attr(self, AttrClass) -> None:\n llm = DummyLLM()\n llm.to(self.device)\n tokenizer = DummyTokenizer()\n llm_attr = LLMAttribution(AttrClass(llm), tokenizer)\n\n inp = TextTemplateInput(\"{} b {} {} e {}\", [\"a\", \"c\", \"d\", \"f\"])\n res = llm_attr.attribute(inp, \"m n o p q\")\n\n self.assertEqual(res.seq_attr.shape, (4,))\n self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))\n self.assertEqual(res.input_tokens, [\"a\", \"c\", \"d\", \"f\"])\n self.assertEqual(res.output_tokens, [\"m\", \"n\", \"o\", \"p\", \"q\"])\n self.assertEqual(res.seq_attr.device.type, self.device)\n self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)", + "completions": [ + "```python\n def test_llm_attr(self, AttrClass) -> None:\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for L" + ], + "line": 72, + "token": 464, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 61, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def plan(self, task: str):\n agent_pool_description = self.agent_pool.get_agent_pool_description()\n feedback = self.get_last_feedback()\n last_plan = self.get_last_plan()\n\n # run the planning\n plan = self.planning.execute(task, agent_pool_description, last_plan, feedback)\n\n logger.info(f\"Plan: {plan}\")\n\n self.shared_memory.add(\n event_type=EventType.planning, agent=self.name, content=str(plan)\n )\n\n return plan\n```\n###test function signature:\n```python\n\ndef test_planner_succeeds(get_llm): # noqa: F811\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_planner_succeeds", + "reference": "\ndef test_planner_succeeds(get_llm): # noqa: F811\n llm = get_llm(__file__, test_planner_succeeds.__name__)\n\n physicist_description = (\n \"The physicist agent answers questions or research about physics-related topics\"\n )\n physicist = Physicist(\n name=\"Physicist\",\n description=physicist_description,\n llm=llm,\n )\n\n agent_pool = AgentPool()\n agent_pool.add_agents([physicist])\n\n shared_memory = SharedMemory(\n objective=\"Share the information across different agents.\",\n agent_pool=agent_pool,\n )\n\n planner = Planner(\n name=\"planner\",\n agent_pool=agent_pool,\n shared_memory=shared_memory,\n llm=llm,\n )\n\n task = \"\"\"We need to render a highly complex 3D image on the solar system. We can use any publicly avaliable\n resources to achieve this task.\"\"\" # noqa: E501\n\n p = planner.plan(task=task)\n assert len(p.steps) > 0", + "completions": [ + "```python\ndef test_planner_succeeds(get_llm): # noqa: F811\n planner = get_llm\n planner.agent_pool.set_agent_pool_description(\n {\n \"agent_pool\": {\n \"agents\": [\n {\n \"name\": \"agent1\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n {\n \"name\": \"agent2\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n ]\n }\n }\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task1\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task2\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task3\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task4\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task5\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task6\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task7\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\"," + ], + "line": 21, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_edge(self, from_step: str, to_step: str) -> None:\n if from_step not in self.G:\n raise ValueError(f\"Step with name '{from_step}' does not exist\")\n\n if to_step not in self.G:\n raise ValueError(f\"Step with name '{to_step}' does not exist\")\n\n if to_step in self.G[from_step]:\n raise ValueError(\n f\"There is already a edge from '{to_step}' to '{from_step}'\"\n )\n\n if to_step in nx.ancestors(self.G, from_step):\n raise ValueError(\n f\"Cannot add edge from '{from_step}' to '{to_step}' as it would create a cycle.\"\n )\n\n self.G.add_edge(from_step, to_step)\n```\n###test function signature:\n```python\n\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validate_missing_step_input", + "reference": "\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n class DummyStep3(Step):\n @property\n def inputs(self) -> List[str]:\n return [\"instruction\"]\n\n @property\n def outputs(self) -> List[str]:\n return [\"response\"]\n\n def process(self) -> \"StepOutput\": # type: ignore\n yield [{\"response\": \"response1\"}]\n\n dag = DAG()\n dag.add_step(dummy_generator_step)\n dag.add_step(DummyStep3(name=\"dummy_step_3\", pipeline=pipeline))\n dag.add_edge(\"dummy_generator_step\", \"dummy_step_3\")\n\n with pytest.raises(\n ValueError,\n match=\"Step 'dummy_step_3' should have a parameter with type hint `StepInput`\",\n ):\n dag.validate()", + "completions": [ + "```python\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_2\")\n```\n" + ], + "line": 26, + "token": 238, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 63, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_hourly_temp_mean", + "reference": "\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n meter_data = il_electricity_cdd_hdd_hourly[\"meter_data\"][\"2016-03-01\":\"2016-07-01\"]\n temperature_data = il_electricity_cdd_hdd_hourly[\"temperature_data\"][\n \"2016-03-01\":\"2016-07-01\"\n ]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert list(sorted(df.columns)) == [\n \"n_hours_dropped\",\n \"n_hours_kept\",\n \"temperature_mean\",\n ]\n assert df.shape == (2952, 3)\n\n assert round(df.temperature_mean.mean()) == 62.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n df = il_electricity_cdd_hdd_hourly.copy()\n df = df.iloc[100:150]\n df = df.assign(\n **{\n \"cdd_10\": np.maximum(df.temperature_mean - 10, 0),\n \"cdd_20\": np.maximum(df.temperature_mean - 20, 0),\n \"hdd_10\": np.maximum(10 - df.temperature_mean, 0),\n \"hdd_20\": np.maximum(20 - df.temperature_mean, 0),\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n df = df.drop([\"temperature_mean\"], axis=1)\n df = df.reindex(il_electricity_cdd_hdd_hourly.index)\n df = overwrite_partial_rows_with_nan(df)\n df = df.iloc[:-1].reindex(df.index)\n assert_frame_equal(df, il_electricity_cdd_hdd_hourly)\n```\n" + ], + "line": 167, + "token": 992, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 64, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef pandasSQL_builder(\n con,\n schema: str | None = None,\n need_transaction: bool = False,\n) -> PandasSQL:\n import sqlite3\n\n if isinstance(con, sqlite3.Connection) or con is None:\n return SQLiteDatabase(con)\n\n sqlalchemy = import_optional_dependency(\"sqlalchemy\", errors=\"ignore\")\n\n if isinstance(con, str) and sqlalchemy is None:\n raise ImportError(\"Using URI string without sqlalchemy installed.\")\n\n if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):\n return SQLDatabase(con, schema, need_transaction)\n\n warnings.warn(\n \"pandas only supports SQLAlchemy connectable (engine/connection) or \"\n \"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 \"\n \"objects are not tested. Please consider using SQLAlchemy.\",\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return SQLiteDatabase(con)\n```\n###test function signature:\n```python\n\n def test_execute_closed_connection(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_execute_closed_connection", + "reference": "\n def test_execute_closed_connection(self):\n create_sql = \"\"\"\n CREATE TABLE test\n (\n a TEXT,\n b TEXT,\n c REAL,\n PRIMARY KEY (a, b)\n );\n \"\"\"\n with contextlib.closing(sqlite3.connect(\":memory:\")) as conn:\n cur = conn.cursor()\n cur.execute(create_sql)\n\n with sql.pandasSQL_builder(conn) as pandas_sql:\n pandas_sql.execute('INSERT INTO test VALUES(\"foo\", \"bar\", 1.234)')\n\n msg = \"Cannot operate on a closed database.\"\n with pytest.raises(sqlite3.ProgrammingError, match=msg):\n tquery(\"select * from test\", con=conn)", + "completions": [ + "```python\n def test_execute_closed_connection(self):\n from pandas.io.sql import PandasSQL\n from pandas.io.sql.sqldatabase import SQLiteDatabase\n from pandas.io.sql.pandas_sql import PandasSQLQueryCompiler\n from pandas.io.sql.pandas_sql import PandasSQLEngine\n from pandas.io.sql.pandas_sql import PandasSQLTable\n from pandas.io.sql.pandas_sql import PandasSQLConverter\n from pandas.io.sql.pandas_sql import PandasSQLAlchemy\n from pandas.io.sql.pandas_sql import PandasSQLQuery\n from pandas.io.sql.pandas_sql import PandasSQLExecutionContext\n from pandas.io.sql.pandas_sql import PandasSQLCompiler\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas." + ], + "line": 31, + "token": 237, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 65, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_hourly_degree_days", + "reference": "\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index,\n temperature_data,\n heating_balance_points=[60, 61],\n cooling_balance_points=[65, 66],\n temperature_mean=False,\n degree_day_method=\"hourly\",\n )\n assert df.shape == (27, 6)\n assert list(sorted(df.columns)) == [\n \"cdd_65\",\n \"cdd_66\",\n \"hdd_60\",\n \"hdd_61\",\n \"n_hours_dropped\",\n \"n_hours_kept\",\n ]\n snapshot.assert_match(\n [\n round(df.hdd_60.mean(), 2),\n round(df.hdd_61.mean(), 2),\n round(df.cdd_65.mean(), 2),\n round(df.cdd_66.mean(), 2),\n round(df.n_hours_kept.mean(), 2),\n round(df.n_hours_dropped.mean(), 2),\n ],\n \"values\",\n )", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"H\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 1" + ], + "line": 168, + "token": 985, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 66, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef initialize_model_parallel(\n tensor_model_parallel_size: int = 1,\n pipeline_model_parallel_size: int = 1,\n virtual_pipeline_model_parallel_size: Optional[int] = None,\n pipeline_model_parallel_split_rank: Optional[int] = None,\n use_fp8: bool = False,\n) -> None:\n # Get world size and rank. Ensure some consistencies.\n assert torch.distributed.is_initialized()\n world_size: int = torch.distributed.get_world_size()\n\n if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:\n raise RuntimeError(\n f\"world_size ({world_size}) is not divisible by tensor_model_parallel_size \"\n f\"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n )\n\n data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n\n num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n num_data_parallel_groups: int = world_size // data_parallel_size\n\n if virtual_pipeline_model_parallel_size is not None:\n if not pipeline_model_parallel_size > 2:\n raise RuntimeError(\"pipeline-model-parallel size should be greater than 2 with \" \"interleaved schedule\")\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size\n\n if pipeline_model_parallel_split_rank is not None:\n global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK\n _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank\n\n rank = torch.distributed.get_rank()\n\n # Build the data-parallel groups.\n global _DATA_PARALLEL_GROUP\n global _DATA_PARALLEL_GROUP_GLOO\n global _DATA_PARALLEL_GLOBAL_RANKS\n assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'\n all_data_parallel_group_ranks = []\n for i in range(pipeline_model_parallel_size):\n start_rank = i * num_pipeline_model_parallel_groups\n end_rank = (i + 1) * num_pipeline_model_parallel_groups\n for j in range(tensor_model_parallel_size):\n ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)\n all_data_parallel_group_ranks.append(list(ranks))\n group = torch.distributed.new_group(ranks)\n group_gloo = torch.distributed.new_group(ranks, backend=\"gloo\")\n if rank in ranks:\n _DATA_PARALLEL_GROUP = group\n _DATA_PARALLEL_GROUP_GLOO = group_gloo\n _DATA_PARALLEL_GLOBAL_RANKS = ranks\n\n # Build the model-parallel groups.\n global _MODEL_PARALLEL_GROUP\n assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'\n for i in range(data_parallel_size):\n ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _MODEL_PARALLEL_GROUP = group\n\n # Build the tensor model-parallel groups.\n global _TENSOR_MODEL_PARALLEL_GROUP\n assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized'\n for i in range(num_tensor_model_parallel_groups):\n ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _TENSOR_MODEL_PARALLEL_GROUP = group\n\n # Build the pipeline model-parallel groups and embedding groups\n # (first and last rank in each pipeline model-parallel group).\n global _PIPELINE_MODEL_PARALLEL_GROUP\n global _PIPELINE_GLOBAL_RANKS\n assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized'\n global _EMBEDDING_GROUP\n global _EMBEDDING_GLOBAL_RANKS\n assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'\n global _POSITION_EMBEDDING_GROUP\n global _POSITION_EMBEDDING_GLOBAL_RANKS\n assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'\n for i in range(num_pipeline_model_parallel_groups):\n ranks = range(i, world_size, num_pipeline_model_parallel_groups)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _PIPELINE_MODEL_PARALLEL_GROUP = group\n _PIPELINE_GLOBAL_RANKS = ranks\n # Setup embedding group (to exchange gradients between\n # first and last stages).\n if len(ranks) > 1:\n embedding_ranks = [ranks[0], ranks[-1]]\n position_embedding_ranks = [ranks[0]]\n if pipeline_model_parallel_split_rank is not None:\n if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:\n embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]]\n if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:\n position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]\n else:\n embedding_ranks = ranks\n position_embedding_ranks = ranks\n\n group = torch.distributed.new_group(embedding_ranks)\n if rank in embedding_ranks:\n _EMBEDDING_GROUP = group\n if rank in ranks:\n _EMBEDDING_GLOBAL_RANKS = embedding_ranks\n\n group = torch.distributed.new_group(position_embedding_ranks)\n if rank in position_embedding_ranks:\n _POSITION_EMBEDDING_GROUP = group\n if rank in ranks:\n _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks\n\n # Build the FP8 groups.\n global _AMAX_REDUCTION_GROUP\n assert _AMAX_REDUCTION_GROUP is None, \\\n 'FP8 amax reduction group is already initialized'\n if use_fp8:\n amax_group_size: int = tensor_model_parallel_size * data_parallel_size\n num_amax_groups: int = world_size // amax_group_size\n for i in range(num_amax_groups):\n start_rank = i * amax_group_size\n end_rank = (i + 1) * amax_group_size\n ranks = range(start_rank, end_rank)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _AMAX_REDUCTION_GROUP = group\n\n # Initialize global memory buffer\n # This isn't really \"parallel state\" but there isn't another good place to\n # put this. If we end up with a more generic initialization of megatron-core\n # we could stick it there\n _set_global_memory_buffer()\n```\n###test function signature:\n```python\n\ndef test_initialize__and_destroy_model_parallel():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_initialize__and_destroy_model_parallel", + "reference": "\ndef test_initialize__and_destroy_model_parallel():\n with pytest.raises(AssertionError):\n assert(ps.initialize_model_parallel())\n Utils.initialize_distributed()\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))\n Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)\n\n assert(ps.model_parallel_is_initialized())\n assert(ps.get_model_parallel_group() is not None)\n assert(ps.get_tensor_model_parallel_group() is not None)\n assert(ps.get_pipeline_model_parallel_group() is not None)\n assert(ps.get_data_parallel_group() is not None) \n Utils.destroy_model_parallel()\n assert(ps._MODEL_PARALLEL_GROUP is None)", + "completions": [ + "```python\ndef test_initialize_and_destroy_model_parallel():\n initialize_model_parallel(\n tensor_model_parallel_size=1,\n pipeline_model_parallel_size=1,\n virtual_pipeline_model_parallel_size=None,\n pipeline_model_parallel_split_rank=None,\n use_fp8=False,\n )\n destroy_model_parallel()\n```\n" + ], + "line": 133, + "token": 886, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 67, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_shapes(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_shapes", + "reference": "\n def test_spectrogram_shapes(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 729))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 183))", + "completions": [ + "```python\n def test_spectrogram_shapes(self):\n waveform = np.random.rand(1000)\n window = np.hanning(100)\n frame_length = 100\n hop_length = 10\n fft_length = 100\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (10, 100))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 68, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_base64_image_given_and_no_resize_needed", + "reference": "\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n _, image = example_local_image\n _, buffer = cv2.imencode(\".jpg\", image)\n base64_image = base64.b64encode(buffer).decode(\"ascii\")\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=base64_image,\n max_width=None,\n max_height=None,\n )\n\n # then\n assert (\n scaling_factor is None\n ), \"No resize parameters given, so scaling factor should not be established\"\n assert (\n serialised_image == base64_image\n ), \"Serialised image should be identical with input\"", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n image_bytes = base64.b64decode(example_local_image[0])\n image = bytes_to_opencv_image(payload=image_bytes)\n # when\n result = load_image_from_string(\n reference=example_local_image[0],\n max_height=None,\n max_width=None,\n )\n # then\n assert result[0] == example_local_image[0]\n assert result[1] is None\n assert np.array_equal(image, result[1])\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 69, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def submit_payout_address(self, robot_index=1):\n path = reverse(\"order\")\n params = f\"?order_id={self.order_id}\"\n headers = self.get_robot_auth(robot_index)\n\n payout_address = create_address(\"robot\")\n signed_payout_address = sign_message(\n payout_address,\n passphrase_path=f\"tests/robots/{robot_index}/token\",\n private_key_path=f\"tests/robots/{robot_index}/enc_priv_key\",\n )\n body = {\n \"action\": \"update_address\",\n \"address\": signed_payout_address,\n \"mining_fee_rate\": 50,\n }\n self.response = self.client.post(path + params, body, **headers)\n```\n###test function signature:\n```python\n def test_successful_onchain(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_successful_onchain", + "reference": " def test_successful_onchain(self):\n trade = Trade(self.client)\n trade.publish_order()\n trade.take_order()\n trade.lock_taker_bond()\n trade.lock_escrow(trade.taker_index)\n trade.submit_payout_address(trade.maker_index)\n trade.confirm_fiat(trade.maker_index)\n trade.confirm_fiat(trade.taker_index)\n\n trade.process_payouts(mine_a_block=True)\n trade.get_order(trade.maker_index)\n\n data = trade.response.json()\n\n self.assertEqual(trade.response.status_code, 200)\n self.assertResponse(trade.response)\n\n self.assertEqual(data[\"status_message\"], Order.Status(Order.Status.SUC).label)\n self.assertTrue(data[\"is_fiat_sent\"])\n self.assertFalse(data[\"is_disputed\"])\n self.assertIsInstance(data[\"maker_summary\"][\"address\"], str)\n self.assertIsHash(data[\"maker_summary\"][\"txid\"])\n\n self.assert_order_logs(data[\"id\"])", + "completions": [ + "```python\n def test_successful_onchain(self):\n self.submit_payout_address(robot_index=1)\n self.assertEqual(self.response.status_code, 200)\n self.assertEqual(self.response.json()[\"status\"], \"success\")\n```\n" + ], + "line": 26, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 70, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_blank_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_blank_line", + "reference": "\ndef test_header_to_metadata_and_cell_blank_line():\n text = \"\"\"---\ntitle: Sample header\n---\n\nHeader is followed by a blank line\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {}\n assert lines[pos].startswith(\"Header is\")", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_blank_line():\n lines = [\"#' ---\", \"\"]\n header_prefix = \"#'\"\n header_suffix = \"\"\n ext = None\n root_level_metadata_as_raw_cell = True\n metadata, jupyter, cell, i = header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext, root_level_metadata_as_raw_cell\n )\n assert metadata == {}\n assert jupyter == []\n assert cell == new_raw_cell(source=\"\\n\".join([\"---\"]), metadata={\"lines_to_next_cell\": 1})\n assert i == 1\n```\n" + ], + "line": 104, + "token": 642, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 71, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def _create_output(\n self,\n accepted: torch.Tensor, # [batch_size, k]\n recovered_token_ids: torch.Tensor, # [batch_size, k]\n draft_token_ids: torch.Tensor, # [batch_size, k]\n bonus_token_ids: torch.Tensor, # [batch_size]\n ) -> torch.Tensor:\n bonus_token_ids = bonus_token_ids.squeeze()\n batch_size, k = recovered_token_ids.shape\n\n # Determine the index of the first False value for each row.\n limits = (accepted == 0).max(1).indices\n limits[~(accepted == 0).any(1)] = k\n\n # Create masks using the indices.\n indices = torch.arange(k, device=accepted.device).unsqueeze(0)\n accepted_mask = indices < limits.unsqueeze(1)\n after_false_mask = indices == limits.unsqueeze(1)\n\n # Create an extended output tensor\n output_with_bonus_tokens = -torch.ones(\n (batch_size, k + self._num_bonus_tokens),\n dtype=self.token_id_dtype,\n device=accepted.device)\n output = output_with_bonus_tokens[:, :k]\n\n # Fill in the first k columns of the output tensor using masks and data\n # tensors.\n output[:, :k] = torch.where(accepted_mask, draft_token_ids,\n -torch.ones_like(draft_token_ids))\n\n # Fill the last column.\n # We check output directly as accepted may have True values inconsistent\n # with causal acceptance.\n output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,\n bonus_token_ids, -1)\n\n # We disable bonus tokens because it causes corrupt KV cache for\n # proposal methods that require KV cache. We can fix it by \"prefilling\"\n # the bonus token in the proposer. The following issue tracks the fix.\n # https://github.com/vllm-project/vllm/issues/4212\n output_with_bonus_tokens[:, -1] = -1\n\n # Fill the recovered token ids.\n output.mul_(~after_false_mask).add_(\n recovered_token_ids.mul(after_false_mask))\n\n self.num_accepted_tokens += accepted.sum()\n self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()\n self.num_draft_tokens += batch_size * k\n\n return output_with_bonus_tokens\n```\n###test function signature:\n```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct_output_format", + "reference": "def test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n set_random_seed(seed)\n torch.set_default_device(device)\n\n batch_size = 10\n k = 5\n vocab_size = 3000\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"no_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"some_tokens_accepted\":\n last_accepted_indices = torch.randint(low=-1,\n high=k,\n size=(batch_size, ))\n accepted = mock_causal_accepted_tensor(k, last_accepted_indices)\n else:\n raise AssertionError()\n\n recovered_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n draft_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n bonus_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, 1),\n dtype=torch.int64)\n\n rejection_sampler = RejectionSampler()\n rejection_sampler.init_gpu_tensors(rank=0)\n output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access\n accepted,\n recovered_token_ids,\n draft_token_ids,\n bonus_token_ids,\n )\n\n # Bonus tokens are currently disabled. Verify they're set to -1.\n # See https://github.com/vllm-project/vllm/issues/4212\n expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n # Expect all tokens to be equal to draft tokens.\n assert torch.equal(output_token_ids[:, :-1], draft_token_ids)\n\n # Expect all bonus tokens to be included.\n assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)\n elif which_tokens_accepted == \"no_tokens_accepted\":\n # Expect first token to be equal to recovered tokens.\n assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])\n\n # Expect everything else to be -1.\n assert torch.equal(output_token_ids[:, 1:],\n torch.ones_like(output_token_ids[:, 1:]) * -1)\n elif which_tokens_accepted == \"some_tokens_accepted\":\n recovered_plus_bonus = torch.cat(\n (recovered_token_ids, expected_bonus_token_ids), dim=-1)\n # Assert first rejected token is a recovered token or bonus token.\n assert torch.equal(\n recovered_plus_bonus[torch.arange(0, batch_size),\n last_accepted_indices + 1],\n output_token_ids[torch.arange(0, batch_size),\n last_accepted_indices + 1])\n\n # Assert every subsequent token is -1.\n subsequent_mask = torch.arange(0, k + 1).expand(\n batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)\n assert torch.all(output_token_ids[subsequent_mask] == -1)", + "completions": [ + "```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n torch.manual_seed(seed)\n batch_size = 10\n k = 10\n num_bonus_tokens = 1\n accepted = torch.randint(0, 2, (batch_size, k), dtype=torch.int32)\n recovered_token_ids = torch.randint(0, 10, (batch_size, k),\n dtype=torch.int32)\n draft_token_ids = torch.randint(0, 10, (batch_size, k), dtype=torch.int32)\n bonus_token_ids = torch.randint(0, 10, (batch_size,), dtype=torch.int32)\n output = _create_output(accepted, recovered_token_ids, draft_token_ids,\n bonus_token_ids)\n assert output.shape == (batch_size, k + num_bonus_tokens)\n assert output.dtype == torch.int32\n assert output.device == device\n assert (output[:, :k] == torch.where(accepted, draft_token_ids,\n -torch.ones_like(draft_token_ids))).all()\n assert (output[:, -1] == torch.where(output[:, -1] != -1, bonus_token_ids,\n -1)).all()\n assert (output[:, -1] == -1).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted" + ], + "line": 54, + "token": 433, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 72, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_array(fp, allow_pickle=False, pickle_kwargs=None, *,\n max_header_size=_MAX_HEADER_SIZE):\n if allow_pickle:\n # Effectively ignore max_header_size, since `allow_pickle` indicates\n # that the input is fully trusted.\n max_header_size = 2**64\n\n version = read_magic(fp)\n _check_version(version)\n shape, fortran_order, dtype = _read_array_header(\n fp, version, max_header_size=max_header_size)\n if len(shape) == 0:\n count = 1\n else:\n count = numpy.multiply.reduce(shape, dtype=numpy.int64)\n\n # Now read the actual data.\n if dtype.hasobject:\n # The array contained Python objects. We need to unpickle the data.\n if not allow_pickle:\n raise ValueError(\"Object arrays cannot be loaded when \"\n \"allow_pickle=False\")\n if pickle_kwargs is None:\n pickle_kwargs = {}\n try:\n array = pickle.load(fp, **pickle_kwargs)\n except UnicodeError as err:\n # Friendlier error message\n raise UnicodeError(\"Unpickling a python object failed: %r\\n\"\n \"You may need to pass the encoding= option \"\n \"to numpy.load\" % (err,)) from err\n else:\n if isfileobj(fp):\n # We can use the fast fromfile() function.\n array = numpy.fromfile(fp, dtype=dtype, count=count)\n else:\n # This is not a real file. We have to read it the\n # memory-intensive way.\n # crc32 module fails on reads greater than 2 ** 32 bytes,\n # breaking large reads from gzip streams. Chunk reads to\n # BUFFER_SIZE bytes to avoid issue and reduce memory overhead\n # of the read. In non-chunked case count < max_read_count, so\n # only one read is performed.\n\n # Use np.ndarray instead of np.empty since the latter does\n # not correctly instantiate zero-width string dtypes; see\n # https://github.com/numpy/numpy/pull/6430\n array = numpy.ndarray(count, dtype=dtype)\n\n if dtype.itemsize > 0:\n # If dtype.itemsize == 0 then there's nothing more to read\n max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)\n\n for i in range(0, count, max_read_count):\n read_count = min(max_read_count, count - i)\n read_size = int(read_count * dtype.itemsize)\n data = _read_bytes(fp, read_size, \"array data\")\n array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n count=read_count)\n\n if fortran_order:\n array.shape = shape[::-1]\n array = array.transpose()\n else:\n array.shape = shape\n\n return array\n```\n###test function signature:\n```python\n\ndef test_version_2_0():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_version_2_0", + "reference": "\ndef test_version_2_0():\n f = BytesIO()\n # requires more than 2 byte for header\n dt = [((\"%d\" % i) * 100, float) for i in range(500)]\n d = np.ones(1000, dtype=dt)\n\n format.write_array(f, d, version=(2, 0))\n with warnings.catch_warnings(record=True) as w:\n warnings.filterwarnings('always', '', UserWarning)\n format.write_array(f, d)\n assert_(w[0].category is UserWarning)\n\n # check alignment of data portion\n f.seek(0)\n header = f.readline()\n assert_(len(header) % format.ARRAY_ALIGN == 0)\n\n f.seek(0)\n n = format.read_array(f, max_header_size=200000)\n assert_array_equal(d, n)\n\n # 1.0 requested but data cannot be saved this way\n assert_raises(ValueError, format.write_array, f, d, (1, 0))", + "completions": [ + "```python\ndef test_version_2_0():\n fp = open(\"test_file.npy\", \"rb\")\n read_array(fp)\n```\n" + ], + "line": 70, + "token": 516, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 73, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):\n if not torch.is_tensor(audio):\n if isinstance(audio, str):\n audio = load_audio(audio)\n audio = torch.from_numpy(audio)\n\n window = torch.hann_window(N_FFT).to(audio.device)\n stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)\n magnitudes = stft[..., :-1].abs() ** 2\n\n filters = mel_filters(audio.device, n_mels)\n mel_spec = filters @ magnitudes\n\n log_spec = torch.clamp(mel_spec, min=1e-10).log10()\n log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)\n log_spec = (log_spec + 4.0) / 4.0\n return log_spec\n```\n###test function signature:\n```python\n\ndef test_audio():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_audio", + "reference": "\ndef test_audio():\n audio_path = os.path.join(os.path.dirname(__file__), \"jfk.flac\")\n audio = load_audio(audio_path)\n assert audio.ndim == 1\n assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12\n assert 0 < audio.std() < 1\n\n mel_from_audio = log_mel_spectrogram(audio)\n mel_from_file = log_mel_spectrogram(audio_path)\n\n assert np.allclose(mel_from_audio, mel_from_file)\n assert mel_from_audio.max() - mel_from_audio.min() <= 2.0", + "completions": [ + "```python\ndef test_audio():\n audio = torch.randn(1, 16000)\n log_mel_spectrogram(audio)\n```\n" + ], + "line": 24, + "token": 213, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 74, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef transform_grid_cell(\n grid: Grid,\n cell: npt.NDArray[np.float_] = np.eye(3),\n output_shape: Optional[Tuple[int, int, int]] = None,\n mode: str = \"constant\",\n order: int = 1,\n **kwargs,\n) -> Grid:\n # Take the current shape of the grid if no output shape was provided\n if output_shape is None:\n output_shape = grid.shape\n\n # Make sure the cell has type float\n cell = np.asarray(cell, dtype=float)\n\n # Get the current cell in coordinates of the destination axes\n inv_cell = cell_invert(cell).T\n projected_cell = grid.cell.dot(inv_cell)\n\n # From that, infere how long will the bounding box of the cell be\n lengths = abs(projected_cell).sum(axis=0)\n\n # Create the transformation matrix. Since we want to control the shape\n # of the output, we can not use grid.dcell directly, we need to modify it.\n scales = output_shape / lengths\n forward_t = (grid.dcell.dot(inv_cell) * scales).T\n\n # Scipy's affine transform asks for the inverse transformation matrix, to\n # map from output pixels to input pixels. By taking the inverse of our\n # transformation matrix, we get exactly that.\n tr = cell_invert(forward_t).T\n\n # Calculate the offset of the image so that all points of the grid \"fall\" inside\n # the output array.\n # For this we just calculate the centers of the input and output images\n center_input = 0.5 * (_a.asarrayd(grid.shape) - 1)\n center_output = 0.5 * (_a.asarrayd(output_shape) - 1)\n\n # And then make sure that the input center that is interpolated from the output\n # falls in the actual input's center\n offset = center_input - tr.dot(center_output)\n\n # We pass all the parameters to scipy's affine_transform\n transformed_image = affine_transform(\n grid.grid,\n tr,\n order=1,\n offset=offset,\n output_shape=output_shape,\n mode=mode,\n **kwargs,\n )\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1))\n new_grid.grid = transformed_image\n new_grid.geometry = grid.geometry\n new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset))\n\n # Find the offset between the origin before and after the transformation\n return new_grid\n```\n###test function signature:\n```python\n\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_grid_cell", + "reference": "\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10))\n\n assert new_grid.shape == (10, 10, 10)\n assert new_grid.lattice.is_cartesian()\n\n if not skewed:\n assert np.allclose(new_grid.lattice.cell, grid.lattice.cell)\n\n assert np.allclose(new_grid.grid, grid.grid) == (not skewed)\n assert not np.allclose(new_grid.grid, 0)\n\n assert np.allclose(new_grid.origin, grid.origin) == (not skewed)\n\n # Convert to a skewed cell\n directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]])\n new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5))\n\n assert new_grid.shape == (5, 5, 5)\n for i in range(3):\n n = new_grid.lattice.cell[i] / directions[i]\n assert np.allclose(n, n[0])", + "completions": [ + "```python\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n cell = grid.cell.dot(cell_invert(grid.cell))\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = transform_grid_cell(grid, cell, grid.shape, mode=\"constant\", order=1)\n\n # Check that the new grid has the same shape and cell as the original grid\n assert new_grid.shape == grid.shape\n assert np.allclose(new_grid.cell, grid.cell)\n\n # Check that the new grid has the same geometry as the original grid\n assert new_grid.geometry == grid.geometry\n\n # Check that the new grid has the same origin as the original grid\n assert np.allclose(new_grid.origin, grid.origin)\n\n # Check that the new grid has the same dcell as the original grid\n assert np.allclose(new_grid.dcell, grid.dcell)\n\n # Check that the new grid has the same grid as the original grid\n assert np.allclose(new_grid.grid, grid.grid)\n\n # Check that the new grid has the same skewed as the original grid\n assert np.allclose(new_grid.skewed, grid.skewed)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid" + ], + "line": 63, + "token": 533, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 75, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:\n\n index_norm = norm(index, axis=1)\n xq_norm = norm(xq.T)\n sim = np.dot(index, xq.T) / (index_norm * xq_norm)\n return sim\n```\n###test function signature:\n```python\ndef test_top_scores__is_sorted(test_index):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_top_scores__is_sorted", + "reference": "def test_top_scores__is_sorted(test_index):\n\n xq = test_index[0] # should have max similarity\n\n sim = similarity_matrix(xq, test_index)\n _, idx = top_scores(sim, 3)\n\n # Scores and indexes should be sorted ascending\n assert np.array_equal(idx, np.array([2, 1, 0]))", + "completions": [ + "```python\ndef test_top_scores__is_sorted(test_index):\n sim = similarity_matrix(xq, test_index)\n assert np.all(np.diff(sim, axis=1) >= 0)\n```\n" + ], + "line": 15, + "token": 125, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_mixed_ref_area_2", + "reference": "\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n\n out_bars_with_mixed_ref_area_2 = _generate_nanoplot(**CASES[16])\n\n assert _is_nanoplot_output(out_bars_with_mixed_ref_area_2)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n y_ref_area = [1, 10]\n y_ref_line = 5\n x_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n expand_x = \"2020-01-01 00:00:00 UTC\"\n expand_y = \"2020-01-01 00:00:00 UTC\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n all_single_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_" + ], + "line": 749, + "token": 7038, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 77, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def mean(self, lower=0.0, upper=np.inf):\n loc = []\n for idx in self.index:\n qpd = self.qpd.loc[idx, :].values[0]\n l, _ = quad(exp_func, args=(qpd), a=lower, b=upper)\n loc.append(l)\n loc_arr = np.array(loc)\n return pd.DataFrame(loc_arr, index=self.index, columns=self.columns)\n```\n###test function signature:\n```python\ndef test_qpd_b_simple_use():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_qpd_b_simple_use", + "reference": "def test_qpd_b_simple_use():\n from skpro.distributions.qpd import QPD_B\n\n qpd = QPD_B(\n alpha=0.2,\n qv_low=[1, 2],\n qv_median=[3, 4],\n qv_high=[5, 6],\n lower=0,\n upper=10,\n )\n\n qpd.mean()", + "completions": [ + "```python\ndef test_qpd_b_simple_use():\n import numpy as np\n import pandas as pd\n from scipy.integrate import quad\n from scipy.special import expit\n\n def exp_func(x, qpd):\n return expit(qpd * x)\n\n qpd = pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z'])\n qpd_b = qpd.mean(lower=0.0, upper=np.inf)\n assert qpd_b.equals(pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z']))\n```\n" + ], + "line": 18, + "token": 146, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef create(\n supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False\n) -> adapters.HTTPAdapter:\n settings = supply_settings()\n allowed_methods = set(util.Retry.DEFAULT_ALLOWED_METHODS)\n if include_post:\n allowed_methods.add(\"POST\")\n\n retries = util.Retry(\n total=settings.total,\n backoff_factor=settings.backoff_factor,\n status_forcelist=settings.statuses,\n allowed_methods=allowed_methods,\n )\n return adapters.HTTPAdapter(max_retries=retries)\n```\n###test function signature:\n```python\n\ndef test_should_create_adapter(include_post):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_should_create_adapter", + "reference": "\ndef test_should_create_adapter(include_post):\n total_retries = 5\n backoff_factor = 0.4\n statuses = (random.randint(400, 600),)\n settings = config.ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses)\n adapter = http_adapter.create(supply_settings=lambda: settings, include_post=include_post)\n assert adapter.max_retries.total == total_retries\n assert adapter.max_retries.backoff_factor == backoff_factor\n assert adapter.max_retries.status_forcelist == statuses\n allowed_methods = adapter.max_retries.allowed_methods or []\n if include_post:\n assert \"POST\" in allowed_methods\n else:\n assert \"POST\" not in allowed_methods", + "completions": [ + "```python\n\ndef test_should_create_adapter(include_post):\n assert create(include_post=include_post) == adapters.HTTPAdapter(max_retries=util.Retry(\n total=config.ResilienceSettings().total,\n backoff_factor=config.ResilienceSettings().backoff_factor,\n status_forcelist=config.ResilienceSettings().statuses,\n allowed_methods=set(util.Retry.DEFAULT_ALLOWED_METHODS) if not include_post else set(util.Retry.DEFAULT_ALLOWED_METHODS).union({\"POST\"})\n ))\n```\n" + ], + "line": 24, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 79, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:\n summary_path = os.path.join(trial_path, 'summary.csv')\n if not os.path.exists(summary_path):\n raise ValueError(f\"summary.csv does not exist in {trial_path}.\")\n trial_summary_df = load_summary_file(summary_path, dict_columns=['best_module_params'])\n config_yaml_path = os.path.join(trial_path, 'config.yaml')\n with open(config_yaml_path, 'r') as f:\n config_dict = yaml.safe_load(f)\n yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)\n if output_path is not None:\n with open(output_path, 'w') as f:\n yaml.dump(yaml_dict, f)\n return yaml_dict\n```\n###test function signature:\n```python\n\ndef test_runner(evaluator):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_runner", + "reference": "\ndef test_runner(evaluator):\n os.environ['BM25'] = 'bm25'\n evaluator.start_trial(os.path.join(resource_dir, 'simple.yaml'))\n project_dir = evaluator.project_dir\n\n def runner_test(runner: Runner):\n answer = runner.run('What is the best movie in Korea? Have Korea movie ever won Oscar?',\n 'retrieved_contents')\n assert len(answer) == 1\n assert isinstance(answer, list)\n assert isinstance(answer[0], str)\n\n runner = Runner.from_trial_folder(os.path.join(project_dir, '0'))\n runner_test(runner)\n runner_test(runner)\n\n with tempfile.NamedTemporaryFile(suffix='yaml', mode='w+t') as yaml_path:\n extract_best_config(os.path.join(project_dir, '0'), yaml_path.name)\n runner = Runner.from_yaml(yaml_path.name, project_dir=project_dir)\n runner_test(runner)", + "completions": [ + "```python\ndef test_runner(evaluator):\n def test_extract_best_config():\n trial_path = 'path/to/trial'\n output_path = 'path/to/output'\n yaml_dict = extract_best_config(trial_path, output_path)\n assert yaml_dict is not None\n assert isinstance(yaml_dict, dict)\n assert 'best_module_params' in yaml_dict\n assert isinstance(yaml_dict['best_module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['" + ], + "line": 23, + "token": 205, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 80, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def init_weights(self):\n logger = MMLogger.get_current_instance()\n if self.init_cfg is None:\n logger.warn(f'No pre-trained weights for '\n f'{self.__class__.__name__}, '\n f'training start from scratch')\n if self.use_abs_pos_embed:\n trunc_normal_(self.absolute_pos_embed, std=0.02)\n for m in self.modules():\n if isinstance(m, nn.Linear):\n trunc_normal_init(m, std=.02, bias=0.)\n elif isinstance(m, nn.LayerNorm):\n constant_init(m, 1.0)\n else:\n assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n f'specify `Pretrained` in ' \\\n f'`init_cfg` in ' \\\n f'{self.__class__.__name__} '\n ckpt = CheckpointLoader.load_checkpoint(\n self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n if 'state_dict' in ckpt:\n _state_dict = ckpt['state_dict']\n elif 'model' in ckpt:\n _state_dict = ckpt['model']\n else:\n _state_dict = ckpt\n if self.convert_weights:\n # supported loading weight from original repo,\n _state_dict = swin_converter(_state_dict)\n\n state_dict = OrderedDict()\n for k, v in _state_dict.items():\n if k.startswith('backbone.'):\n state_dict[k[9:]] = v\n\n # strip prefix of state_dict\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n # reshape absolute position embedding\n if state_dict.get('absolute_pos_embed') is not None:\n absolute_pos_embed = state_dict['absolute_pos_embed']\n N1, L, C1 = absolute_pos_embed.size()\n N2, C2, H, W = self.absolute_pos_embed.size()\n if N1 != N2 or C1 != C2 or L != H * W:\n logger.warning('Error in loading absolute_pos_embed, pass')\n else:\n state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n # interpolate position bias table if needed\n relative_position_bias_table_keys = [\n k for k in state_dict.keys()\n if 'relative_position_bias_table' in k\n ]\n for table_key in relative_position_bias_table_keys:\n table_pretrained = state_dict[table_key]\n table_current = self.state_dict()[table_key]\n L1, nH1 = table_pretrained.size()\n L2, nH2 = table_current.size()\n if nH1 != nH2:\n logger.warning(f'Error in loading {table_key}, pass')\n elif L1 != L2:\n S1 = int(L1**0.5)\n S2 = int(L2**0.5)\n table_pretrained_resized = F.interpolate(\n table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n size=(S2, S2),\n mode='bicubic')\n state_dict[table_key] = table_pretrained_resized.view(\n nH2, L2).permute(1, 0).contiguous()\n\n # load state_dict\n self.load_state_dict(state_dict, False)\n```\n###test function signature:\n```python\ndef test_swin_transformer():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_swin_transformer", + "reference": "def test_swin_transformer():\n\n with pytest.raises(TypeError):\n # Pretrained arg must be str or None.\n SwinTransformer(pretrained=123)\n\n with pytest.raises(AssertionError):\n # Because swin uses non-overlapping patch embed, so the stride of patch\n # embed must be equal to patch size.\n SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)\n\n # test pretrained image size\n with pytest.raises(AssertionError):\n SwinTransformer(pretrain_img_size=(224, 224, 224))\n\n # Test absolute position embedding\n temp = torch.randn((1, 3, 224, 224))\n model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)\n model.init_weights()\n model(temp)\n\n # Test patch norm\n model = SwinTransformer(patch_norm=False)\n model(temp)\n\n # Test normal inference\n temp = torch.randn((1, 3, 32, 32))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 31, 31))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 112, 137))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 28, 35)\n assert outs[1].shape == (1, 192, 14, 18)\n assert outs[2].shape == (1, 384, 7, 9)\n assert outs[3].shape == (1, 768, 4, 5)\n\n model = SwinTransformer(frozen_stages=4)\n model.train()\n for p in model.parameters():\n assert not p.requires_grad", + "completions": [ + "```python\ndef test_swin_transformer():\n from mmcv.cnn import (build_pos_embed, build_pos_embed_sine,\n constant_init, trunc_normal_init)\n from mmcv.runner import CheckpointLoader\n from mmdet.models.builder import build_backbone\n from mmdet.models.utils import swin_converter\n from mmdet.utils import get_root_logger\n from mmdet.utils.logger import MMLogger\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import PatchEmbed\n from mmdet.models.backbones.swin import PatchMerging\n from mmdet.models.backbones.swin import BasicLayer\n from mmdet.models.backbones.swin import BasicBlock\n from mmdet.models.backbones.swin import DepthwiseSeparableConv\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransform" + ], + "line": 79, + "token": 586, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 81, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\n\ndef test_date_columns_equal():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_date_columns_equal", + "reference": "\ndef test_date_columns_equal():\n data = \"\"\"a|b|expected\n2017-01-01|2017-01-01|True\n2017-01-02|2017-01-02|True\n2017-10-01|2017-10-10|False\n2017-01-01||False\n|2017-01-01|False\n||True\"\"\"\n df = pl.read_csv(\n io.StringIO(data),\n separator=\"|\",\n null_values=[\"NULL\"],\n missing_utf8_is_empty_string=True,\n )\n # First compare just the strings\n actual_out = columns_equal(df[\"a\"], df[\"b\"], rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)\n\n # Then compare converted to datetime objects\n col_a = df[\"a\"].str.to_datetime(strict=False)\n col_b = df[\"b\"].str.to_datetime(strict=False)\n actual_out = columns_equal(col_a, col_b, rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)", + "completions": [ + "```python\ndef test_date_columns_equal():\n col_1 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n col_2 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 82, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def upsert_child_firing(\n db: PrefectDBInterface,\n session: AsyncSession,\n firing: Firing,\n):\n automation_id = firing.trigger.automation.id\n parent_trigger_id = firing.trigger.parent.id\n child_trigger_id = firing.trigger.id\n\n upsert = (\n postgresql.insert(db.CompositeTriggerChildFiring)\n .values(\n automation_id=automation_id,\n parent_trigger_id=parent_trigger_id,\n child_trigger_id=child_trigger_id,\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n )\n .on_conflict_do_update(\n index_elements=[\n db.CompositeTriggerChildFiring.automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id,\n ],\n set_=dict(\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n updated=pendulum.now(\"UTC\"),\n ),\n )\n )\n\n await session.execute(upsert)\n\n result = await session.execute(\n sa.select(db.CompositeTriggerChildFiring).filter(\n db.CompositeTriggerChildFiring.automation_id == automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id == parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id == child_trigger_id,\n )\n )\n\n return result.scalars().one()\n```\n###test function signature:\n```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_insert_compound_trigger_child_firing", + "reference": "\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n firing = Firing(\n trigger=saw_a_red_dragon,\n trigger_states={TriggerState.Triggered},\n triggered=frozen_time.add(seconds=1),\n triggering_labels={},\n triggering_event=baby_red_dragon_passed_by,\n )\n res = await upsert_child_firing(\n session=automations_session,\n firing=firing,\n )\n\n assert res and res.id\n assert res.automation_id == raise_the_alarm.id\n assert res.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert res.child_trigger_id == saw_a_red_dragon.id\n assert res.child_firing_id == firing.id\n assert res.child_fired_at == firing.triggered\n\n child_firings = await get_child_firings(\n session=automations_session, trigger=saw_a_red_and_green_dragon\n )\n assert len(child_firings) == 1\n assert child_firings[0].id == res.id\n assert child_firings[0].child_firing_id == firing.id", + "completions": [ + "```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n \"\"\"\n Test that a compound trigger child firing is inserted correctly\n \"\"\"\n # Arrange\n firing = Firing(\n id=uuid4(),\n triggered=frozen_time,\n trigger=saw_a_red_and_green_dragon,\n event=baby_red_dragon_passed_by,\n )\n\n # Act\n result = await upsert_child_firing(\n db=automations_session,\n session=automations_session,\n firing=firing,\n )\n\n # Assert\n assert result.automation_id == raise_the_alarm.id\n assert result.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert result.child_trigger_id == saw_a_red_dragon.id\n assert result.child_firing_id == firing.id\n assert result.child_fired_at == firing.triggered\n assert result.child_firing == firing.dict()\n assert result.updated == frozen_time\n\n```\n" + ], + "line": 58, + "token": 225, + "line_diff": 3, + "token_diff": 1 + }, + { + "id": 83, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef write_output(output_file, output_type, results):\n\n # If json or txt, write contents directly to file.\n if output_type == \"json\" or output_type == \"txt\":\n with open(output_file, \"w\") as f:\n json.dump(results, f, indent=4)\n return\n # If csv or xlsx, convert results to pandas dataframes.\n urls_df = get_urls_df(results)\n codes_df = get_codes_df(results)\n\n # If csv, write dataframes to separate csv files for urls, codes.\n if output_type == \"csv\":\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n urls_df.to_csv(urls_output_file, index=False)\n codes_df.to_csv(codes_output_file, index=False)\n\n # If xlsx, write dataframes to separate sheets for urls, codes.\n if output_type == \"xlsx\":\n writer = pd.ExcelWriter(output_file, engine=\"xlsxwriter\")\n urls_df.to_excel(writer, sheet_name=\"URLs\", index=False)\n codes_df.to_excel(writer, sheet_name=\"Codes\", index=False)\n writer.close()\n```\n###test function signature:\n```python\n def test_write_output_csv(self, mock_urls, mock_codes):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_output_csv", + "reference": " def test_write_output_csv(self, mock_urls, mock_codes):\n\n test_file = \"./test_output/test_file.csv\"\n test_file_urls = \"./test_output/test_file_urls.csv\"\n test_file_codes = \"./test_output/test_file_codes.csv\"\n test_results = {\"test\": \"test\"}\n mock_urls.return_value = pd.DataFrame([test_results])\n mock_codes.return_value = pd.DataFrame([test_results])\n\n with open(test_file_urls, \"w\") as f:\n pass\n\n with open(test_file_codes, \"w\") as f:\n pass\n\n write_output(test_file, \"csv\", test_results)\n\n test_data_urls = pd.read_csv(test_file_urls).to_dict(orient=\"records\")[0]\n test_data_codes = pd.read_csv(test_file_codes).to_dict(orient=\"records\")[0]\n\n os.remove(test_file_urls)\n os.remove(test_file_codes)\n\n self.assertEqual(test_data_urls, test_results)\n self.assertEqual(test_data_codes, test_results)", + "completions": [ + "```python\ndef test_write_output_csv(self, mock_urls, mock_codes):\n \"\"\"\n Test write_output function with csv output type.\n \"\"\"\n output_file = \"output.csv\"\n output_type = \"csv\"\n results = {\"urls\": mock_urls, \"codes\": mock_codes}\n write_output(output_file, output_type, results)\n\n # Check that the urls and codes dataframes are written to separate csv files.\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n self.assertTrue(os.path.isfile(urls_output_file))\n self.assertTrue(os.path.isfile(codes_output_file))\n\n # Check that the dataframes are correct.\n urls_df = pd.read_csv(urls_output_file)\n codes_df = pd.read_csv(codes_output_file)\n self.assertEqual(urls_df.shape, (2, 2))\n self.assertEqual(codes_df.shape, (2, 2))\n self.assertEqual(urls_df.iloc[0, 0], \"url1\")\n self.assertEqual(codes_df.iloc[0, 0], \"code1\")\n self.assertEqual(urls_df.iloc[1, 0], \"url2\")\n self.assertEqual(codes_df.iloc[1, 0], \"code2\")\n```\n" + ], + "line": 31, + "token": 279, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 84, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def terminate_idle_instance() -> None:\n async with get_session_ctx() as session:\n async with PROCESSING_POOL_LOCK:\n res = await session.execute(\n select(InstanceModel)\n .where(\n InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE,\n InstanceModel.deleted == False,\n InstanceModel.job == None, # noqa: E711\n InstanceModel.status == InstanceStatus.IDLE,\n )\n .options(joinedload(InstanceModel.project))\n )\n instances = res.scalars().all()\n\n for instance in instances:\n last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)\n if instance.last_job_processed_at is not None:\n last_time = instance.last_job_processed_at.replace(\n tzinfo=datetime.timezone.utc\n )\n\n idle_seconds = instance.termination_idle_time\n delta = datetime.timedelta(seconds=idle_seconds)\n\n current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc)\n\n if last_time + delta < current_time:\n jpd = JobProvisioningData.__response__.parse_raw(\n instance.job_provisioning_data\n )\n await terminate_job_provisioning_data_instance(\n project=instance.project, job_provisioning_data=jpd\n )\n instance.deleted = True\n instance.deleted_at = get_current_datetime()\n instance.finished_at = get_current_datetime()\n instance.status = InstanceStatus.TERMINATED\n instance.termination_reason = \"Idle timeout\"\n\n idle_time = current_time - last_time\n logger.info(\n \"Instance %s terminated by termination policy: idle time %ss\",\n instance.name,\n str(idle_time.seconds),\n extra={\n \"instance_name\": instance.name,\n \"instance_status\": InstanceStatus.TERMINATED.value,\n },\n )\n\n await session.commit()\n```\n###test function signature:\n```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_terminate_by_idle_timeout", + "reference": "\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n project = await create_project(session=session)\n pool = await create_pool(session, project)\n\n instance = await create_instance(session, project, pool, status=InstanceStatus.IDLE)\n instance.termination_idle_time = 300\n instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE\n instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)\n await session.commit()\n\n with patch(\n \"dstack._internal.server.background.tasks.process_instances.terminate_job_provisioning_data_instance\"\n ):\n await terminate_idle_instance()\n\n await session.refresh(instance)\n\n assert instance is not None\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"", + "completions": [ + "```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n # given\n project = ProjectModel(name=\"test_project\")\n await session.add(project)\n await session.commit()\n\n instance = InstanceModel(\n name=\"test_instance\",\n project=project,\n termination_policy=TerminationPolicy.DESTROY_AFTER_IDLE,\n termination_idle_time=10,\n status=InstanceStatus.IDLE,\n created_at=datetime.datetime.now(datetime.timezone.utc),\n )\n await session.add(instance)\n await session.commit()\n\n # when\n await terminate_idle_instance()\n\n # then\n instance = await session.get(InstanceModel, instance.id)\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"\n```\n" + ], + "line": 56, + "token": 289, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 85, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef nargsort(\n items,\n kind: str = \"quicksort\",\n ascending: bool = True,\n na_position: str = \"last\",\n key: Callable | None = None,\n mask: npt.NDArray[np.bool_] | None = None,\n) -> npt.NDArray[np.intp]:\n\n if key is not None:\n items = ensure_key_mapped(items, key)\n return nargsort(\n items,\n kind=kind,\n ascending=ascending,\n na_position=na_position,\n key=None,\n mask=mask,\n )\n\n if isinstance(items, ABCRangeIndex):\n return items.argsort(ascending=ascending) # TODO: test coverage with key?\n elif not isinstance(items, ABCMultiIndex):\n items = extract_array(items)\n if mask is None:\n mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?\n\n if is_extension_array_dtype(items):\n return items.argsort(ascending=ascending, kind=kind, na_position=na_position)\n else:\n items = np.asanyarray(items)\n\n idx = np.arange(len(items))\n non_nans = items[~mask]\n non_nan_idx = idx[~mask]\n\n nan_idx = np.nonzero(mask)[0]\n if not ascending:\n non_nans = non_nans[::-1]\n non_nan_idx = non_nan_idx[::-1]\n indexer = non_nan_idx[non_nans.argsort(kind=kind)]\n if not ascending:\n indexer = indexer[::-1]\n # Finally, place the NaNs at the end or the beginning according to\n # na_position\n if na_position == \"last\":\n indexer = np.concatenate([indexer, nan_idx])\n elif na_position == \"first\":\n indexer = np.concatenate([nan_idx, indexer])\n else:\n raise ValueError(f\"invalid na_position: {na_position}\")\n return ensure_platform_int(indexer)\n```\n###test function signature:\n```python\n\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nargsort", + "reference": "\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)\n\n # mergesort is the most difficult to get right because we want it to be\n # stable.\n\n # According to numpy/core/tests/test_multiarray, \"\"\"The number of\n # sorted items must be greater than ~50 to check the actual algorithm\n # because quick and merge sort fall over to insertion sort for small\n # arrays.\"\"\"\n\n result = nargsort(\n items, kind=\"mergesort\", ascending=ascending, na_position=na_position\n )\n tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)", + "completions": [ + "```python\n def test_nargsort(self, ascending, na_position, exp, box):\n items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216," + ], + "line": 58, + "token": 412, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 86, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_rotated_rb_circuits(\n n_qubits: int,\n num_cliffords: int,\n theta: Optional[float] = None,\n trials: int = 1,\n return_type: Optional[str] = None,\n seed: Optional[int] = None,\n) -> List[QPROGRAM]:\n\n circuits = cast(\n List[cirq.Circuit],\n generate_rb_circuits(n_qubits, num_cliffords, trials, seed=seed),\n )\n\n if theta is None:\n theta = 2 * np.pi * np.random.rand()\n\n for circ in circuits:\n qubits = list(circ.all_qubits())\n circ.insert(len(circ) // 2, cirq.Rz(rads=theta).on(qubits[0]))\n\n return_type = \"cirq\" if not return_type else return_type\n return [convert_from_mitiq(circuit, return_type) for circuit in circuits]\n```\n###test function signature:\n```python\n\ndef test_rotated_rb_circuits(n_qubits, theta):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rotated_rb_circuits", + "reference": "\ndef test_rotated_rb_circuits(n_qubits, theta):\n depth = 10\n for trials in [2, 3]:\n circuits = generate_rotated_rb_circuits(\n n_qubits=n_qubits, num_cliffords=depth, theta=theta, trials=trials\n )\n for circ in circuits:\n zero_prob = (\n cirq.DensityMatrixSimulator()\n .simulate(circ)\n .final_density_matrix[0, 0]\n .real\n )\n assert -1.0001 <= zero_prob <= 1.0001", + "completions": [ + "```python\ndef test_rotated_rb_circuits(n_qubits, theta):\n from cirq import Circuit, LineQubit\n from cirq.testing import assert_has_diagram\n from mitiq.generate import generate_rotated_rb_circuits\n\n circuits = generate_rotated_rb_circuits(n_qubits, 10, theta)\n for circuit in circuits:\n assert_has_diagram(\n circuit,\n f\"\"\"\n0: ───Rz({theta})───\n1: ───Rz({theta})───\n2: ───Rz({theta})───\n3: ───Rz({theta})───\n4: ───Rz({theta})───\n5: ───Rz({theta})───\n6: ───Rz({theta})───\n7: ───Rz({theta})───\n8: ───Rz({theta})───\n9: ───Rz({theta})───\n10: ───Rz({theta})───\n11: ───Rz({theta})───\n12: ───Rz({theta})───\n13: ───Rz({theta})───\n14: ───Rz({theta})───\n15: ───Rz({theta})───\n16: ───Rz({theta})───\n17: ───Rz({theta})───\n18: ───Rz({theta})───\n19: ───Rz({theta})───\n20: ───Rz({theta})───\n21: ───Rz({theta})───\n22: ───Rz({theta})───\n23: ───Rz({theta})───\n24: ───Rz({theta})───\n25: ───Rz({theta})───\n26: ───Rz({theta})───\n27: ───Rz({theta})───\n28: ───Rz({theta})───\n29: ───Rz({theta})───\n30: ───Rz({theta})───\n31: ───Rz({theta})───\n32: ───Rz({theta})───\n33: ───Rz({theta})───\n34: ───Rz({theta})───\n35: ───Rz({theta})───\n36: ───Rz({theta})───\n37: ───Rz({theta})───\n38: ───Rz({theta})───\n39: ───Rz({theta})───\n40: ───Rz({theta})───\n41: ───Rz({theta})───\n42: ───Rz({theta})───\n43: ───Rz({theta})───\n44: ───Rz({theta})───\n45: ───Rz({theta})───\n46: ───Rz({theta})───\n47: ───Rz({theta})───\n48: ───Rz({theta})───\n49: ───Rz({theta})───\n50: ───Rz({theta})───\n51: ───Rz({theta})───\n52: ───Rz({theta})───\n53: ───Rz({theta})───\n54: ───Rz({theta})───\n55: ───Rz({theta})───\n56: ───Rz({theta})───\n57: ───Rz({theta})───\n58: ───Rz({theta})───\n59: ───Rz({theta})───\n60: ───Rz({theta})───\n61: ───Rz({theta})───\n62: ───Rz({theta})───\n63: ───" + ], + "line": 29, + "token": 226, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 87, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef eye(N, M=None, k=None, dtype=None):\n dtype = to_torch_dtype(dtype or config.floatx())\n M = N if M is None else M\n k = 0 if k is None else k\n if k == 0:\n # TODO: torch.eye doesn't support bfloat16 with cpu\n if get_device() == \"cpu\" and dtype == torch.bfloat16:\n return cast(\n torch.eye(\n N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n ),\n dtype,\n )\n return torch.eye(N, M, dtype=dtype, device=get_device())\n diag_length = builtins.max(N, M)\n diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n return torch.diag(diag, diagonal=k)[:N, :M]\n```\n###test function signature:\n```python\n\n def test_eye(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_eye", + "reference": "\n def test_eye(self):\n self.assertAllClose(knp.eye(3), np.eye(3))\n self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n self.assertAllClose(knp.Eye()(3), np.eye(3))\n self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1))\n\n # Test k >= N\n self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3))\n\n # Test k > 0 and N >= M\n self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1))\n\n # Test k > 0 and N < M and N + k > M\n self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2))\n\n # Test k < 0 and M >= N\n self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1))\n\n # Test k < 0 and M < N and M - k > N\n self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))", + "completions": [ + "```python\n\n def test_eye(self):\n self.assertEqual(eye(3), torch.eye(3))\n self.assertEqual(eye(3, 4), torch.eye(3, 4))\n self.assertEqual(eye(3, 4, 1), torch.eye(3, 4, 1))\n self.assertEqual(eye(3, 4, -1), torch.eye(3, 4, -1))\n self.assertEqual(eye(3, 4, 2), torch.eye(3, 4, 2))\n self.assertEqual(eye(3, 4, -2), torch.eye(3, 4, -2))\n self.assertEqual(eye(3, 4, 3), torch.eye(3, 4, 3))\n self.assertEqual(eye(3, 4, -3), torch.eye(3, 4, -3))\n self.assertEqual(eye(3, 4, 4), torch.eye(3, 4, 4))\n self.assertEqual(eye(3, 4, -4), torch.eye(3, 4, -4))\n self.assertEqual(eye(3, 4, 5), torch.eye(3, 4, 5))\n self.assertEqual(eye(3, 4, -5), torch.eye(3, 4, -5))\n self.assertEqual(eye(3, 4, 6), torch.eye(3, 4, 6))\n self.assertEqual(eye(3, 4, -6), torch.eye(3, 4, -6))\n self.assertEqual(eye(3, 4, 7), torch.eye(3, 4, 7))\n self.assertEqual(eye(3, 4, -7), torch.eye(3, 4, -7))\n self.assertEqual(eye(3, 4, 8), torch.eye(3, 4, 8))\n self.assertEqual(eye(3, 4, -8), torch.eye(3, 4, -8))\n self.assertEqual(eye(3, 4, 9), torch.eye(3, 4, 9))\n self.assertEqual(eye(3, 4, -9), torch.eye(3, 4, -9))\n self.assertEqual(eye(3, 4, 10), torch.eye(3, 4, 10))\n self.assertEqual(eye(3, 4, -10), torch.eye(3, 4, -10))\n self.assertEqual(eye(3, 4, 11), torch.eye(3, 4, 11))\n self.assertEqual(eye(3, 4, -11), torch.eye(3, 4, -11))\n self.assertEqual(eye(3, 4, 12), torch.eye(3, 4, 12))\n self.assertEqual(eye(3, 4, -12), torch.eye(3, 4, -12))\n self.assertEqual(eye(3, 4, 13), torch.eye(3, 4, 13))\n self.assertEqual(eye(3, 4, -13), torch.eye(3, 4, -13))\n self.assertEqual(eye(3, 4, 14), torch.eye(3, 4, 14))\n self.assertEqual(eye(3, 4, -14), torch.eye(3, 4, -14))\n self.assertEqual(eye(3, 4, 15), torch.eye(" + ], + "line": 27, + "token": 219, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 88, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef ragged_qk_dotprod(\n query: RaggedActivations, key: RaggedActivations, lut: RaggedQkPidLookupTable\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n # check constraints\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n # allocates output\n # max_n_ctx_q_across_seqs = query.max_n_ctx_per_seq\n\n assert query.n_seqs == key.n_seqs\n # TODO: flag use zeros for garbage\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n # Integers\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n```\n###test function signature:\n```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ragged_qk_dotprod_single_seq", + "reference": "\ndef test_ragged_qk_dotprod_single_seq(dtype):\n d_head = 2\n\n key = RaggedActivations.from_list(\n [\n make_seq(n_ctx=3, value=42, d_model=d_head, dtype=dtype),\n ]\n )\n query = RaggedActivations.from_list(\n [\n make_seq(n_ctx=4, value=55, d_model=d_head, dtype=dtype),\n ]\n )\n torch_scores = scores_via_qk_dotprod(query, key)\n print(f\"{torch_scores=}\")\n\n lut = RaggedQkPidLookupTable.from_query_and_key_tokens_per_seq(\n n_ctx_q_per_seq=query.n_ctx_per_seq, n_ctx_k_per_seq=key.n_ctx_per_seq\n )\n\n scores = ragged_qk_dotprod(query, key, lut)\n assert_eq(torch_scores, scores)", + "completions": [ + "```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n # query\n query_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n query = RaggedActivations(\n raw_tensor=query_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # key\n key_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n key = RaggedActivations(\n raw_tensor=key_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # lut\n pid_to_in_q_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_in_k_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_q_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_k_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_seq_idx = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n lut = RaggedQkPidLookupTable(\n pid_to_in_q_token_offset=pid_to_in_q_token_offset,\n pid_to_in_k_token_offset=pid_to_in_k_token_offset,\n pid_to_out_q_block=pid_to_out_q_block,\n pid_to_out_k_block=pid_to_out_k_block,\n pid_to_out_seq_idx=pid_to_out_seq_idx,\n n_pids_total=1,\n )\n\n # scores_out\n scores_out = torch.ones(\n (1, 1, 1), dtype=dtype, device=torch.device(\"cuda\")\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_" + ], + "line": 57, + "token": 310, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 89, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):\n # Ravel both arrays, behavior for the first array could be different\n ar1 = np.asarray(ar1).ravel()\n ar2 = np.asarray(ar2).ravel()\n\n # Ensure that iteration through object arrays yields size-1 arrays\n if ar2.dtype == object:\n ar2 = ar2.reshape(-1, 1)\n\n if kind not in {None, 'sort', 'table'}:\n raise ValueError(\n f\"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.\")\n\n # Can use the table method if all arrays are integers or boolean:\n is_int_arrays = all(ar.dtype.kind in (\"u\", \"i\", \"b\") for ar in (ar1, ar2))\n use_table_method = is_int_arrays and kind in {None, 'table'}\n\n if use_table_method:\n if ar2.size == 0:\n if invert:\n return np.ones_like(ar1, dtype=bool)\n else:\n return np.zeros_like(ar1, dtype=bool)\n\n # Convert booleans to uint8 so we can use the fast integer algorithm\n if ar1.dtype == bool:\n ar1 = ar1.astype(np.uint8)\n if ar2.dtype == bool:\n ar2 = ar2.astype(np.uint8)\n\n ar2_min = np.min(ar2)\n ar2_max = np.max(ar2)\n\n ar2_range = int(ar2_max) - int(ar2_min)\n\n # Constraints on whether we can actually use the table method:\n # 1. Assert memory usage is not too large\n below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)\n # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype\n range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max\n # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype\n if ar1.size > 0:\n ar1_min = np.min(ar1)\n ar1_max = np.max(ar1)\n\n # After masking, the range of ar1 is guaranteed to be\n # within the range of ar2:\n ar1_upper = min(int(ar1_max), int(ar2_max))\n ar1_lower = max(int(ar1_min), int(ar2_min))\n\n range_safe_from_overflow &= all((\n ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,\n ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min\n ))\n\n # Optimal performance is for approximately\n # log10(size) > (log10(range) - 2.27) / 0.927.\n # However, here we set the requirement that by default\n # the intermediate array can only be 6x\n # the combined memory allocation of the original\n # arrays. See discussion on \n # https://github.com/numpy/numpy/pull/12065.\n\n if (\n range_safe_from_overflow and \n (below_memory_constraint or kind == 'table')\n ):\n\n if invert:\n outgoing_array = np.ones_like(ar1, dtype=bool)\n else:\n outgoing_array = np.zeros_like(ar1, dtype=bool)\n\n # Make elements 1 where the integer exists in ar2\n if invert:\n isin_helper_ar = np.ones(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 0\n else:\n isin_helper_ar = np.zeros(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 1\n\n # Mask out elements we know won't work\n basic_mask = (ar1 <= ar2_max) & (ar1 >= ar2_min)\n outgoing_array[basic_mask] = isin_helper_ar[ar1[basic_mask] -\n ar2_min]\n\n return outgoing_array\n elif kind == 'table': # not range_safe_from_overflow\n raise RuntimeError(\n \"You have specified kind='table', \"\n \"but the range of values in `ar2` or `ar1` exceed the \"\n \"maximum integer of the datatype. \"\n \"Please set `kind` to None or 'sort'.\"\n )\n elif kind == 'table':\n raise ValueError(\n \"The 'table' method is only \"\n \"supported for boolean or integer arrays. \"\n \"Please select 'sort' or None for kind.\"\n )\n\n\n # Check if one of the arrays may contain arbitrary objects\n contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject\n\n # This code is run when\n # a) the first condition is true, making the code significantly faster\n # b) the second condition is true (i.e. `ar1` or `ar2` may contain\n # arbitrary objects), since then sorting is not guaranteed to work\n if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:\n if invert:\n mask = np.ones(len(ar1), dtype=bool)\n for a in ar2:\n mask &= (ar1 != a)\n else:\n mask = np.zeros(len(ar1), dtype=bool)\n for a in ar2:\n mask |= (ar1 == a)\n return mask\n\n # Otherwise use sorting\n if not assume_unique:\n ar1, rev_idx = np.unique(ar1, return_inverse=True)\n ar2 = np.unique(ar2)\n\n ar = np.concatenate((ar1, ar2))\n # We need this to be a stable sort, so always use 'mergesort'\n # here. The values from the first array should always come before\n # the values from the second array.\n order = ar.argsort(kind='mergesort')\n sar = ar[order]\n if invert:\n bool_ar = (sar[1:] != sar[:-1])\n else:\n bool_ar = (sar[1:] == sar[:-1])\n flag = np.concatenate((bool_ar, [invert]))\n ret = np.empty(ar.shape, dtype=bool)\n ret[order] = flag\n\n if assume_unique:\n return ret[:len(ar1)]\n else:\n return ret[rev_idx]\n```\n###test function signature:\n```python\n\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_in1d_ravel", + "reference": "\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n a = np.arange(6).reshape(2, 3)\n b = np.arange(3, 9).reshape(3, 2)\n long_b = np.arange(3, 63).reshape(30, 2)\n ec = np.array([False, False, False, True, True, True])\n\n assert_array_equal(in1d(a, b, assume_unique=True, kind=kind),\n ec)\n assert_array_equal(in1d(a, b, assume_unique=False,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=True,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=False,\n kind=kind),\n ec)", + "completions": [ + "```python\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n ar1 = np.array([1, 2, 3, 4, 5])\n ar2 = np.array([1, 2, 3, 4, 5])\n ar1_ravel = ar1.ravel()\n ar2_ravel = ar2.ravel()\n self.assertTrue(np.array_equal(ar1_ravel, ar1))\n self.assertTrue(np.array_equal(ar2_ravel, ar2))\n self.assertTrue(np.array_equal(in1d(ar1, ar2, kind=kind),\n in1d(ar1_ravel, ar2_ravel, kind=kind)))\n```\n" + ], + "line": 133, + "token": 1112, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 90, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_metadata", + "reference": "\ndef test_header_to_metadata_and_cell_metadata():\n text = \"\"\"---\ntitle: Sample header\njupyter:\n mainlanguage: python\n---\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {\"mainlanguage\": \"python\"}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {\"lines_to_next_cell\": 0}\n assert pos == len(lines)", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_metadata():\n lines = [\n \"#' ---\",\n \"#' jupyter:\",\n \"#' jupytext:\",\n \"#' text_representation:\",\n \"#' extension: .py\",\n \"#' format_name: light\",\n \"#' format_version: '1.5'\",\n \"#' jupytext_version: 1.11.4\",\n \"#' kernelspec:\",\n \"#' display_name: Python 3\",\n \"#' language: python\",\n \"#' name: python3\",\n \"#' ---\",\n \"#'\",\n \"#' This is a comment\",\n \"#'\",\n \"#' And another one\",\n \"#'\",\n \"#' ```python\",\n \"#' import numpy as np\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' import matplotlib.pyplot as plt\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %matplotlib inline\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2" + ], + "line": 104, + "token": 642, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 91, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef import_optional_dependency(\n name: str,\n extra: str = \"\",\n errors: str = \"raise\",\n min_version: str | None = None,\n):\n\n assert errors in {\"warn\", \"raise\", \"ignore\"}\n\n package_name = INSTALL_MAPPING.get(name)\n install_name = package_name if package_name is not None else name\n\n msg = (\n f\"Missing optional dependency '{install_name}'. {extra} \"\n f\"Use pip or conda to install {install_name}.\"\n )\n try:\n module = importlib.import_module(name)\n except ImportError:\n if errors == \"raise\":\n raise ImportError(msg)\n return None\n\n # Handle submodules: if we have submodule, grab parent module from sys.modules\n parent = name.split(\".\")[0]\n if parent != name:\n install_name = parent\n module_to_get = sys.modules[install_name]\n else:\n module_to_get = module\n minimum_version = min_version if min_version is not None else VERSIONS.get(parent)\n if minimum_version:\n version = get_version(module_to_get)\n if version and Version(version) < Version(minimum_version):\n msg = (\n f\"Pandas requires version '{minimum_version}' or newer of '{parent}' \"\n f\"(version '{version}' currently installed).\"\n )\n if errors == \"warn\":\n warnings.warn(\n msg,\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return None\n elif errors == \"raise\":\n raise ImportError(msg)\n\n return module\n```\n###test function signature:\n```python\n\ndef test_bad_version(monkeypatch):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_bad_version", + "reference": "\ndef test_bad_version(monkeypatch):\n name = \"fakemodule\"\n module = types.ModuleType(name)\n module.__version__ = \"0.9.0\"\n sys.modules[name] = module\n monkeypatch.setitem(VERSIONS, name, \"1.0.0\")\n\n match = \"Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'\"\n with pytest.raises(ImportError, match=match):\n import_optional_dependency(\"fakemodule\")\n\n # Test min_version parameter\n result = import_optional_dependency(\"fakemodule\", min_version=\"0.8\")\n assert result is module\n\n with tm.assert_produces_warning(UserWarning):\n result = import_optional_dependency(\"fakemodule\", errors=\"warn\")\n assert result is None\n\n module.__version__ = \"1.0.0\" # exact match is OK\n result = import_optional_dependency(\"fakemodule\")\n assert result is module", + "completions": [ + "```python\ndef test_bad_version(monkeypatch):\n monkeypatch.setattr(import_optional_dependency, \"get_version\", lambda x: \"0.1.0\")\n monkeypatch.setattr(import_optional_dependency, \"VERSIONS\", {\"pandas\": \"0.2.0\"})\n with pytest.raises(ImportError):\n import_optional_dependency(\"pandas\")\n```\n" + ], + "line": 54, + "token": 352, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 92, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_mel_spectrogram(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mel_spectrogram", + "reference": "\n def test_mel_spectrogram(self):\n waveform = self._load_datasamples(1)[0]\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=513,\n num_mel_filters=13,\n min_frequency=100,\n max_frequency=4000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"htk\",\n )\n self.assertEqual(mel_filters.shape, (513, 13))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n )\n self.assertEqual(spec.shape, (513, 732))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n mel_filters=mel_filters,\n )\n self.assertEqual(spec.shape, (13, 732))\n\n # fmt: off\n expected = np.array([\n 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,\n 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,\n 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,\n 9.44153646e-04\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:, 300], expected))", + "completions": [ + "```python\n def test_mel_spectrogram(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n mel_floor = 1e-10\n log_mel = \"log\"\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (1, 10))\n self.assertEqual(spectrogram[0][0], 1)\n self.assertEqual(spectrogram[0][1], 2)\n self.assertEqual(spectrogram[0][2], 3)\n self.assertEqual(spectrogram[0][3], 4)\n self.assertEqual(spectrogram[0][4], 5)\n self.assertEqual(spectrogram[0][5], 6)\n self.assertEqual(spectrogram[0][6], 7)\n self.assertEqual(spectrogram[0][7], 8)\n self.assertEqual(spectrogram[0][8], 9)\n self.assertEqual(spectrogram[0][9], 10)\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 93, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef create(\n meta_schema,\n validators=(),\n version=None,\n type_checker=_types.draft202012_type_checker,\n format_checker=_format.draft202012_format_checker,\n id_of=_id_of,\n applicable_validators=methodcaller(\"items\"),\n):\n # preemptively don't shadow the `Validator.format_checker` local\n format_checker_arg = format_checker\n\n @attr.s\n class Validator:\n\n VALIDATORS = dict(validators)\n META_SCHEMA = dict(meta_schema)\n TYPE_CHECKER = type_checker\n FORMAT_CHECKER = format_checker_arg\n ID_OF = staticmethod(id_of)\n\n schema = attr.ib(repr=reprlib.repr)\n resolver = attr.ib(default=None, repr=False)\n format_checker = attr.ib(default=None)\n\n def __init_subclass__(cls):\n warnings.warn(\n (\n \"Subclassing validator classes is not intended to \"\n \"be part of their public API. A future version \"\n \"will make doing so an error, as the behavior of \"\n \"subclasses isn't guaranteed to stay the same \"\n \"between releases of jsonschema. Instead, prefer \"\n \"composition of validators, wrapping them in an object \"\n \"owned entirely by the downstream library.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n\n def __attrs_post_init__(self):\n if self.resolver is None:\n self.resolver = RefResolver.from_schema(\n self.schema,\n id_of=id_of,\n )\n\n @classmethod\n def check_schema(cls, schema, format_checker=_UNSET):\n Validator = validator_for(cls.META_SCHEMA, default=cls)\n if format_checker is _UNSET:\n format_checker = Validator.FORMAT_CHECKER\n validator = Validator(\n schema=cls.META_SCHEMA,\n format_checker=format_checker,\n )\n for error in validator.iter_errors(schema):\n raise exceptions.SchemaError.create_from(error)\n\n def evolve(self, **changes):\n # Essentially reproduces attr.evolve, but may involve instantiating\n # a different class than this one.\n cls = self.__class__\n\n schema = changes.setdefault(\"schema\", self.schema)\n NewValidator = validator_for(schema, default=cls)\n\n for field in attr.fields(cls):\n if not field.init:\n continue\n attr_name = field.name # To deal with private attributes.\n init_name = attr_name if attr_name[0] != \"_\" else attr_name[1:]\n if init_name not in changes:\n changes[init_name] = getattr(self, attr_name)\n\n return NewValidator(**changes)\n\n def iter_errors(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.iter_errors \"\n \"is deprecated and will be removed in a future \"\n \"release. Call validator.evolve(schema=new_schema).\"\n \"iter_errors(...) instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n else:\n _schema = self.schema\n\n if _schema is True:\n return\n elif _schema is False:\n yield exceptions.ValidationError(\n f\"False schema does not allow {instance!r}\",\n validator=None,\n validator_value=None,\n instance=instance,\n schema=_schema,\n )\n return\n\n scope = id_of(_schema)\n if scope:\n self.resolver.push_scope(scope)\n try:\n for k, v in applicable_validators(_schema):\n validator = self.VALIDATORS.get(k)\n if validator is None:\n continue\n\n errors = validator(self, v, instance, _schema) or ()\n for error in errors:\n # set details if not already set by the called fn\n error._set(\n validator=k,\n validator_value=v,\n instance=instance,\n schema=_schema,\n type_checker=self.TYPE_CHECKER,\n )\n if k not in {\"if\", \"$ref\"}:\n error.schema_path.appendleft(k)\n yield error\n finally:\n if scope:\n self.resolver.pop_scope()\n\n def descend(self, instance, schema, path=None, schema_path=None):\n for error in self.evolve(schema=schema).iter_errors(instance):\n if path is not None:\n error.path.appendleft(path)\n if schema_path is not None:\n error.schema_path.appendleft(schema_path)\n yield error\n\n def validate(self, *args, **kwargs):\n for error in self.iter_errors(*args, **kwargs):\n raise error\n\n def is_type(self, instance, type):\n try:\n return self.TYPE_CHECKER.is_type(instance, type)\n except exceptions.UndefinedTypeCheck:\n raise exceptions.UnknownType(type, instance, self.schema)\n\n def is_valid(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.is_valid is deprecated \"\n \"and will be removed in a future release. Call \"\n \"validator.evolve(schema=new_schema).is_valid(...) \"\n \"instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n self = self.evolve(schema=_schema)\n\n error = next(self.iter_errors(instance), None)\n return error is None\n\n if version is not None:\n safe = version.title().replace(\" \", \"\").replace(\"-\", \"\")\n Validator.__name__ = Validator.__qualname__ = f\"{safe}Validator\"\n Validator = validates(version)(Validator)\n\n return Validator\n```\n###test function signature:\n```python\n\n def test_validates_registers_meta_schema_id(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validates_registers_meta_schema_id", + "reference": "\n def test_validates_registers_meta_schema_id(self):\n meta_schema_key = \"meta schema id\"\n my_meta_schema = {\"id\": meta_schema_key}\n\n validators.create(\n meta_schema=my_meta_schema,\n version=\"my version\",\n id_of=lambda s: s.get(\"id\", \"\"),\n )\n self.addCleanup(validators._META_SCHEMAS.pop, meta_schema_key)\n\n self.assertIn(meta_schema_key, validators._META_SCHEMAS)", + "completions": [ + "```python\n def test_validates_registers_meta_schema_id(self):\n validator = create(\n meta_schema={\n \"$id\": \"https://example.com/schema.json\",\n \"type\": \"object\",\n \"properties\": {\"foo\": {\"type\": \"string\"}},\n }\n )\n self.assertEqual(\n validator.resolver.store[\"https://example.com/schema.json\"],\n validator.schema,\n )\n```\n" + ], + "line": 160, + "token": 896, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 94, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_metrics(directory: str) -> pd.Series:\n\n dfs = []\n # Loop through all csv files in the directory\n csv_files = [\n filename for filename in os.listdir(directory) if filename.endswith(\".csv\")\n ]\n\n if len(csv_files) == 0:\n return pd.Series()\n\n for filename in csv_files:\n filepath = os.path.join(directory, filename)\n # Read the csv file into a pandas DataFrame\n df = pd.read_csv(\n filepath,\n header=None,\n names=[\n \"initial_time_iso\",\n \"lead_time_hours\",\n \"channel\",\n \"metric\",\n \"value\",\n ],\n )\n out = pd.DataFrame()\n out[\"initial_time\"] = pd.to_datetime(df[\"initial_time_iso\"])\n out[\"lead_time\"] = pd.to_timedelta(df[\"lead_time_hours\"], unit=\"h\")\n out[\"channel\"] = df[\"channel\"]\n out[\"metric\"] = df[\"metric\"]\n out[\"value\"] = df[\"value\"]\n dfs.append(out)\n\n df = pd.concat(dfs, axis=0)\n df.set_index([\"initial_time\", \"lead_time\", \"channel\", \"metric\"], inplace=True)\n return df[\"value\"]\n```\n###test function signature:\n```python\n\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_and_read_metric", + "reference": "\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n initial_time = datetime.datetime(2022, 1, 1, 0, 0, 0)\n lead_time = datetime.timedelta(hours=24)\n channel = \"t2m\"\n metric = \"rmse\"\n value = 25.6\n\n # Read the metric from the directory\n directory = tmpdir.mkdir(\"metrics\")\n filename = \"test_metric.csv\"\n filepath = os.path.join(directory, filename)\n with open(filepath, \"w\") as f:\n forecast_metrics_io.write_metric(\n f, initial_time, lead_time, channel, metric, value\n )\n metrics = forecast_metrics_io.read_metrics(directory)\n\n # Check that the metric was written correctly\n expected_index = pd.MultiIndex.from_tuples(\n [(initial_time, lead_time, channel, metric)],\n names=[\"initial_time\", \"lead_time\", \"channel\", \"metric\"],\n )\n expected_value = pd.Series([value], index=expected_index, name=\"value\")\n pd.testing.assert_series_equal(metrics, expected_value)", + "completions": [ + "```python\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n metric = pd.Series(\n data=[1, 2, 3],\n index=[\n pd.Timestamp(\"2020-01-01 00:00:00\"),\n pd.Timestamp(\"2020-01-01 01:00:00\"),\n pd.Timestamp(\"2020-01-01 02:00:00\"),\n ],\n name=\"metric\",\n )\n metric.to_csv(tmpdir.join(\"metric.csv\"), header=False)\n\n # Read the metric from the file object\n read_metric = read_metrics(tmpdir.strpath)\n\n # Check that the metric is the same\n assert_series_equal(metric, read_metric)\n```\n" + ], + "line": 43, + "token": 320, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 95, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n\n def test_kde_bandwidth_method(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_kde_bandwidth_method", + "reference": "\n def test_kde_bandwidth_method(self):\n\n np.random.seed(8765678)\n n_basesample = 50\n xn = np.random.randn(n_basesample)\n\n # Default\n gkde = mlab.GaussianKDE(xn)\n # Supply a callable\n gkde2 = mlab.GaussianKDE(xn, 'scott')\n # Supply a scalar\n gkde3 = mlab.GaussianKDE(xn, bw_method=gkde.factor)\n\n xs = np.linspace(-7, 7, 51)\n kdepdf = gkde.evaluate(xs)\n kdepdf2 = gkde2.evaluate(xs)\n assert kdepdf.all() == kdepdf2.all()\n kdepdf3 = gkde3.evaluate(xs)\n assert kdepdf.all() == kdepdf3.all()", + "completions": [ + "```python\n def test_kde_bandwidth_method(self):\n from sklearn.neighbors import KernelDensity\n from sklearn.metrics import euclidean_distances\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n from sklearn.model_selection import train_test_split\n from sklearn.metrics import accuracy_score\n from sklearn.metrics import f1_score\n from sklearn.metrics import precision_score\n from sklearn.metrics import recall_score\n from sklearn.metrics import roc_auc_score\n from sklearn.metrics import confusion_matrix\n from sklearn.metrics import classification_report\n from sklearn.metrics import cohen_kappa_score\n from sklearn.metrics import mean_squared_error\n from sklearn.metrics import mean_absolute_error\n from sklearn.metrics import mean_squared_log_error\n from sklearn.metrics import median_absolute_error\n from sklearn.metrics import r2_score\n from sklearn.metrics import explained_variance_score\n from sklearn.metrics import silhouette_score\n from sklearn.metrics import calinski_harabasz_score\n from sklearn.metrics import davies_bouldin_score\n from sklearn.metrics import mutual_info_score\n from sklearn.metrics import normalized_mutual_info_score\n from sklearn.metrics import adjusted_mutual_info_score\n from sklearn.metrics import adjusted_rand_score\n from sklearn.metrics import completeness_score\n from sklearn.metrics import contingency_matrix\n from sklearn.metrics import homogeneity_completeness_v_measure\n from sklearn.metrics import homogeneity_score\n from sklearn.metrics import v_measure_score\n from sklearn.metrics import normalized_v_measure_score\n from sklearn.metrics import poisson\n from sklearn.metrics import kld\n from sklearn.metrics import kld_normalized\n from sklearn.metrics import kld_unnormalized\n from sklearn.metrics import kld_unnormalized_normalized\n from sklearn.metrics import kld_unnormalized_normalized_sym\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from" + ], + "line": 33, + "token": 273, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 96, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_top_k_embeddings(\n query_embedding: List[float],\n embeddings: List[List[float]],\n similarity_fn: Optional[Callable[..., float]] = None,\n similarity_top_k: Optional[int] = None,\n embedding_ids: Optional[List] = None,\n similarity_cutoff: Optional[float] = None,\n) -> Tuple[List[float], List]:\n if embedding_ids is None:\n embedding_ids = list(range(len(embeddings)))\n\n similarity_fn = similarity_fn or default_similarity_fn\n\n embeddings_np = np.array(embeddings)\n query_embedding_np = np.array(query_embedding)\n\n similarity_heap: List[Tuple[float, Any]] = []\n for i, emb in enumerate(embeddings_np):\n similarity = similarity_fn(query_embedding_np, emb)\n if similarity_cutoff is None or similarity > similarity_cutoff:\n heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))\n if similarity_top_k and len(similarity_heap) > similarity_top_k:\n heapq.heappop(similarity_heap)\n result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)\n\n result_similarities = [s for s, _ in result_tups]\n result_ids = [n for _, n in result_tups]\n\n return result_similarities, result_ids\n```\n###test function signature:\n```python\ndef test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_top_k_mmr_embeddings", + "reference": "def test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n query_embedding = [5.0, 0.0, 0.0]\n embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]]\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.8\n )\n\n assert np.isclose(0.8 * 4 / 5, result_similarities[0], atol=0.00001)\n assert np.isclose(\n 0.8 * 3 / 5 - (1 - 0.8) * (3 * 4 / 25 + 3 * 4 / 25),\n result_similarities[1],\n atol=0.00001,\n )\n assert np.isclose(\n 0.8 * -4 / 5 - (1 - 0.8) * (3 * -4 / 25 + 4 * 3 / 25),\n result_similarities[2],\n atol=0.00001,\n )\n assert result_ids == [0, 1, 2]\n\n # Tests that if the first embedding vector is close to the second,\n # it will return the third\n query_embedding = [1.0, 0.0, 1.0]\n embeddings = [[1.0, 0.0, 0.9], [1.0, 0.0, 0.8], [0.7, 0.0, 1.0]]\n\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.5\n )\n assert result_ids == [0, 2, 1]\n\n # Tests that embedding ids map properly to results\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, embedding_ids=[\"A\", \"B\", \"C\"], mmr_threshold=0.5\n )\n assert result_ids == [\"A\", \"C\", \"B\"]\n # Test that it will go back to the original order under a high threshold\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n assert result_ids == [0, 1, 2]\n\n # Test similarity_top_k works\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1, similarity_top_k=2\n )\n assert result_ids == [0, 1]\n\n # Test the results for get_top_k_embeddings and get_top_k_mmr_embeddings are the\n # same for threshold = 1\n query_embedding = [10, 23, 90, 78]\n embeddings = [[1, 23, 89, 68], [1, 74, 144, 23], [0.23, 0.0, 1.0, 9]]\n result_similarities_no_mmr, result_ids_no_mmr = get_top_k_embeddings(\n query_embedding, embeddings\n )\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n\n for result_no_mmr, result_with_mmr in zip(\n result_similarities_no_mmr, result_similarities\n ):\n assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)", + "completions": [ + "```python\ndef test_get_top_k_mmr_embeddings() -> None:\n query_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]\n embeddings = [\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n ]\n result_similarities, result_ids = get_top_k_embeddings(\n query_embedding, embeddings, similarity_top_k=3\n )\n assert result_similarities == [0.5, 0.4, 0.3]\n assert result_ids == [0, 1, 2]\n```\n" + ], + "line": 35, + "token": 309, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 97, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef frame_accuracy(\n logits: np.ndarray,\n label: np.ndarray,\n background_logit_threshold: float = 0.0,\n) -> float:\n top1_idx = np.argmax(logits, axis=-1)\n background_label = np.sum(label, axis=-1) == 0\n pred_background = (\n np.max(logits, axis=-1) <= background_logit_threshold\n ).astype(np.int32)\n\n # Extracts the label at the highest logit index for each input.\n top1_correct = np.take_along_axis(label, top1_idx[..., None], axis=-1)\n top1_correct = np.squeeze(top1_correct)\n foreground_correct = ~pred_background.astype(bool) * top1_correct\n\n # Count correctly classified background samples.\n background_correct = pred_background * background_label\n correct = foreground_correct + background_correct\n return np.sum(correct) / len(logits)\n```\n###test function signature:\n```python\n\n def test_frame_accuracy_all_background(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_frame_accuracy_all_background", + "reference": "\n def test_frame_accuracy_all_background(self):\n logits = np.array(\n [\n [1.5, 0.0, -1.0],\n [0.1, 0.7, 0.2],\n [0.1, 0.7, 0.2],\n [0.1, -0.7, 0.2],\n [-0.1, -0.7, -0.2],\n ],\n np.float32,\n )\n label = np.array(\n [\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n ],\n np.int32,\n )\n actual = metrics.frame_accuracy(logits, label)\n self.assertAlmostEqual(actual, 0.2)", + "completions": [ + "```python\ndef test_frame_accuracy_all_background(self):\n logits = np.array([[0.1, 0.2, 0.3, 0.4]])\n label = np.array([[0, 0, 0, 0]])\n self.assertAlmostEqual(frame_accuracy(logits, label), 1.0)\n```\n" + ], + "line": 28, + "token": 206, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 98, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_local_image_to_be_loaded", + "reference": "\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n file_path, _ = example_local_image\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=file_path,\n max_width=64,\n max_height=128,\n )\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert abs(scaling_factor - 0.5) < 1e-5\n assert decoding_result.shape == (64, 64, 3)\n assert (decoding_result == 0).all()", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n reference, image = example_local_image\n # when\n result = load_image_from_string(\n reference=reference,\n max_height=image.shape[0],\n max_width=image.shape[1],\n )\n # then\n assert result[0] == reference\n assert result[1] is None\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 99, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = \"htk\") -> Union[float, np.ndarray]:\n\n if mel_scale not in [\"slaney\", \"htk\", \"kaldi\"]:\n raise ValueError('mel_scale should be one of \"htk\", \"slaney\" or \"kaldi\".')\n\n if mel_scale == \"htk\":\n return 2595.0 * np.log10(1.0 + (freq / 700.0))\n elif mel_scale == \"kaldi\":\n return 1127.0 * np.log(1.0 + (freq / 700.0))\n\n min_log_hertz = 1000.0\n min_log_mel = 15.0\n logstep = 27.0 / np.log(6.4)\n mels = 3.0 * freq / 200.0\n\n if isinstance(freq, np.ndarray):\n log_region = freq >= min_log_hertz\n mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep\n elif freq >= min_log_hertz:\n mels = min_log_mel + np.log(freq / min_log_hertz) * logstep\n\n return mels\n```\n###test function signature:\n```python\n\n def test_hertz_to_mel(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_hertz_to_mel", + "reference": "\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(0.0), 0.0)\n self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)\n\n inputs = np.array([100, 200])\n expected = np.array([150.48910241, 283.22989816])\n self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))\n\n self.assertEqual(hertz_to_mel(0.0, \"slaney\"), 0.0)\n self.assertEqual(hertz_to_mel(100, \"slaney\"), 1.5)\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"slaney\"), expected))\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"kaldi\"), expected))\n\n with pytest.raises(ValueError):\n hertz_to_mel(100, mel_scale=None)", + "completions": [ + "```python\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(" + ], + "line": 27, + "token": 262, + "line_diff": 1, + "token_diff": 1 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/5/QS/line_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Python/5/QS/line_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..67c2c67f0510205deeb1d409fd734fbbcfde06b9 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/5/QS/line_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(15~26),subset_1(27~40),subset_2(41~56),subset_3(56~95),subset_4(95~749) +StarCoder2-15b,27.80,24.21,23.65,20.97,25.05 +CodeLlama-7b,23.45,22.46,30.49,27.85,28.14 +CodeLlama-13b,29.74,27.93,26.04,25.15,25.03 +CodeLlama-34b,30.95,26.29,23.04,26.33,28.36 +DeepSeek-Coder-1.3b,23.19,24.99,27.18,25.69,27.06 +DeepSeek-Coder-6.7b,24.03,24.78,26.02,23.78,24.19 +DeepSeek-Coder-33b,30.01,29.30,28.21,27.12,30.49 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/5/QS/token_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Python/5/QS/token_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..6107c3b264684bd0dfad71231052b75a0b4ed250 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/5/QS/token_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(122~195),subset_1(205~280),subset_2(280~390),subset_3(405~645),subset_4(692~7038) +StarCoder2-15b,25.03,28.26,22.49,22.60,23.76 +CodeLlama-7b,22.90,28.39,29.13,24.22,27.77 +CodeLlama-13b,31.41,23.82,26.33,27.52,24.13 +CodeLlama-34b,27.20,31.28,24.86,24.59,27.79 +DeepSeek-Coder-1.3b,23.39,26.03,27.07,25.85,25.78 +DeepSeek-Coder-6.7b,23.45,27.18,25.02,24.54,22.61 +DeepSeek-Coder-33b,28.14,31.46,28.49,28.20,29.13 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/5/QS/tongji.py b/dataset/Test Generation/ComplexCodeEval-Python/5/QS/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/5/QS/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/6/EI/EI.json b/dataset/Test Generation/ComplexCodeEval-Python/6/EI/EI.json new file mode 100644 index 0000000000000000000000000000000000000000..6baa59a01c6dd7c3aa37bb3736f6ec6e120cd929 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/6/EI/EI.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_pair_vector_and_distance(g: dgl.DGLGraph):\n dst_pos = g.ndata[\"pos\"][g.edges()[1]] + g.edata[\"pbc_offshift\"]\n src_pos = g.ndata[\"pos\"][g.edges()[0]]\n bond_vec = dst_pos - src_pos\n bond_dist = torch.norm(bond_vec, dim=1)\n\n return bond_vec, bond_dist\n```\n###test function signature:\n```python\n\n def test_compute_angle(self, graph_Mo, graph_CH4):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_angle", + "reference": "\n def test_compute_angle(self, graph_Mo, graph_CH4):\n s1, g1, state1 = graph_Mo\n lattice = torch.tensor(s1.lattice.matrix, dtype=matgl.float_th).unsqueeze(dim=0)\n g1.edata[\"pbc_offshift\"] = torch.matmul(g1.edata[\"pbc_offset\"], lattice[0])\n g1.ndata[\"pos\"] = g1.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g1)\n g1.edata[\"bond_vec\"] = bv\n g1.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g1, 4.0)\n\n line_graph = create_line_graph(g1, 4.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))\n np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata[\"theta\"])), decimal=4)\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n s2, g2, state2 = graph_CH4\n lattice = torch.tensor(np.identity(3), dtype=matgl.float_th).unsqueeze(dim=0)\n g2.edata[\"pbc_offshift\"] = torch.matmul(g2.edata[\"pbc_offset\"], lattice[0])\n g2.ndata[\"pos\"] = g2.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g2)\n g2.edata[\"bond_vec\"] = bv\n g2.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g2, 2.0)\n\n line_graph = create_line_graph(g2, 2.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata[\"theta\"]))\n )\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )", + "completions": [ + "```python\n def test_compute_angle(self, graph_Mo, graph_CH4):\n \"\"\"\n Test compute_angle function\n \"\"\"\n # compute angle between Mo-C and Mo-H\n angle_Mo_C = compute_angle(graph_Mo, graph_CH4, 0, 1)\n angle_Mo_H = compute_angle(graph_Mo, graph_CH4, 0, 2)\n\n # check angle between Mo-C and Mo-H\n self.assertAlmostEqual(angle_Mo_C, 104.5, places=1)\n self.assertAlmostEqual(angle_Mo_H, 109.5, places=1)\n```\n" + ], + "line": 16, + "token": 148, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 1, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def _make_schedule(self, targets: List[Task]) -> List[Task]:\n self.schedule = []\n self.dependencies = self._build_dependencies(targets)\n\n edge_tups = []\n for node in self.dependencies:\n for dependency in self.dependencies[node]:\n edge_tups.append((dependency, node))\n\n for task in targets:\n # add edges from a dummy node to each target to guarantee\n # they will be included in the final schedule\n edge_tups.append((Executor.DUMMY_TASK_VALUE, task))\n\n def _compare_key(task: Union[Task, str]):\n if task == Executor.DUMMY_TASK_VALUE:\n return (\"\", 0)\n return (\n task.group_label() or \"\",\n -task.priority(),\n )\n\n graph = networkx.DiGraph(edge_tups)\n res = [\n t\n for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)\n if t != Executor.DUMMY_TASK_VALUE\n ]\n return res\n```\n###test function signature:\n```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_group_label_with_dependencies", + "reference": "\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = create_mock_task(\"task1\", result=1, group_label=\"group1\")\n task2 = create_mock_task(\n \"task2\", result=2, dependencies={\"task1\": task1}, group_label=\"group2\"\n )\n task3 = create_mock_task(\n \"task3\", result=3, dependencies={\"task2\": task2}, group_label=\"group1\"\n )\n\n executor = Executor([task3])\n schedule = executor._make_schedule([task3])\n scheduled_labels = [\n task.group_label() for task in schedule if task.group_label()\n ]\n\n # Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1\n group1_indices = [\n i for i, label in enumerate(scheduled_labels) if label == \"group1\"\n ]\n group2_index = scheduled_labels.index(\"group2\")\n\n assert (\n group1_indices[-1] > group2_index\n ), \"Task with the same group label but later dependency was not scheduled after different group label\"", + "completions": [ + "```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = Task(name=\"task1\", group_label=\"group1\")\n task2 = Task(name=\"task2\", group_label=\"group1\")\n task3 = Task(name=\"task3\", group_label=\"group2\")\n task4 = Task(name=\"task4\", group_label=\"group2\")\n task5 = Task(name=\"task5\", group_label=\"group3\")\n task6 = Task(name=\"task6\", group_label=\"group3\")\n task7 = Task(name=\"task7\", group_label=\"group4\")\n task8 = Task(name=\"task8\", group_label=\"group4\")\n task9 = Task(name=\"task9\", group_label=\"group5\")\n task10 = Task(name=\"task10\", group_label=\"group5\")\n\n # Creating dependencies between tasks\n task1.depends_on(task2)\n task2.depends_on(task3)\n task3.depends_on(task4)\n task4.depends_on(task5)\n task5.depends_on(task6)\n task6.depends_on(task7)\n task7.depends_on(task8)\n task8.depends_on(task9)\n task9.depends_on(task10)\n\n # Creating a list of tasks\n tasks = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies\n tasks_with_dependencies = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with group labels\n tasks_with_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_" + ], + "line": 36, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 2, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n):\n if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:\n return\n\n head_dim = dim // heads\n max_seq_len = find_multiple(max_seq_len, 8)\n\n for b in layers:\n b.attention.kv_cache = KVCache(\n max_batch_size, max_seq_len, heads, head_dim\n )\n\n freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)\n causal_mask = torch.tril(\n torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)\n )\n\n return causal_mask, freq_cis\n```\n###test function signature:\n```python\n\ndef test_setup_cache_max_batch_size_greater_than_max():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_setup_cache_max_batch_size_greater_than_max", + "reference": "\ndef test_setup_cache_max_batch_size_greater_than_max():\n layers = [\n Mock(attention=Mock(kw_cache=None)),\n Mock(attention=Mock(kw_cache=None)),\n ]\n max_batch_size = 64\n max_seq_len = 32\n heads = 4\n head_dim = 32\n block_size = 32\n rope_base = 1000\n setup_cache(\n max_batch_size + 10,\n max_seq_len,\n head_dim * heads,\n heads,\n layers,\n block_size,\n rope_base,\n )\n for layer in layers:\n assert isinstance(layer.attention.kw_cache, KVCache)\n assert layer.attention.kw_cache.k_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )\n assert layer.attention.kw_cache.v_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )", + "completions": [ + "```python\ndef test_setup_cache_max_batch_size_greater_than_max():\n max_batch_size = 10\n max_seq_len = 10\n dim = 10\n heads = 10\n layers = [1, 2, 3]\n block_size = 10\n rope_base = 10\n\n causal_mask, freq_cis = setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n )\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size" + ], + "line": 26, + "token": 168, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_file_to_be_downloaded", + "reference": "\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n response = Response()\n response.status_code = 200\n image = np.zeros((128, 128, 3), dtype=np.uint8)\n _, encoded_image = cv2.imencode(\".jpg\", image)\n response._content = encoded_image\n requests_get_mock.return_value = response\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=\"https://some/file.jpg\"\n )\n\n # then\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert scaling_factor is None\n assert decoding_result.shape == image.shape\n assert np.allclose(decoding_result, image)\n requests_get_mock.assert_called_once_with(\"https://some/file.jpg\")", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n reference = \"https://www.example.com/image.jpg\"\n max_height = 100\n max_width = 100\n requests_get_mock.return_value.content = b\"image_bytes\"\n # when\n result = load_image_from_string(\n reference=reference, max_height=max_height, max_width=max_width\n )\n # then\n assert result == (\n \"data:image/jpeg;base64,image_bytes\",\n None,\n )\n```\n" + ], + "line": 47, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 4, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_random_circuits", + "reference": "\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n placer = placer(connectivity=connectivity)\n layout_circ = Circuit(5)\n initial_layout = placer(layout_circ)\n router = Sabre(connectivity=connectivity, lookahead=look, decay_lookahead=decay)\n circuit = generate_random_circuit(nqubits=5, ngates=n_gates)\n measurement = gates.M(*range(5))\n circuit.add(measurement)\n transpiled_circuit, final_qubit_map = router(circuit, initial_layout)\n assert router.added_swaps >= 0\n assert_connectivity(connectivity, transpiled_circuit)\n assert_placement(transpiled_circuit, final_qubit_map)\n assert n_gates + router.added_swaps + 1 == transpiled_circuit.ngates\n assert_circuit_equivalence(\n original_circuit=circuit,\n transpiled_circuit=transpiled_circuit,\n final_map=final_qubit_map,\n initial_map=initial_layout,\n )\n assert transpiled_circuit.queue[-1].register_name == measurement.register_name", + "completions": [ + "```python\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n import random\n import networkx as nx\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import" + ], + "line": 27, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 5, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef xx_circuit_step(source, strength, target, embodiment):\n\n permute_source_for_overlap, permute_target_for_overlap = None, None\n\n # apply all possible reflections, shifts to the source\n for source_reflection_name in reflection_options:\n reflected_source_coord, source_reflection, reflection_phase_shift = apply_reflection(\n source_reflection_name, source\n )\n for source_shift_name in shift_options:\n shifted_source_coord, source_shift, shift_phase_shift = apply_shift(\n source_shift_name, reflected_source_coord\n )\n\n # check for overlap, back out permutation\n source_shared, target_shared = None, None\n for i, j in [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]:\n\n if (\n abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi)) < EPSILON\n or abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi) - np.pi)\n < EPSILON\n ):\n source_shared, target_shared = i, j\n break\n if source_shared is None:\n continue\n\n # pick out the other coordinates\n source_first, source_second = (x for x in [0, 1, 2] if x != source_shared)\n target_first, target_second = (x for x in [0, 1, 2] if x != target_shared)\n\n # check for arccos validity\n r, s, u, v, x, y = decompose_xxyy_into_xxyy_xx(\n float(target[target_first]),\n float(target[target_second]),\n float(shifted_source_coord[source_first]),\n float(shifted_source_coord[source_second]),\n float(strength),\n )\n if any(math.isnan(val) for val in (r, s, u, v, x, y)):\n continue\n\n # OK: this combination of things works.\n # save the permutation which rotates the shared coordinate into ZZ.\n permute_source_for_overlap = canonical_rotation_circuit(source_first, source_second)\n permute_target_for_overlap = canonical_rotation_circuit(target_first, target_second)\n break\n\n if permute_source_for_overlap is not None:\n break\n\n if permute_source_for_overlap is None:\n raise QiskitError(\n \"Error during RZX decomposition: Could not find a suitable Weyl \"\n f\"reflection to match {source} to {target} along {strength}.\"\n )\n\n prefix_circuit, affix_circuit = QuantumCircuit(2), QuantumCircuit(2)\n\n # the basic formula we're trying to work with is:\n # target^p_t_f_o =\n # rs * (source^s_reflection * s_shift)^p_s_f_o * uv * operation * xy\n # but we're rearranging it into the form\n # target = affix source prefix\n # and computing just the prefix / affix circuits.\n\n # the outermost prefix layer comes from the (inverse) target permutation.\n prefix_circuit.compose(permute_target_for_overlap.inverse(), inplace=True)\n # the middle prefix layer comes from the local Z rolls.\n prefix_circuit.rz(2 * x, [0])\n prefix_circuit.rz(2 * y, [1])\n prefix_circuit.compose(embodiment, inplace=True)\n prefix_circuit.rz(2 * u, [0])\n prefix_circuit.rz(2 * v, [1])\n # the innermost prefix layer is source_reflection, shifted by source_shift,\n # finally conjugated by p_s_f_o.\n prefix_circuit.compose(permute_source_for_overlap, inplace=True)\n prefix_circuit.compose(source_reflection, inplace=True)\n prefix_circuit.global_phase += -np.log(reflection_phase_shift).imag\n prefix_circuit.global_phase += -np.log(shift_phase_shift).imag\n\n # the affix circuit is constructed in reverse.\n # first (i.e., innermost), we install the other half of the source transformations and p_s_f_o.\n affix_circuit.compose(source_reflection.inverse(), inplace=True)\n affix_circuit.compose(source_shift, inplace=True)\n affix_circuit.compose(permute_source_for_overlap.inverse(), inplace=True)\n # then, the other local rolls in the middle.\n affix_circuit.rz(2 * r, [0])\n affix_circuit.rz(2 * s, [1])\n # finally, the other half of the p_t_f_o conjugation.\n affix_circuit.compose(permute_target_for_overlap, inplace=True)\n\n return {\"prefix_circuit\": prefix_circuit, \"affix_circuit\": affix_circuit}\n```\n###test function signature:\n```python\n def test_xx_circuit_step(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_xx_circuit_step", + "reference": " def test_xx_circuit_step(self):\n\n for _ in range(100):\n source_coordinate, interaction, target_coordinate = self._generate_xxyy_test_case()\n\n source_embodiment = qiskit.QuantumCircuit(2)\n source_embodiment.append(UnitaryGate(canonical_matrix(*source_coordinate)), [0, 1])\n interaction_embodiment = qiskit.QuantumCircuit(2)\n interaction_embodiment.append(UnitaryGate(canonical_matrix(*interaction)), [0, 1])\n\n prefix_circuit, affix_circuit = itemgetter(\"prefix_circuit\", \"affix_circuit\")(\n xx_circuit_step(\n source_coordinate, interaction[0], target_coordinate, interaction_embodiment\n )\n )\n\n target_embodiment = QuantumCircuit(2)\n target_embodiment.compose(prefix_circuit, inplace=True)\n target_embodiment.compose(source_embodiment, inplace=True)\n target_embodiment.compose(affix_circuit, inplace=True)\n self.assertTrue(\n np.all(\n np.abs(\n qiskit.quantum_info.operators.Operator(target_embodiment).data\n - canonical_matrix(*target_coordinate)\n )\n < EPSILON\n )\n )", + "completions": [ + "```python\n def test_xx_circuit_step(self):\n source = np.array([0.1, 0.2, 0.3])\n strength = 0.4\n target = np.array([0.5, 0.6, 0.7])\n embodiment = QuantumCircuit(2)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.cx(0, 1)\n embodiment.barrier()\n embodiment.rz(0.1, 0)\n embodiment.rz(0.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.3, 0)\n embodiment.rz(0.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.5, 0)\n embodiment.rz(0.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.7, 0)\n embodiment.rz(0.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.9, 0)\n embodiment.rz(1.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.1, 0)\n embodiment.rz(1.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.3, 0)\n embodiment.rz(1.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.5, 0)\n embodiment.rz(1.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.7, 0)\n embodiment.rz(1.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.9, 0)\n embodiment.rz(2.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(2.1, 0)\n embodiment.rz(2.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment" + ], + "line": 90, + "token": 807, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_einsum(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_einsum", + "reference": "\n def test_einsum(self):\n x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n self.assertAllClose(\n knp.einsum(\"ijk,lkj->il\", x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"ijk,ikj->i\", x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"i...,j...k->...ijk\", x, y),\n np.einsum(\"i..., j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n self.assertAllClose(\n knp.Einsum(\"ijk,lkj->il\")(x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"ijk,ikj->i\")(x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n np.einsum(\"i...,j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))", + "completions": [ + "```python\n\n def test_arange(self):\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32),\n torch.arange(0, 10, 1, dtype=torch.float32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64),\n torch.arange(0, 10, 1, dtype=torch.float64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32),\n torch.arange(0, 10, 1, dtype=torch.int32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64),\n torch.arange(0, 10, 1, dtype=torch.int64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool),\n torch.arange(0, 10, 1, dtype=torch.bool),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64),\n torch.arange(0, 10, 1, dtype=torch.complex64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128),\n torch.arange(0, 10, 1, dtype=torch.complex128),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n )\n```\n" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 7, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_center_padding(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_center_padding", + "reference": "\n def test_spectrogram_center_padding(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"reflect\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,\n 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,\n 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,\n 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,\n 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,\n 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,\n 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,\n 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,\n 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,\n 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,\n 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,\n 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,\n 0.00217659, 0.00276204, 0.00260835, 0.00299299,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"constant\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,\n 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,\n 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,\n 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,\n 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,\n 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,\n 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,\n 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,\n 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,\n 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,\n 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,\n 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,\n 0.00788239, 0.00664407, 0.00824227, 0.00628301,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=False,\n )\n self.assertEqual(spec.shape, (257, 728))\n\n # fmt: off\n expected = np.array([\n 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,\n 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,\n 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,\n 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,\n 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,\n 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,\n 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,\n 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,\n 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,\n 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,\n 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,\n 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,\n 0.00811857, 0.00538216, 0.00685749, 0.00535275,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))", + "completions": [ + "```python\n def test_spectrogram_center_padding(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 4\n hop_length = 2\n fft_length = 16\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n np.testing.assert_array_equal(\n spectrogram,\n np.array(\n [\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0," + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 8, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_kword_ref_line", + "reference": "\ndef test_nanoplot_out_bars_with_kword_ref_line():\n\n out_bars_with_kword_ref_line = _generate_nanoplot(**CASES[12])\n\n assert _is_nanoplot_output(out_bars_with_kword_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = \"max\"\n y_ref_area = [1, 2]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = \"min\"\n expand_y = \"min\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_stroke_color=data_bar_negative_stroke_" + ], + "line": 749, + "token": 7038, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 9, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_take_along_axis(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_take_along_axis", + "reference": "\n def test_take_along_axis(self):\n x = np.arange(24).reshape([1, 2, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=1),\n np.take_along_axis(x, indices, axis=1),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=1)(x, indices),\n np.take_along_axis(x, indices, axis=1),\n )\n\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )\n\n # Test with axis=None\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.array([1, 2, 3], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=None),\n np.take_along_axis(x, indices, axis=None),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=None)(x, indices),\n np.take_along_axis(x, indices, axis=None),\n )\n\n # Test with negative indices\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )", + "completions": [ + "```python\n\n def test_take_along_axis(self):\n # test arange\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=1, dtype=torch.float32),\n torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=2, dtype=torch.float32),\n torch.tensor([0, 2, 4, 6, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=3, dtype=torch.float32),\n torch.tensor([0, 3, 6, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=4, dtype=torch.float32),\n torch.tensor([0, 4, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=5, dtype=torch.float32),\n torch.tensor([0, 5], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=6, dtype=torch.float32),\n torch.tensor([0, 6], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=7, dtype=torch.float32),\n torch.tensor([0, 7], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=8, dtype=torch.float32),\n torch.tensor([0, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=9, dtype=torch.float32),\n torch.tensor([0, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=10, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=11, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=12, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=13, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=14, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 10, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_correlation(\n response_level_scores: dict[str, list[dict[str, dict[str, Any]]]],\n name: str = '',\n) -> dict[str, dict[str, dict[str, float]]]:\n factscore_metrics = set(list_metrics(response_level_scores[FACTSCORE]))\n pred_metrics = set(list_metrics(response_level_scores[EVAL_METHOD]))\n result = {}\n\n for metric in factscore_metrics.intersection(pred_metrics):\n factscore_scores = [\n find_metric(response_level_scores[FACTSCORE][i], metric)\n for i in range(len(response_level_scores[FACTSCORE]))\n ]\n pred_scores = [\n find_metric(response_level_scores[EVAL_METHOD][i], metric)\n for i in range(len(response_level_scores[EVAL_METHOD]))\n ]\n\n # Remove failed runs\n for i in range(len(factscore_scores) - 1, -1, -1):\n if factscore_scores[i] == -1 or pred_scores[i] == -1:\n factscore_scores.pop(i)\n pred_scores.pop(i)\n\n # Cannot calculate correlation with only one data point\n if len(factscore_scores) <= 1 or len(pred_scores) <= 1:\n pearson_result, spearman_result = None, None\n else:\n pearson_result = stats.pearsonr(factscore_scores, pred_scores)\n spearman_result = stats.spearmanr(factscore_scores, pred_scores)\n\n scatter_plot(\n factscore_scores,\n pred_scores,\n f'{name}-{metric}',\n FACTSCORE,\n EVAL_METHOD,\n )\n result[metric] = {\n _PEARSON: CorrelationResult(pearson_result).result,\n _SPEARMAN: CorrelationResult(spearman_result).result,\n }\n\n return result\n```\n###test function signature:\n```python\n\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_correlation_one_data_point", + "reference": "\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n response_level_scores = {\n correlation_vs_factscore.FACTSCORE: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'N': 1}},\n ],\n correlation_vs_factscore.EVAL_METHOD: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'M': 1}},\n ],\n }\n actual_output = correlation_vs_factscore.compute_correlation(\n response_level_scores=response_level_scores\n )\n mock_pearsonr.assert_not_called()\n mock_spearmanr.assert_not_called()\n mock_scatter_plot.assert_called()\n self.assertIsInstance(actual_output, dict)\n self.assertLen(actual_output, 1)\n self.assertIn('S', actual_output)\n self.assertIsInstance(actual_output['S'], dict)\n self.assertLen(actual_output['S'], 2)\n self.assertIn('Pearson', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Pearson'], dict)\n self.assertIn('correlation', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['correlation']))\n self.assertIn('p_value', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['p_value']))\n self.assertIn('Spearman', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Spearman'], dict)\n self.assertIn('correlation', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['correlation']))\n self.assertIn('p_value', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['p_value']))", + "completions": [ + "```python\ndef test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n) -> None:\n response_level_scores = {\n FACTSCORE: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n EVAL_METHOD: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n }\n result = compute_correlation(response_level_scores)\n self.assertEqual(result, {})\n mock_spearmanr.assert_not_called()\n mock_pearsonr.assert_not_called()\n mock_scatter_plot.assert_not_called()\n```\n" + ], + "line": 54, + "token": 384, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_data_lines():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_data_lines", + "reference": "\ndef test_nanoplot_out_data_lines():\n\n out_data_lines = _generate_nanoplot(**CASES[0])\n\n assert _is_nanoplot_output(out_data_lines)\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_data_lines():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n expected = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_" + ], + "line": 749, + "token": 7038, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 12, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_daily_temp_mean", + "reference": "\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data = il_electricity_cdd_hdd_daily[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_daily[\"temperature_data\"]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert df.shape == (810, 3)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_mean\",\n ]\n\n assert round(df.temperature_mean.mean()) == 55.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"D\", tz=\"UTC\"\n )\n temperature_data = pd.DataFrame(\n {\"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]},\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n df = compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=[],\n cooling_balance_points=[],\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n )\n expected = pd.DataFrame(\n {\n \"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],\n \"n_hours_kept\": [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],\n \"n_hours_dropped\": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n },\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n pd.testing.assert_frame_equal(df, expected)\n```\n" + ], + "line": 166, + "token": 983, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 13, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef power_color(\n frequency,\n power,\n power_err=None,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n):\n freq_edges = np.asarray(freq_edges)\n if len(freq_edges) != 5:\n raise ValueError(\"freq_edges must have 5 elements\")\n\n frequency = np.asarray(frequency)\n power = np.asarray(power)\n\n if df is None:\n df = np.median(np.diff(frequency))\n input_frequency_low_edges = frequency - df / 2\n input_frequency_high_edges = frequency + df / 2\n\n if freq_edges.min() < input_frequency_low_edges[0]:\n raise ValueError(\"The minimum frequency is larger than the first frequency edge\")\n if freq_edges.max() > input_frequency_high_edges[-1]:\n raise ValueError(\"The maximum frequency is lower than the last frequency edge\")\n\n if power_err is None:\n power_err = power / np.sqrt(m)\n else:\n power_err = np.asarray(power_err)\n\n if freqs_to_exclude is not None:\n if len(np.shape(freqs_to_exclude)) == 1:\n freqs_to_exclude = [freqs_to_exclude]\n\n if (\n not isinstance(freqs_to_exclude, Iterable)\n or len(np.shape(freqs_to_exclude)) != 2\n or np.shape(freqs_to_exclude)[1] != 2\n ):\n raise ValueError(\"freqs_to_exclude must be of format [[f0, f1], [f2, f3], ...]\")\n for f0, f1 in freqs_to_exclude:\n frequency_mask = (input_frequency_low_edges > f0) & (input_frequency_high_edges < f1)\n idx0, idx1 = np.searchsorted(frequency, [f0, f1])\n power[frequency_mask] = np.mean([power[idx0], power[idx1]])\n\n var00, var00_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[:2],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var01, var01_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[2:4],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var10, var10_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[1:3],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var11, var11_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[3:5],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n pc0 = var00 / var01\n pc1 = var10 / var11\n pc0_err = pc0 * (var00_err / var00 + var01_err / var01)\n pc1_err = pc1 * (var10_err / var10 + var11_err / var11)\n if return_log:\n pc0_err = 1 / pc0 * pc0_err\n pc1_err = 1 / pc1 * pc1_err\n pc0 = np.log10(pc0)\n pc1 = np.log10(pc1)\n return pc0, pc0_err, pc1, pc1_err\n```\n###test function signature:\n```python\n\n def test_with_power_err(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_power_err", + "reference": "\n def test_with_power_err(self):\n pc0, pc0_err, pc1, pc1_err = power_color(\n self.freq,\n self.power,\n power_err=self.power / 2,\n )\n pc0e, pc0e_err, pc1e, pc1e_err = power_color(\n self.freq,\n self.power,\n power_err=self.power,\n )\n assert np.isclose(pc0, 1, atol=0.001)\n assert np.isclose(pc1, 1, atol=0.001)\n assert np.isclose(pc0e, 1, atol=0.001)\n assert np.isclose(pc1e, 1, atol=0.001)\n assert np.isclose(pc0e_err / pc0_err, 2, atol=0.001)\n assert np.isclose(pc1e_err / pc1_err, 2, atol=0.001)", + "completions": [ + "```python\n def test_with_power_err(self):\n frequency = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power_err = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n pc0, pc0_err, pc1, pc1_err = power_color(\n frequency,\n power,\n power_err=power_err,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n )\n self.assertEqual(pc0, 0.0)\n self.assertEqual(pc0_err, 0.0)\n self.assertEqual(pc1, 0.0)\n self.assertEqual(pc1_err, 0.0)\n```\n" + ], + "line": 97, + "token": 560, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 14, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def apply_top_k(\n self, logits: numpy.ndarray, filter_value=_MIN_FLOAT\n ) -> numpy.ndarray:\n logits_shape = logits.shape\n logits = logits.reshape(logits.shape[-1])\n top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]\n logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value\n\n return logits.reshape(logits_shape)\n```\n###test function signature:\n```python\n\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_apply_topk", + "reference": "\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11))\n\n token_generator = TokenGenerator(\n logits_shape=logits[-1].shape[-1],\n top_k=3,\n )\n\n filter_value = -float(\"Inf\")\n new_logits = token_generator.apply_top_k(\n logits.copy(), filter_value=filter_value\n )\n\n for _ in range(token_generator.top_k):\n curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits)\n assert curr_max > filter_value\n new_logits = numpy.delete(new_logits, idx)\n\n assert numpy.all(new_logits == filter_value)", + "completions": [ + "```python\n def test_apply_topk(\n self,\n ):\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9," + ], + "line": 21, + "token": 152, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(\n self,\n sources: List[Union[str, Path, ByteStream]],\n meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,\n ):\n parser = MarkdownIt(renderer_cls=RendererPlain)\n if self.table_to_single_line:\n parser.enable(\"table\")\n\n documents = []\n meta_list = normalize_metadata(meta=meta, sources_count=len(sources))\n\n for source, metadata in tqdm(\n zip(sources, meta_list),\n total=len(sources),\n desc=\"Converting markdown files to Documents\",\n disable=not self.progress_bar,\n ):\n try:\n bytestream = get_bytestream_from_source(source)\n except Exception as e:\n logger.warning(\"Could not read {source}. Skipping it. Error: {error}\", source=source, error=e)\n continue\n try:\n file_content = bytestream.data.decode(\"utf-8\")\n text = parser.render(file_content)\n except Exception as conversion_e:\n logger.warning(\n \"Failed to extract text from {source}. Skipping it. Error: {error}\",\n source=source,\n error=conversion_e,\n )\n continue\n\n merged_metadata = {**bytestream.meta, **metadata}\n document = Document(content=text, meta=merged_metadata)\n documents.append(document)\n\n return {\"documents\": documents}\n```\n###test function signature:\n```python\n\n def test_run_calls_normalize_metadata(self, test_files_path):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_calls_normalize_metadata", + "reference": "\n def test_run_calls_normalize_metadata(self, test_files_path):\n bytestream = ByteStream(data=b\"test\", meta={\"author\": \"test_author\", \"language\": \"en\"})\n\n converter = MarkdownToDocument()\n\n with patch(\"haystack.components.converters.markdown.normalize_metadata\") as normalize_metadata, patch(\n \"haystack.components.converters.markdown.MarkdownIt\"\n ):\n converter.run(sources=[bytestream, test_files_path / \"markdown\" / \"sample.md\"], meta={\"language\": \"it\"})\n\n # check that the metadata normalizer is called properly\n normalize_metadata.assert_called_with(meta={\"language\": \"it\"}, sources_count=2)", + "completions": [ + "```python\n def test_run_calls_normalize_metadata(self, test_files_path):\n # Arrange\n sources = [test_files_path / \"test_file.md\"]\n meta = {\"key\": \"value\"}\n\n # Act\n run(sources, meta)\n\n # Assert\n normalize_metadata.assert_called_once_with(meta=meta, sources_count=1)\n```\n" + ], + "line": 45, + "token": 311, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 16, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_config(config: object) -> None:\n if hasattr(config, \"kind\"):\n Config.running_config.classes[config.kind] = config # type: ignore\n Config.running_config.types[config.kind] = {}\n for field in fields(config): # type: ignore\n if hasattr(field, \"type\"):\n Config.running_config.types[config.kind][field.name] = optional_origin(field.type)\n else:\n raise RuntimeError(\"Config must have a 'kind' attribute\")\n```\n###test function signature:\n```python\n\ndef test_config():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_config", + "reference": "\ndef test_config():\n arg_parser = get_arg_parser()\n core_add_args(arg_parser)\n arg_parser.parse_known_args()\n cfg = Config(\"test\")\n cfg.add_config(ConfigTest)\n cfg.init_default_config()\n assert Config.dict() == {\n \"configtest\": {\n \"testvar1\": \"testing123\",\n \"testvar2\": 12345,\n \"testvar3\": {\n \"mydict\": {\"foo\": \"bar\", \"abc\": {\"def\": \"ghi\"}},\n \"myint\": 0,\n \"mystr\": \"Hello\",\n },\n }\n }\n cfg2 = Config(\"test2\")\n assert cfg.configtest.testvar1 == cfg2.configtest.testvar1\n assert cfg.configtest.testvar1 == Config.configtest.testvar1\n assert cfg.configtest.testvar2 == cfg2.configtest.testvar2\n assert cfg.configtest.testvar3 == cfg2.configtest.testvar3\n Config.configtest.testvar2 += 1\n assert cfg.configtest.testvar2 == 12346\n with pytest.raises(ConfigNotFoundError):\n Config.does_not_exist.foo = \"bar\"\n with pytest.raises(ConfigNotFoundError):\n cfg.does_not_exist.foo = \"bar\"", + "completions": [ + "```python\ndef test_config():\n from dataclasses import dataclass, field\n from typing import Optional\n\n @dataclass\n class Config:\n kind: str\n a: Optional[int] = field(default=None)\n b: Optional[int] = field(default=None)\n\n add_config(Config)\n\n assert Config.running_config.classes[\"Config\"] == Config\n assert Config.running_config.types[\"Config\"] == {\"a\": int, \"b\": int}\n```\n" + ], + "line": 19, + "token": 166, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 17, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_with_num_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_with_num_ref_line", + "reference": "\ndef test_nanoplot_out_with_num_ref_line():\n\n out_with_num_ref_line = _generate_nanoplot(**CASES[1])\n\n assert _is_nanoplot_output(out_with_num_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_with_num_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color," + ], + "line": 749, + "token": 7038, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 18, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def begin_typing(self, x=None):\n self.capturekeystrokes = True\n # Disable keypress shortcuts, which may otherwise cause the figure to\n # be saved, closed, etc., until the user stops typing. The way to\n # achieve this depends on whether toolmanager is in use.\n stack = ExitStack() # Register cleanup actions when user stops typing.\n self._on_stop_typing = stack.close\n toolmanager = getattr(\n self.ax.figure.canvas.manager, \"toolmanager\", None)\n if toolmanager is not None:\n # If using toolmanager, lock keypresses, and plan to release the\n # lock when typing stops.\n toolmanager.keypresslock(self)\n stack.callback(toolmanager.keypresslock.release, self)\n else:\n # If not using toolmanager, disable all keypress-related rcParams.\n # Avoid spurious warnings if keymaps are getting deprecated.\n with _api.suppress_matplotlib_deprecation_warning():\n stack.enter_context(mpl.rc_context(\n {k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n```\n###test function signature:\n```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_TextBox", + "reference": "\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n plt.rcParams._set(\"toolbar\", toolbar)\n\n submit_event = mock.Mock(spec=noop, return_value=None)\n text_change_event = mock.Mock(spec=noop, return_value=None)\n tool = widgets.TextBox(ax, '')\n tool.on_submit(submit_event)\n tool.on_text_change(text_change_event)\n\n assert tool.text == ''\n\n do_event(tool, '_click')\n\n tool.set_val('x**2')\n\n assert tool.text == 'x**2'\n assert text_change_event.call_count == 1\n\n tool.begin_typing()\n tool.stop_typing()\n\n assert submit_event.call_count == 2\n\n do_event(tool, '_click')\n do_event(tool, '_keypress', key='+')\n do_event(tool, '_keypress', key='5')\n\n assert text_change_event.call_count == 3", + "completions": [ + "```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n with ExitStack() as stack:\n stack.enter_context(mpl.rc_context({k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n ax.begin_typing()\n assert ax.capturekeystrokes is True\n assert toolbar.keypresslock.is_locked(ax)\n```\n" + ], + "line": 31, + "token": 258, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def read_events(\n db: PrefectDBInterface,\n session: AsyncSession,\n events_filter: EventFilter,\n limit: \"int | None\" = None,\n offset: \"int | None\" = None,\n) -> Sequence[ORMEvent]:\n # Always order by occurred timestamp, with placeholder for order direction\n order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc\n\n # Check if distinct fields are provided\n if distinct_fields := build_distinct_queries(events_filter):\n # Define window function\n window_function = (\n sa.func.row_number()\n .over(partition_by=distinct_fields, order_by=order(db.Event.occurred))\n .label(\"row_number\")\n )\n # Create a subquery with the window function\n subquery = (\n sa.select(db.Event, window_function)\n .where(\n sa.and_(\n *events_filter.build_where_clauses(db)\n ) # Ensure the same filters are applied here\n )\n .subquery()\n )\n\n # Alias the subquery for easier column references\n aliased_table = aliased(db.Event, subquery)\n\n # Create the final query from the subquery, filtering to get only rows with row_number = 1\n select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1)\n\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(subquery.c.occurred))\n\n else:\n # If no distinct fields are provided, create a query for all events\n select_events_query = sa.select(db.Event).where(\n sa.and_(*events_filter.build_where_clauses(db))\n )\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(db.Event.occurred))\n\n if limit is not None:\n limit = max(0, min(limit, events_filter.logical_limit))\n select_events_query = select_events_query.limit(limit=limit)\n if offset is not None:\n select_events_query = select_events_query.offset(offset=offset)\n\n logger.debug(\"Running PostgreSQL query: %s\", select_events_query)\n\n select_events_query_result = await session.execute(select_events_query)\n return select_events_query_result.scalars().unique().all()\n```\n###test function signature:\n```python\n\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_read_events", + "reference": "\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_1.id\n\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=3),\n until=pendulum.now(\"UTC\").subtract(days=1),\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_2.id\n\n # Read criteria should apply AND logic\n # Nothing should be returned for this resource\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n resource=EventResourceFilter(id=[\"prefect.garbage.foo\"]),\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 0", + "completions": [ + "```python\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n # Arrange\n events_filter = EventFilter(\n occurred_after=datetime.datetime(2022, 1, 1),\n occurred_before=datetime.datetime(2022, 1, 2),\n order=EventOrder.DESC,\n )\n # Act\n events = await read_events(\n db=self.db,\n session=session,\n events_filter=events_filter,\n )\n # Assert\n self.assertEqual(len(events), 2)\n self.assertEqual(events[0], event_2)\n self.assertEqual(events[1], event_1)\n```\n" + ], + "line": 58, + "token": 412, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 20, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n def test_evaluate_diff_dim(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_evaluate_diff_dim", + "reference": " def test_evaluate_diff_dim(self):\n x1 = np.arange(3, 10, 2)\n kde = mlab.GaussianKDE(x1)\n x2 = np.arange(3, 12, 2)\n y_expected = [\n 0.08797252, 0.11774109, 0.11774109, 0.08797252, 0.0370153\n ]\n y = kde.evaluate(x2)\n np.testing.assert_array_almost_equal(y, y_expected, 7)", + "completions": [ + "```python\n def test_evaluate_diff_dim(self):\n \"\"\"\n Test evaluate with different dimensions\n \"\"\"\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n\n # Generate a random dataset\n X, _ = make_blobs(n_samples=100, n_features=10, centers=5,\n random_state=42)\n\n # Standardize the data\n scaler = StandardScaler()\n X = scaler.fit_transform(X)\n\n # Create a GaussianMixture object\n gmm = GaussianMixture(n_components=5, covariance_type='full')\n gmm.fit(X)\n\n # Generate a random set of points\n n_points = 100\n points = np.random.rand(n_points, 10)\n\n # Evaluate the points\n result = gmm.evaluate(points)\n\n # Check that the result is correct\n self.assertEqual(result.shape, (n_points,))\n self.assertTrue(np.all(result >= 0))\n self.assertTrue(np.all(result <= 1))\n```\n" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 21, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _differentiate(func, x, *, args=(), atol=None, rtol=None, maxiter=10,\n order=8, initial_step=0.5, step_factor=2.0,\n step_direction=0, preserve_shape=False, callback=None):\n # TODO (followup):\n # - investigate behavior at saddle points\n # - array initial_step / step_factor?\n # - multivariate functions?\n\n res = _differentiate_iv(func, x, args, atol, rtol, maxiter, order, initial_step,\n step_factor, step_direction, preserve_shape, callback)\n (func, x, args, atol, rtol, maxiter, order,\n h0, fac, hdir, preserve_shape, callback) = res\n\n # Initialization\n # Since f(x) (no step) is not needed for central differences, it may be\n # possible to eliminate this function evaluation. However, it's useful for\n # input validation and standardization, and everything else is designed to\n # reduce function calls, so let's keep it simple.\n temp = eim._initialize(func, (x,), args, preserve_shape=preserve_shape)\n func, xs, fs, args, shape, dtype = temp\n x, f = xs[0], fs[0]\n df = np.full_like(f, np.nan)\n # Ideally we'd broadcast the shape of `hdir` in `_elementwise_algo_init`, but\n # it's simpler to do it here than to generalize `_elementwise_algo_init` further.\n # `hdir` and `x` are already broadcasted in `_differentiate_iv`, so we know\n # that `hdir` can be broadcasted to the final shape.\n hdir = np.broadcast_to(hdir, shape).flatten()\n\n status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress\n nit, nfev = 0, 1 # one function evaluations performed above\n # Boolean indices of left, central, right, and (all) one-sided steps\n il = hdir < 0\n ic = hdir == 0\n ir = hdir > 0\n io = il | ir\n\n # Most of these attributes are reasonably obvious, but:\n # - `fs` holds all the function values of all active `x`. The zeroth\n # axis corresponds with active points `x`, the first axis corresponds\n # with the different steps (in the order described in\n # `_differentiate_weights`).\n # - `terms` (which could probably use a better name) is half the `order`,\n # which is always even.\n work = _RichResult(x=x, df=df, fs=f[:, np.newaxis], error=np.nan, h=h0,\n df_last=np.nan, error_last=np.nan, h0=h0, fac=fac,\n atol=atol, rtol=rtol, nit=nit, nfev=nfev,\n status=status, dtype=dtype, terms=(order+1)//2,\n hdir=hdir, il=il, ic=ic, ir=ir, io=io)\n # This is the correspondence between terms in the `work` object and the\n # final result. In this case, the mapping is trivial. Note that `success`\n # is prepended automatically.\n res_work_pairs = [('status', 'status'), ('df', 'df'), ('error', 'error'),\n ('nit', 'nit'), ('nfev', 'nfev'), ('x', 'x')]\n\n def pre_func_eval(work):\n \"\"\"Determine the abscissae at which the function needs to be evaluated.\n\n See `_differentiate_weights` for a description of the stencil (pattern\n of the abscissae).\n\n In the first iteration, there is only one stored function value in\n `work.fs`, `f(x)`, so we need to evaluate at `order` new points. In\n subsequent iterations, we evaluate at two new points. Note that\n `work.x` is always flattened into a 1D array after broadcasting with\n all `args`, so we add a new axis at the end and evaluate all point\n in one call to the function.\n\n For improvement:\n - Consider measuring the step size actually taken, since `(x + h) - x`\n is not identically equal to `h` with floating point arithmetic.\n - Adjust the step size automatically if `x` is too big to resolve the\n step.\n - We could probably save some work if there are no central difference\n steps or no one-sided steps.\n \"\"\"\n n = work.terms # half the order\n h = work.h # step size\n c = work.fac # step reduction factor\n d = c**0.5 # square root of step reduction factor (one-sided stencil)\n # Note - no need to be careful about dtypes until we allocate `x_eval`\n\n if work.nit == 0:\n hc = h / c**np.arange(n)\n hc = np.concatenate((-hc[::-1], hc))\n else:\n hc = np.asarray([-h, h]) / c**(n-1)\n\n if work.nit == 0:\n hr = h / d**np.arange(2*n)\n else:\n hr = np.asarray([h, h/d]) / c**(n-1)\n\n n_new = 2*n if work.nit == 0 else 2 # number of new abscissae\n x_eval = np.zeros((len(work.hdir), n_new), dtype=work.dtype)\n il, ic, ir = work.il, work.ic, work.ir\n x_eval[ir] = work.x[ir, np.newaxis] + hr\n x_eval[ic] = work.x[ic, np.newaxis] + hc\n x_eval[il] = work.x[il, np.newaxis] - hr\n return x_eval\n\n def post_func_eval(x, f, work):\n \"\"\" Estimate the derivative and error from the function evaluations\n\n As in `pre_func_eval`: in the first iteration, there is only one stored\n function value in `work.fs`, `f(x)`, so we need to add the `order` new\n points. In subsequent iterations, we add two new points. The tricky\n part is getting the order to match that of the weights, which is\n described in `_differentiate_weights`.\n\n For improvement:\n - Change the order of the weights (and steps in `pre_func_eval`) to\n simplify `work_fc` concatenation and eliminate `fc` concatenation.\n - It would be simple to do one-step Richardson extrapolation with `df`\n and `df_last` to increase the order of the estimate and/or improve\n the error estimate.\n - Process the function evaluations in a more numerically favorable\n way. For instance, combining the pairs of central difference evals\n into a second-order approximation and using Richardson extrapolation\n to produce a higher order approximation seemed to retain accuracy up\n to very high order.\n - Alternatively, we could use `polyfit` like Jacobi. An advantage of\n fitting polynomial to more points than necessary is improved noise\n tolerance.\n \"\"\"\n n = work.terms\n n_new = n if work.nit == 0 else 1\n il, ic, io = work.il, work.ic, work.io\n\n # Central difference\n # `work_fc` is *all* the points at which the function has been evaluated\n # `fc` is the points we're using *this iteration* to produce the estimate\n work_fc = (f[ic, :n_new], work.fs[ic, :], f[ic, -n_new:])\n work_fc = np.concatenate(work_fc, axis=-1)\n if work.nit == 0:\n fc = work_fc\n else:\n fc = (work_fc[:, :n], work_fc[:, n:n+1], work_fc[:, -n:])\n fc = np.concatenate(fc, axis=-1)\n\n # One-sided difference\n work_fo = np.concatenate((work.fs[io, :], f[io, :]), axis=-1)\n if work.nit == 0:\n fo = work_fo\n else:\n fo = np.concatenate((work_fo[:, 0:1], work_fo[:, -2*n:]), axis=-1)\n\n work.fs = np.zeros((len(ic), work.fs.shape[-1] + 2*n_new))\n work.fs[ic] = work_fc\n work.fs[io] = work_fo\n\n wc, wo = _differentiate_weights(work, n)\n work.df_last = work.df.copy()\n work.df[ic] = fc @ wc / work.h\n work.df[io] = fo @ wo / work.h\n work.df[il] *= -1\n\n work.h /= work.fac\n work.error_last = work.error\n # Simple error estimate - the difference in derivative estimates between\n # this iteration and the last. This is typically conservative because if\n # convergence has begin, the true error is much closer to the difference\n # between the current estimate and the *next* error estimate. However,\n # we could use Richarson extrapolation to produce an error estimate that\n # is one order higher, and take the difference between that and\n # `work.df` (which would just be constant factor that depends on `fac`.)\n work.error = abs(work.df - work.df_last)\n\n def check_termination(work):\n \"\"\"Terminate due to convergence, non-finite values, or error increase\"\"\"\n stop = np.zeros_like(work.df).astype(bool)\n\n i = work.error < work.atol + work.rtol*abs(work.df)\n work.status[i] = eim._ECONVERGED\n stop[i] = True\n\n if work.nit > 0:\n i = ~((np.isfinite(work.x) & np.isfinite(work.df)) | stop)\n work.df[i], work.status[i] = np.nan, eim._EVALUEERR\n stop[i] = True\n\n # With infinite precision, there is a step size below which\n # all smaller step sizes will reduce the error. But in floating point\n # arithmetic, catastrophic cancellation will begin to cause the error\n # to increase again. This heuristic tries to avoid step sizes that are\n # too small. There may be more theoretically sound approaches for\n # detecting a step size that minimizes the total error, but this\n # heuristic seems simple and effective.\n i = (work.error > work.error_last*10) & ~stop\n work.status[i] = _EERRORINCREASE\n stop[i] = True\n\n return stop\n\n def post_termination_check(work):\n return\n\n def customize_result(res, shape):\n return shape\n\n return eim._loop(work, callback, shape, maxiter, func, args, dtype,\n pre_func_eval, post_func_eval, check_termination,\n post_termination_check, customize_result, res_work_pairs,\n preserve_shape)\n```\n###test function signature:\n```python\n\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_maxiter_callback", + "reference": "\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n x = 0.612814\n dist = stats.norm()\n maxiter = 3\n\n def f(x):\n res = dist.cdf(x)\n return res\n\n default_order = 8\n res = differentiate(f, x, maxiter=maxiter, rtol=1e-15)\n assert not np.any(res.success)\n assert np.all(res.nfev == default_order + 1 + (maxiter - 1)*2)\n assert np.all(res.nit == maxiter)\n\n def callback(res):\n callback.iter += 1\n callback.res = res\n assert hasattr(res, 'x')\n assert res.df not in callback.dfs\n callback.dfs.add(res.df)\n assert res.status == eim._EINPROGRESS\n if callback.iter == maxiter:\n raise StopIteration\n callback.iter = -1 # callback called once before first iteration\n callback.res = None\n callback.dfs = set()\n\n res2 = differentiate(f, x, callback=callback, rtol=1e-15)\n # terminating with callback is identical to terminating due to maxiter\n # (except for `status`)\n for key in res.keys():\n if key == 'status':\n assert res[key] == eim._ECONVERR\n assert callback.res[key] == eim._EINPROGRESS\n assert res2[key] == eim._ECALLBACK\n else:\n assert res2[key] == callback.res[key] == res[key]", + "completions": [ + "```python\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n def callback(work):\n if work.nit == 0:\n work.maxiter = 1\n else:\n work.maxiter = 0\n res = _differentiate(lambda x: x**2, 0, maxiter=2, callback=callback)\n self.assertEqual(res.nit, 1)\n self.assertEqual(res.status, 0)\n```\n" + ], + "line": 187, + "token": 1981, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 22, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef get_standard_metadata(\n pyproject_dict: Mapping[str, Any],\n settings: ScikitBuildSettings,\n) -> StandardMetadata:\n new_pyproject_dict = dict(pyproject_dict)\n # Handle any dynamic metadata\n for field, provider, config in load_dynamic_metadata(settings.metadata):\n if provider is None:\n msg = f\"{field} is missing provider\"\n raise KeyError(msg)\n if field not in pyproject_dict.get(\"project\", {}).get(\"dynamic\", []):\n msg = f\"{field} is not in project.dynamic\"\n raise KeyError(msg)\n new_pyproject_dict[\"project\"][field] = provider.dynamic_metadata(field, config)\n new_pyproject_dict[\"project\"][\"dynamic\"].remove(field)\n\n metadata = StandardMetadata.from_pyproject(new_pyproject_dict)\n\n # For scikit-build-core < 0.5, we keep the normalized name for back-compat\n if settings.minimum_version is not None and settings.minimum_version < Version(\n \"0.5\"\n ):\n metadata.name = metadata.canonical_name\n\n # The description field is required to be one line. Instead of merging it\n # or cutting off subsequent lines (setuptools), we throw a nice error.\n # But we didn't validate before 0.9.\n if (\n settings.minimum_version is None or settings.minimum_version >= Version(\"0.9\")\n ) and \"\\n\" in (metadata.description or \"\"):\n msg = \"Multiple lines in project.description are not supported; this is supposed to be a one line summary\"\n raise ValueError(msg)\n\n return metadata\n```\n###test function signature:\n```python\n\ndef test_plugin_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_plugin_metadata", + "reference": "\ndef test_plugin_metadata():\n reason_msg = (\n \"install hatch-fancy-pypi-readme and setuptools-scm to test the \"\n \"dynamic metadata plugins\"\n )\n pytest.importorskip(\"hatch_fancy_pypi_readme\", reason=reason_msg)\n pytest.importorskip(\"setuptools_scm\", reason=reason_msg)\n if shutil.which(\"git\") is None:\n pytest.skip(\"git is not installed\")\n\n shutil.copy(\"plugin_project.toml\", \"pyproject.toml\")\n\n subprocess.run([\"git\", \"init\", \"--initial-branch=main\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.name\", \"bot\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.email\", \"bot@scikit-build.org\"], check=True)\n subprocess.run([\"git\", \"add\", \"pyproject.toml\"], check=True)\n subprocess.run([\"git\", \"commit\", \"-m\", \"initial commit\"], check=True)\n subprocess.run([\"git\", \"tag\", \"v0.1.0\", \"-m\", \"initial commint\"], check=True)\n\n with Path(\"pyproject.toml\").open(\"rb\") as ft:\n pyproject = tomllib.load(ft)\n settings_reader = SettingsReader(pyproject, {}, state=\"metadata_wheel\")\n settings = settings_reader.settings\n\n settings_reader.validate_may_exit()\n\n metadata = get_standard_metadata(pyproject, settings)\n\n assert str(metadata.version) == \"0.1.0\"\n assert metadata.readme == pyproject_metadata.Readme(\n \"Fragment #1Fragment #2\", None, \"text/x-rst\"\n )\n\n assert set(GetRequires().dynamic_metadata()) == {\n \"hatch-fancy-pypi-readme>=22.3\",\n \"setuptools-scm\",\n }", + "completions": [ + "```python\n\ndef test_plugin_metadata():\n from skbuild import __version__ as skbuild_version\n from skbuild.settings import ScikitBuildSettings\n from skbuild.metadata import StandardMetadata\n from skbuild.utils import load_dynamic_metadata\n from skbuild.version import Version\n\n pyproject_dict = {\n \"build-system\": {\n \"requires\": [\"setuptools>=40.8.0\", \"wheel\"],\n \"build-backend\": \"setuptools.build_meta\",\n },\n \"project\": {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n \"description\": \"A build system for Python packages\",\n \"dynamic\": [\"name\", \"version\"],\n },\n }\n\n settings = ScikitBuildSettings(\n minimum_version=Version(\"0.1.0\"),\n metadata=[\n (\"name\", None, None),\n (\"version\", None, None),\n ],\n )\n\n metadata = get_standard_metadata(pyproject_dict, settings)\n\n assert metadata.name == \"scikit-build\"\n assert metadata.version == \"0.1.0\"\n assert metadata.description == \"A build system for Python packages\"\n\n assert metadata.canonical_name == \"scikit-build\"\n assert metadata.canonical_version == \"0.1.0\"\n\n assert metadata.dynamic_metadata == {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n }\n\n assert metadata.dynamic_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_providers == [None, None]\n\n assert metadata.dynamic_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_metadata_providers == [None, None]\n\n assert metadata.dynamic_metadata_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields_by_provider == {None: [\"name\", \"version\"]}\n\n assert metadata.dynamic_metadata_providers_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_configs_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field == {\n (None, \"name\"): \"scikit-build\",\n (None, \"version\"): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider == {\n (\"name\", None): \"scikit-build\",\n (\"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider == {\n (\"name\", None): None,\n (\"version\", None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config == {\n (None, \"name\", None): \"scikit-build\",\n (None, \"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config == {\n (\"name\", None, None): \"scikit-build\",\n (\"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider_and_config == {\n (\"name\", None, None): None,\n (\"version\", None, None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config_and_provider == {\n (None, \"name\", None, None): \"scikit-build\",\n (None, \"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config_and_provider == {\n (\"name\", None, None" + ], + "line": 40, + "token": 347, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 23, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _get_datetime_format(value):\n if not isinstance(value, pd.Series):\n value = pd.Series(value)\n\n value = value[~value.isna()]\n value = value.astype(str).to_numpy()\n\n return _guess_datetime_format_for_array(value)\n```\n###test function signature:\n```python\ndef test__get_datetime_format():\n # Setup\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test__get_datetime_format", + "reference": "def test__get_datetime_format():\n # Setup\n string_value = '2021-02-02'\n list_value = [np.nan, '2021-02-02']\n series_value = pd.Series(['2021-02-02T12:10:59'])\n\n # Run\n string_out = _get_datetime_format(string_value)\n list_out = _get_datetime_format(list_value)\n series_out = _get_datetime_format(series_value)\n\n # Assert\n expected_output = '%Y-%m-%d'\n assert string_out == expected_output\n assert list_out == expected_output\n assert series_out == '%Y-%m-%dT%H:%M:%S'", + "completions": [ + "```python\ndef test__get_datetime_format():\n # Setup\n value = pd.Series([\"2020-01-01\", \"2020-01-02\", \"2020-01-03\"])\n expected = \"%Y-%m-%d\"\n\n # Exercise\n actual = _get_datetime_format(value)\n\n # Verify\n assert actual == expected\n```\n" + ], + "line": 17, + "token": 124, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 24, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def load(self) -> None:\n super().load()\n\n _path = str(\n importlib_resources.files(\"distilabel\")\n / \"steps\"\n / \"tasks\"\n / \"templates\"\n / \"ultrafeedback\"\n / f\"{self.aspect}.jinja2\"\n )\n\n self._template = Template(open(_path).read())\n```\n###test function signature:\n```python\n\n def test_process_with_complex_aspect(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_process_with_complex_aspect", + "reference": "\n def test_process_with_complex_aspect(self) -> None:\n pipeline = Pipeline(name=\"unit-test-pipeline\")\n llm = UltraFeedbackLLM()\n\n task = UltraFeedback(\n name=\"ultrafeedback\",\n aspect=\"truthfulness\",\n llm=llm,\n pipeline=pipeline,\n )\n task.load()\n\n assert next(\n task.process([{\"instruction\": \"test\", \"generations\": [\"A\", \"B\"]}])\n ) == [\n {\n \"instruction\": \"test\",\n \"generations\": [\"A\", \"B\"],\n \"types\": [1, 2],\n \"rationales\": [\"text\", \"text\"],\n \"ratings\": [1, 2],\n \"rationales-for-ratings\": [\"text\", \"text\"],\n \"model_name\": \"ultrafeedback-model\",\n }\n ]", + "completions": [ + "```python\n def test_process_with_complex_aspect(self) -> None:\n self.aspect = \"complex_aspect\"\n self.load()\n self.assertEqual(\n self._template.render(\n {\n \"aspect\": self.aspect,\n \"aspect_name\": \"Complex Aspect\",\n \"aspect_description\": \"This is a complex aspect\",\n \"aspect_type\": \"complex\",\n \"aspect_value\": \"1\",\n \"aspect_weight\": \"1\",\n \"aspect_color\": \"red\",\n \"aspect_icon\": \"icon-circle\",\n \"aspect_icon_color\": \"red\",\n \"aspect_icon_bg_color\": \"white\",\n \"aspect_icon_size\": \"20\",\n \"aspect_icon_border_color\": \"red\",\n \"aspect_icon_border_width\": \"2\",\n \"aspect_icon_border_radius\": \"50\",\n \"aspect_icon_padding\": \"5\",\n \"aspect_icon_margin\": \"5\",\n \"aspect_icon_bg_size\": \"20\",\n \"aspect_icon_bg_border_color\": \"red\",\n \"aspect_icon_bg_border_width\": \"2\",\n \"aspect_icon_bg_border_radius\": \"50\",\n \"aspect_icon_bg_padding\": \"5\",\n \"aspect_icon_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_" + ], + "line": 21, + "token": 145, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 25, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _setup_cuda_device(device: torch.device) -> torch.device:\n local_rank = _get_local_rank() or 0\n if device.index is None:\n device = torch.device(type=\"cuda\", index=local_rank)\n\n # Ensure index is available before setting device\n if device.index >= torch.cuda.device_count():\n raise RuntimeError(\n \"The local rank is larger than the number of available GPUs.\"\n )\n\n torch.cuda.set_device(device)\n return device\n```\n###test function signature:\n```python\n\n def test_get_gpu_device(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_gpu_device", + "reference": "\n def test_get_gpu_device(self) -> None:\n device_idx = torch.cuda.device_count() - 1\n assert device_idx >= 0\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(device_idx)}, clear=True):\n device = get_device()\n assert device.type == \"cuda\"\n assert device.index == device_idx\n assert device.index == torch.cuda.current_device()\n\n # Test that we raise an error if the device index is specified on distributed runs\n if device_idx > 0:\n with pytest.raises(\n RuntimeError,\n match=f\"Device specified is cuda:0 but was assigned cuda:{device_idx}\",\n ):\n device = get_device(\"cuda:0\")\n\n invalid_device_idx = device_idx + 10\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(invalid_device_idx)}):\n with pytest.raises(\n RuntimeError,\n match=\"The local rank is larger than the number of available GPUs\",\n ):\n device = get_device(\"cuda\")\n\n # Test that we fall back to 0 if LOCAL_RANK is not specified\n device = torch.device(_get_device_type_from_env())\n device = _setup_cuda_device(device)\n assert device.type == \"cuda\"\n assert device.index == 0\n assert device.index == torch.cuda.current_device()", + "completions": [ + "```python\n def test_get_gpu_device(self) -> None:\n device = torch.device(type=\"cuda\", index=0)\n self.assertEqual(device, _setup_cuda_device(device))\n```\n" + ], + "line": 21, + "token": 157, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):\n local_dir = settings.home_dir / settings.cache_dir / \"models\" / model_id\n client.load_artifact(local_dir, posixpath.join(\"models\", model_id))\n meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)\n # internal worker case, no token based http client\n if not local_dir.exists():\n raise RuntimeError(f\"Cannot find existing model {project_key}.{model_id} in {local_dir}\")\n with (Path(local_dir) / META_FILENAME).open(encoding=\"utf-8\") as f:\n file_meta = yaml.load(f, Loader=yaml.Loader)\n classification_labels = cls.cast_labels(meta_response)\n meta = ModelMeta(\n name=meta_response.name,\n description=meta_response.description,\n model_type=SupportedModelTypes[meta_response.modelType],\n feature_names=meta_response.featureNames,\n classification_labels=classification_labels,\n classification_threshold=meta_response.threshold,\n loader_module=file_meta[\"loader_module\"],\n loader_class=file_meta[\"loader_class\"],\n )\n\n model_py_ver = (\n tuple(meta_response.languageVersion.split(\".\")) if \"PYTHON\" == meta_response.language.upper() else None\n )\n\n clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver)\n\n constructor_params = meta.__dict__\n constructor_params[\"id\"] = str(model_id)\n\n del constructor_params[\"loader_module\"]\n del constructor_params[\"loader_class\"]\n\n model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params)\n return model\n```\n###test function signature:\n```python\n\ndef test_model_download(request):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_model_download", + "reference": "\ndef test_model_download(request):\n model: BaseModel = request.getfixturevalue(\"german_credit_model\")\n project_key = str(uuid.uuid4())\n\n with utils.MockedProjectCacheDir():\n with utils.MockedClient(mock_all=False) as (client, mr):\n # The model needs to request files\n requested_urls = []\n requested_urls.extend(utils.register_uri_for_model_meta_info(mr, model, project_key))\n requested_urls.extend(\n utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True)\n )\n\n downloaded_model = BaseModel.download(client=client, project_key=project_key, model_id=str(model.id))\n\n for requested_url in requested_urls:\n assert utils.is_url_requested(mr.request_history, requested_url)\n\n assert downloaded_model.id == model.id\n assert downloaded_model.meta == model.meta", + "completions": [ + "```python\ndef test_model_download(request):\n from pathlib import Path\n from unittest.mock import MagicMock\n from giskard.client import GiskardClient\n from giskard.models import ModelMetaInfo, ModelMeta, SupportedModelTypes\n from giskard.utils import cast_labels\n from giskard.models.base import Model\n from giskard.models.classification import ClassificationModel\n from giskard.models.regression import RegressionModel\n from giskard.models.clustering import ClusteringModel\n from giskard.models.dim_reduction import DimReductionModel\n from giskard.models.anomaly_detection import AnomalyDetectionModel\n from giskard.models.time_series import TimeSeriesModel\n from giskard.models.text_classification import TextClassificationModel\n from giskard.models.text_recommendation import TextRecommendationModel\n from giskard.models.image_classification import ImageClassificationModel\n from giskard.models.image_recommendation import ImageRecommendationModel\n from giskard.models.image_segmentation import ImageSegmentationModel\n from giskard.models.image_generation import ImageGenerationModel\n from giskard.models.image_style_transfer import ImageStyleTransferModel\n from giskard.models.image_super_resolution import ImageSuperResolutionModel\n from giskard.models.image_denoising import ImageDenoisingModel\n from giskard.models.image_colorization import ImageColorizationModel\n from giskard.models.image_deraining import ImageDerainingModel\n from giskard.models.image_deblurring import ImageDeblurringModel\n from giskard.models.image_edge_detection import ImageEdgeDetectionModel\n from giskard.models.image_face_detection import ImageFaceDetectionModel\n from giskard.models.image_face_recognition import ImageFaceRecognitionModel\n from giskard.models.image_object_detection import ImageObjectDetectionModel\n from giskard.models.image_object_recognition import ImageObjectRecognitionModel\n from giskard.models.image_segmentation_mask_rcnn import ImageSegmentationMaskRCNNModel\n from giskard.models.image_segmentation_yolo import ImageSegmentationYOLOModel\n from giskard.models.image_segmentation_faster_rcnn import ImageSegmentationFasterRCNNModel\n from giskard.models.image_segmentation_unet import ImageSegmentationUNetModel\n from giskard.models.image_segmentation_deeplab import ImageSegmentationDeepLabModel\n from giskard.models.image_segmentation_fcn import ImageSegmentationFCNModel\n from giskard.models.image_segmentation_pspnet import ImageSegmentationPSPNetModel\n from giskard.models.image_segmentation_linknet import ImageSegmentationLinkNetModel\n from giskard.models.image_segmentation_unet_plus import ImageSegmentationUNetPlusModel\n from giskard.models.image_segmentation_deeplab_v3 import ImageSegmentationDeepLabV3Model\n from giskard.models.image_segmentation_deeplab_v3_plus import ImageSegmentationDeepLabV3PlusModel\n from giskard.models.image_segmentation_fcn_resnet101 import ImageSegmentationFCNResNet101Model\n from giskard.models.image_segmentation_fcn_resnet50 import ImageSegmentationFCNResNet50Model\n from giskard.models.image_segmentation_fcn_mobilenet_v2 import ImageSegmentationFCNMobileNetV2Model\n from giskard.models.image_segmentation_fcn_mobilenet_v3 import ImageSegmentationFCNMobileNetV3Model\n from giskard.models.image_segmentation_fcn_efficientnet_b" + ], + "line": 40, + "token": 316, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\ndef test_rounded_date_columns():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rounded_date_columns", + "reference": "def test_rounded_date_columns():\n df = pl.DataFrame(\n [\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.000000\", \"exp\": True},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.123456\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:01.000000\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00\", \"exp\": True},\n ]\n )\n\n df = df.with_columns(df[\"a\"].str.to_date().alias(\"a_dt\"))\n actual = columns_equal(df[\"a_dt\"], df[\"b\"])\n expected = df[\"exp\"]\n assert_series_equal(actual, expected, check_names=False)", + "completions": [ + "```python\ndef test_rounded_date_columns():\n # given\n col_1 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n col_2 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n\n # when\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n\n # then\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef annotations_to_instances(annos, image_size, mask_format=\"polygon\"):\n boxes = [BoxMode.convert(obj[\"bbox\"], obj[\"bbox_mode\"], BoxMode.XYXY_ABS) for obj in annos]\n target = Instances(image_size)\n boxes = target.gt_boxes = Boxes(boxes)\n boxes.clip(image_size)\n\n classes = [obj[\"category_id\"] for obj in annos]\n classes = torch.tensor(classes, dtype=torch.int64)\n target.gt_classes = classes\n\n if len(annos) and \"segmentation\" in annos[0]:\n segms = [obj[\"segmentation\"] for obj in annos]\n if mask_format == \"polygon\":\n masks = PolygonMasks(segms)\n else:\n assert mask_format == \"bitmask\", mask_format\n masks = []\n for segm in segms:\n if isinstance(segm, list):\n # polygon\n masks.append(polygons_to_bitmask(segm, *image_size))\n elif isinstance(segm, dict):\n # COCO RLE\n masks.append(mask_util.decode(segm))\n elif isinstance(segm, np.ndarray):\n assert segm.ndim == 2, \"Expect segmentation of 2 dimensions, got {}.\".format(\n segm.ndim\n )\n # mask array\n masks.append(segm)\n else:\n raise ValueError(\n \"Cannot convert segmentation of type '{}' to BitMasks!\"\n \"Supported types are: polygons as list[list[float] or ndarray],\"\n \" COCO-style RLE as a dict, or a full-image segmentation mask \"\n \"as a 2D ndarray.\".format(type(segm))\n )\n # torch.from_numpy does not support array with negative stride.\n masks = BitMasks(\n torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])\n )\n target.gt_masks = masks\n\n if len(annos) and \"keypoints\" in annos[0]:\n kpts = [obj.get(\"keypoints\", []) for obj in annos]\n target.gt_keypoints = Keypoints(kpts)\n\n return target\n```\n###test function signature:\n```python\n\n def test_transform_simple_annotation(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_simple_annotation", + "reference": "\n def test_transform_simple_annotation(self):\n transforms = T.TransformList([T.HFlipTransform(400)])\n anno = {\n \"bbox\": np.asarray([10, 10, 200, 300]),\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 3,\n \"segmentation\": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],\n }\n\n output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))\n self.assertTrue(np.allclose(output[\"bbox\"], [200, 10, 390, 300]))\n self.assertEqual(len(output[\"segmentation\"]), len(anno[\"segmentation\"]))\n self.assertTrue(np.allclose(output[\"segmentation\"][0], [390, 10, 300, 100, 300, 10]))\n\n detection_utils.annotations_to_instances([output, output], (400, 400))", + "completions": [ + "```python\n def test_transform_simple_annotation(self):\n annos = [\n {\n \"bbox\": [10, 10, 100, 100],\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 0,\n \"segmentation\": [\n [10, 10, 50, 50, 90, 90, 100, 100],\n [20, 20, 50, 50, 80, 80, 100, 100],\n ],\n \"keypoints\": [10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n }\n ]\n image_size = (100, 100)\n target = annotations_to_instances(annos, image_size)\n assert target.gt_boxes.tensor.tolist() == [[10, 10, 100, 100]]\n assert target.gt_classes.tolist() == [0]\n assert target.gt_masks.polygons().tolist() == [\n [[10, 10], [50, 50], [90, 90], [100, 100]],\n [[20, 20], [50, 50], [80, 80], [100, 100]],\n ]\n assert target.gt_keypoints.tensor.tolist() == [[10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n```\n" + ], + "line": 54, + "token": 437, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 29, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_simple(seed):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_simple", + "reference": "\ndef test_sabre_simple(seed):\n placer = Trivial()\n circ = Circuit(5)\n circ.add(gates.CZ(0, 1))\n initial_layout = placer(circ)\n router = Sabre(connectivity=star_connectivity(), seed=seed)\n routed_circuit, final_map = router(circuit=circ, initial_layout=initial_layout)\n assert router.added_swaps == 1\n assert final_map == {\"q0\": 2, \"q1\": 1, \"q2\": 0, \"q3\": 3, \"q4\": 4}\n assert routed_circuit.queue[0].qubits == (0, 2)\n assert isinstance(routed_circuit.queue[0], gates.SWAP)\n assert isinstance(routed_circuit.queue[1], gates.CZ)\n assert_connectivity(circuit=routed_circuit, connectivity=star_connectivity())\n assert_circuit_equivalence(\n original_circuit=circ,\n transpiled_circuit=routed_circuit,\n final_map=final_map,\n initial_map=initial_layout,\n )", + "completions": [ + "```python\ndef test_sabre_simple(seed):\n import random\n random.seed(seed)\n import networkx as nx\n from pyquil.quil import Program\n from pyquil.api import QVMConnection\n from pyquil.gates import CNOT, H, MEASURE\n from pyquil.quilbase import Declare\n from pyquil.quilbase import DefGate\n from pyquil.quilbase import DefCircuit\n from pyquil.quilbase import Gate\n from pyquil.quilbase import Circuit\n from pyquil.quilbase import QubitPlaceholder\n from pyquil.quilbase import MemoryReference\n from pyquil.quilbase import Parameter\n from pyquil.quilbase import Pragma\n from pyquil.quilbase import PragmaDeclare\n from pyquil.quilbase import PragmaDefGate\n from pyquil.quilbase import PragmaDefCircuit\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefQubit\n from pyquil.quilbase import PragmaDefMemory\n from pyquil.quilbase import PragmaDefRegister\n from pyquil.quilbase import PragmaDefLabel\n from pyquil.quilbase import PragmaDefFunction\n from pyquil.quilbase import PragmaDefProcedure\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyqu" + ], + "line": 27, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 30, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef unique_patterns_2d(\n agrid: NDArray[np.int64], ksize: int, periodic_input: bool\n) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:\n assert ksize >= 1\n if periodic_input:\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n else:\n # TODO: implement non-wrapped image handling\n # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n\n patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(\n agrid,\n (\n agrid.shape[0] - ksize + 1,\n agrid.shape[1] - ksize + 1,\n ksize,\n ksize,\n *agrid.shape[2:],\n ),\n agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],\n writeable=False,\n )\n patch_codes = hash_downto(patches, 2)\n uc, ui = np.unique(patch_codes, return_index=True)\n locs = np.unravel_index(ui, patch_codes.shape)\n up: NDArray[np.int64] = patches[locs[0], locs[1]]\n ids: NDArray[np.int64] = np.vectorize(\n {code: ind for ind, code in enumerate(uc)}.get\n )(patch_codes)\n return ids, up, patch_codes\n```\n###test function signature:\n```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_patterns_2d", + "reference": "\ndef test_unique_patterns_2d(img_redmaze) -> None:\n img = img_redmaze\n tile_size = 1\n pattern_width = 2\n _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(\n img, tile_size\n )\n\n (\n _patterns_in_grid,\n pattern_contents_list,\n patch_codes,\n ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)\n assert patch_codes[1][2] == 4867810695119132864\n assert pattern_contents_list[7][1][1] == 8253868773529191888", + "completions": [ + "```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n ids, up, patch_codes = unique_patterns_2d(img_redmaze, 3, True)\n assert ids.shape == (10,)\n assert up.shape == (10, 3, 3)\n assert patch_codes.shape == (10,)\n```\n" + ], + "line": 48, + "token": 425, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 31, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef serialize(data: Any, encoding: str | None = None) -> str | bytes:\n serialized: bytes | str\n\n if encoding is not None:\n serialized = yaml.dump(data, encoding=encoding, Dumper=AiiDADumper)\n else:\n serialized = yaml.dump(data, Dumper=AiiDADumper)\n\n return serialized\n```\n###test function signature:\n```python\ndef test_mixed_attribute_normal_dict():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mixed_attribute_normal_dict", + "reference": "def test_mixed_attribute_normal_dict():\n from aiida.common.extendeddicts import AttributeDict\n\n # Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively\n dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}\n attribute_dict = AttributeDict(dictionary)\n\n # Now add a normal dictionary in the attribute dictionary\n attribute_dict['nested']['normal'] = {'a': 2}\n\n serialized = serialize.serialize(attribute_dict)\n deserialized = serialize.deserialize_unsafe(serialized)\n\n assert attribute_dict, deserialized", + "completions": [ + "```python\ndef test_mixed_attribute_normal_dict():\n data = {\n \"a\": 1,\n \"b\": 2,\n \"c\": 3,\n \"d\": 4,\n \"e\": 5,\n \"f\": 6,\n \"g\": 7,\n \"h\": 8,\n \"i\": 9,\n \"j\": 10,\n \"k\": 11,\n \"l\": 12,\n \"m\": 13,\n \"n\": 14,\n \"o\": 15,\n \"p\": 16,\n \"q\": 17,\n \"r\": 18,\n \"s\": 19,\n \"t\": 20,\n \"u\": 21,\n \"v\": 22,\n \"w\": 23,\n \"x\": 24,\n \"y\": 25,\n \"z\": 26,\n }\n assert serialize(data) == \"\"\"a: 1\nb: 2\nc: 3\nd: 4\ne: 5\nf: 6\ng: 7\nh: 8\ni: 9\nj: 10\nk: 11\nl: 12\nm: 13\nn: 14\no: 15\np: 16\nq: 17\nr: 18\ns: 19\nt: 20\nu: 21\nv: 22\nw: 23\nx: 24\ny: 25\nz: 26\n\"\"\"\n```\n" + ], + "line": 17, + "token": 133, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected_against_alias", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n http_client.select_model(model_id=\"yolov8n-640\")\n\n # when\n result = http_client.unload_model(model_id=\"yolov8n-640\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"coco/3\",\n }\n assert (\n http_client.selected_model is None\n ), \"Even when alias is in use - selected model should be emptied\"", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n requests_mock.post(\n f\"{API_URL}/model/remove\",\n json={\n \"model_id\": \"test_model_id\",\n },\n headers=DEFAULT_HEADERS,\n )\n client = RoboflowClient(api_key=\"test_api_key\", api_secret=\"test_api_secret\")\n # when\n response = client.unload_model(model_id=\"test_model_alias\")\n # then\n assert response.model_id == \"test_model_id\"\n assert response.model_name == \"test_model_name\"\n assert response.model_version == \"test_model_version\"\n assert response.model_type == \"test_model_type\"\n assert response.model_status == \"test_model_status\"\n assert response.model_created_at == \"test_model_created_at\"\n assert response.model_updated_at == \"test_model_updated_at\"\n assert response.model_tags == \"test_model_tags\"\n assert response.model_description == \"test_model_description\"\n assert response.model_url == \"test_model_url\"\n assert response.model_thumbnail_url == \"test_model_thumbnail_url\"\n assert response.model_training_data_url == \"test_model_training_data_url\"\n assert response.model_training_data_type == \"test_model_training_data_type\"\n assert response.model_training_data_size == \"test_model_training_data_size\"\n assert response.model_training_data_description == \"test_model_training_data_description\"\n assert response.model_training_data_tags == \"test_model_training_data_tags\"\n assert response.model_training_data_created_at == \"test_model_training_data_created_at\"\n assert response.model_training_data_updated_at == \"test_model_training_data_updated_at\"\n assert response.model_training_data_status == \"test_model_training_data_status\"\n assert response.model_training_data_status_message == \"test_model_training_data_status_message\"\n assert response.model_training_data_status_details == \"test_model_training_data_status_details\"\n assert response.model_training_data_status_created_at == \"test_model_training_data_status_created_at\"\n assert response.model_training_data_status_updated_at == \"test_model_training_data_status_updated_at\"\n assert response.model_training_data_status_completed_at == \"test_model_training_data_status_completed_at\"\n assert response.model_training_data_status_duration == \"test_model_training_data_status_duration\"\n assert response.model_training_data_status_budget_type == \"test_model_training_data_status_budget_type\"\n assert response.model_training_data_status_budget_amount == \"test_model_training_data_status_budget_amount\"\n assert response.model_training_data_status_budget_unit == \"test_model_training_data_status_budget_unit\"\n assert response.model_training_data_status_budget_remaining == \"test_model_training_data_status_budget_remaining\"\n assert response.model_training_data_status_budget_message == \"test_model_training_data_status_budget_message\"\n assert response.model_training_data_status_budget_details == \"test_model_training_data_status_budget_details\"\n assert response.model_training_data_status_budget_created_at == \"test_model_training_data_status_budget_created_at\"\n assert response.model_training_data_status_budget_updated_at == \"test_" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 33, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef edge_index_to_vector(\n edge_index: Tensor,\n size: Tuple[int, int],\n bipartite: bool,\n force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n row, col = edge_index\n\n if bipartite: # No need to account for self-loops.\n idx = (row * size[1]).add_(col)\n population = size[0] * size[1]\n return idx, population\n\n elif force_undirected:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We only operate on the upper triangular matrix:\n mask = row < col\n row, col = row[mask], col[mask]\n offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n idx = row.mul_(num_nodes).add_(col).sub_(offset)\n population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n return idx, population\n\n else:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We remove self-loops as we do not want to take them into account\n # when sampling negative values.\n mask = row != col\n row, col = row[mask], col[mask]\n col[row < col] -= 1\n idx = row.mul_(num_nodes - 1).add_(col)\n population = num_nodes * num_nodes - num_nodes\n return idx, population\n```\n###test function signature:\n```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_edge_index_to_vector_and_vice_versa", + "reference": "\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n N = 10\n row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n edge_index = torch.stack([row, col], dim=0)\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n assert population == N * N\n assert idx.tolist() == list(range(population))\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n assert is_undirected(edge_index2)\n assert edge_index.tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n assert population == N * N - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n force_undirected=True)\n assert population == (N * (N + 1)) / 2 - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n force_undirected=True)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()", + "completions": [ + "```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n\n # Test force_undirected:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test force_undirected with bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n```\n" + ], + "line": 43, + "token": 351, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 34, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_integration_test(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_integration_test", + "reference": "\n def test_spectrogram_integration_test(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,\n 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,\n 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,\n 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,\n 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,\n 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,\n 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,\n 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,\n 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,\n 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,\n 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,\n 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,\n 0.0293578 , 0.03452379, 0.02194803, 0.01676056,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=256,\n num_mel_filters=400,\n min_frequency=20,\n max_frequency=8000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"kaldi\",\n triangularize_in_mel_space=True,\n )\n\n mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"povey\", periodic=False),\n frame_length=400,\n hop_length=160,\n fft_length=512,\n power=2.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n preemphasis=0.97,\n mel_filters=mel_filters,\n log_mel=\"log\",\n mel_floor=1.1920928955078125e-07,\n remove_dc_offset=True,\n )\n self.assertEqual(spec.shape, (400, 584))\n\n # fmt: off\n expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -6.52463769, -7.73677889, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.18650018, -3.37195286,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -4.70190154, -2.4217066 , -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -5.62755239, -3.53385194,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -9.43303023, -8.77480925, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.2951092 , -5.51585994,\n -15.94238515, -15.94238515, -15.94238515, -4.40151721,\n -3.95228878, -15.94238515, -15.94238515, -15.94238515,\n -6.10365415, -4.59494697, -15.94238515, -15.94238515,\n -15.94238515, -8.10727767, -6.2585298 , -15.94238515,\n -15.94238515, -15.94238515, -5.60161702, -4.47217004,\n -15.94238515, -15.94238515, -15.94238515, -5.91641988]\n )\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))", + "completions": [ + "```python\n def test_spectrogram_integration_test(self):\n waveform = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])\n window = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n expected_spectrogram = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]])\n\n self.assertTrue(np.allclose(spectrogram, expected_spectrogram))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 35, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef run_query_expansion_node(modules: List[Callable],\n module_params: List[Dict],\n previous_result: pd.DataFrame,\n node_line_dir: str,\n strategies: Dict,\n ) -> pd.DataFrame:\n if not os.path.exists(node_line_dir):\n os.makedirs(node_line_dir)\n node_dir = os.path.join(node_line_dir, \"query_expansion\")\n if not os.path.exists(node_dir):\n os.makedirs(node_dir)\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n\n # run query expansion\n results, execution_times = zip(*map(lambda task: measure_speed(\n task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))\n average_times = list(map(lambda x: x / len(results[0]), execution_times))\n\n # save results to folder\n pseudo_module_params = deepcopy(module_params)\n for i, module_param in enumerate(pseudo_module_params):\n if 'prompt' in module_params:\n module_param['prompt'] = str(i)\n filepaths = list(map(lambda x: os.path.join(node_dir, f'{x}.parquet'), range(len(modules))))\n list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet\n filenames = list(map(lambda x: os.path.basename(x), filepaths))\n\n # make summary file\n summary_df = pd.DataFrame({\n 'filename': filenames,\n 'module_name': list(map(lambda module: module.__name__, modules)),\n 'module_params': module_params,\n 'execution_time': average_times,\n })\n\n # Run evaluation when there are more than one module.\n if len(modules) > 1:\n # pop general keys from strategies (e.g. metrics, speed_threshold)\n general_key = ['metrics', 'speed_threshold']\n general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))\n extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))\n\n # first, filter by threshold if it is enabled.\n if general_strategy.get('speed_threshold') is not None:\n results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],\n filenames)\n\n # check metrics in strategy\n if general_strategy.get('metrics') is None:\n raise ValueError(\"You must at least one metrics for query expansion evaluation.\")\n\n if extra_strategy.get('top_k') is None:\n extra_strategy['top_k'] = 10 # default value\n\n # get retrieval modules from strategy\n retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)\n\n # get retrieval_gt\n retrieval_gt = pd.read_parquet(os.path.join(project_dir, \"data\", \"qa.parquet\"))['retrieval_gt'].tolist()\n\n # run evaluation\n evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(\n retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,\n general_strategy['metrics'], project_dir, previous_result), results))\n\n evaluation_df = pd.DataFrame({\n 'filename': filenames,\n **{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))\n for metric_name in general_strategy['metrics']}\n })\n summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')\n\n best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)\n # change metric name columns to query_expansion_metric_name\n best_result = best_result.rename(columns={\n metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})\n best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])\n else:\n best_result, best_filename = results[0], filenames[0]\n best_result = pd.concat([previous_result, best_result], axis=1)\n\n # add 'is_best' column at summary file\n summary_df['is_best'] = summary_df['filename'] == best_filename\n\n # save files\n summary_df.to_csv(os.path.join(node_dir, \"summary.csv\"), index=False)\n best_result.to_parquet(os.path.join(node_dir, f\"best_{os.path.splitext(best_filename)[0]}.parquet\"), index=False)\n\n return best_result\n```\n###test function signature:\n```python\n\ndef test_run_query_expansion_node_default(node_line_dir):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_query_expansion_node_default", + "reference": "\ndef test_run_query_expansion_node_default(node_line_dir):\n generator_models['mock'] = MockLLM\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n qa_path = os.path.join(project_dir, \"data\", \"qa.parquet\")\n previous_result = pd.read_parquet(qa_path)\n\n modules = [query_decompose, hyde]\n module_params = [{'llm': \"mock\", 'temperature': 0.2, 'batch': 7}, {'llm': \"mock\"}]\n strategies = {\n 'metrics': metrics\n }\n best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n base_query_expansion_test(best_result, node_line_dir)", + "completions": [ + "```python\ndef test_run_query_expansion_node_default(node_line_dir):\n modules = [query_expansion_module]\n module_params = [{'prompt': 'default'}]\n previous_result = pd.DataFrame({'query': ['a', 'b', 'c']})\n strategies = {'metrics': ['recall', 'precision', 'f1'], 'speed_threshold': 1000}\n run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n```\n" + ], + "line": 84, + "token": 850, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 36, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(self) -> Tuple[pd.DataFrame, FreqDetectionResult]:\n try:\n if self._prophet_detector_params[\"suppress_stan\"]:\n pd.set_option(\"mode.chained_assignment\", None)\n\n # Skip measurements based on feedbacks given from SODA Cloud\n preprocessed_df = self.preprocess(time_series_df=self.raw_time_series_df)\n\n # Automatically detect frequency of the time series\n freq_detector = FrequencyDetector(\n logs=self.logs,\n params=self.params,\n time_series_df=preprocessed_df,\n manual_freq=self.training_dataset_params.frequency,\n )\n freq_detection_result = freq_detector.detect_frequency()\n\n # Return if frequency detection failed\n if freq_detection_result.error_code_int >= ERROR_CODE_LEVEL_CUTTOFF:\n return self.exit_with_warning(freq_detection_result)\n\n # Apply training dataset configurations\n training_df = self.apply_training_dataset_configs(\n time_series_df=preprocessed_df, freq_detection_result=freq_detection_result\n )\n\n # Remove big gaps from the time series to not confuse Prophet\n training_df = self.remove_big_gaps_from_time_series(\n time_series_df=training_df, freq_detection_result=freq_detection_result\n )\n\n # Only use the last n points for training based on the window length\n window_length = self.get_window_length(training_df=training_df)\n training_df = training_df.iloc[-window_length:]\n\n training_df_shape = training_df[\"y\"].dropna().shape[0]\n if training_df_shape <= self._min_n_points:\n freq_detection_result = get_not_enough_measurements_freq_result(n_data_points=training_df_shape)\n return self.exit_with_warning(freq_detection_result)\n\n model_hyperparameters = self.get_prophet_hyperparameters(time_series_df=training_df)\n\n predictions_df = self.setup_fit_predict(\n time_series_df=training_df, model_hyperparameters=model_hyperparameters\n )\n anomalies_df = self.detect_anomalies(time_series_df=training_df, predictions_df=predictions_df)\n anomalies_df = self.generate_severity_zones(anomalies_df=anomalies_df)\n anomalies_df = self.compute_alert_level(anomalies_df=anomalies_df)\n return anomalies_df, freq_detection_result\n except Exception as e:\n raise e\n```\n###test function signature:\n```python\n\ndef test_with_exit() -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_exit", + "reference": "\ndef test_with_exit() -> None:\n time_series_df = generate_random_dataframe(size=2, n_rows_to_convert_none=0, frequency=\"D\")\n detector = ProphetDetector(\n logs=LOGS,\n params=PARAMS,\n time_series_df=time_series_df,\n model_cfg=ModelConfigs(),\n training_dataset_params=TrainingDatasetParameters(),\n severity_level_params=SeverityLevelParameters(),\n )\n df_anomalies, frequency_result = detector.run()\n assert df_anomalies.empty\n assert frequency_result.error_code_int == 100\n assert frequency_result.freq_detection_strategy == \"not_enough_measurements_custom\"", + "completions": [ + "```python\ndef test_with_exit() -> None:\n from soda.sodacl.checks.check_frequency import FrequencyDetector\n from soda.sodacl.checks.check_frequency import get_not_enough_measurements_freq_result\n from soda.sodacl.checks.check_frequency import ERROR_CODE_LEVEL_CUTTOFF\n from soda.sodacl.checks.check_frequency import FreqDetectionResult\n from soda.sodacl.checks.check_frequency import CheckFrequency\n from soda.sodacl.checks.check_frequency import CheckFrequencyParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyLogs\n from soda.sodacl.checks.check_frequency import CheckFrequencyTrainingDatasetParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyPro" + ], + "line": 52, + "token": 310, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 37, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef average_transforms(\n R: torch.Tensor,\n t: torch.Tensor,\n w: torch.Tensor,\n mask: torch.Tensor,\n dim: int,\n t_edge: Optional[torch.Tensor] = None,\n dither: Optional[bool] = True,\n dither_eps: float = 1e-4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert dim >= 0, \"dimension must index from the left\"\n w = torch.where(\n mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)\n )\n\n # We use different averaging models based on the number of weights\n num_transform_weights = w.size(-1)\n if num_transform_weights == 1:\n # Share a single scalar weight between t and R.\n probs = w.softmax(dim)\n t_probs = probs\n R_probs = probs[..., None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 2:\n # Use separate scalar weights for each of t and R.\n probs = w.softmax(dim)\n t_probs, R_probs = probs.unbind(-1)\n t_probs = t_probs[..., None]\n R_probs = R_probs[..., None, None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 3:\n # For R use a signed scalar weight.\n R_probs = w[..., 2].softmax(dim)[..., None, None]\n\n # For t use a two-parameter precision matrix P = P_isometric + P_radial.\n # We need to hand compute softmax over the shared dim x 2 elements.\n w_t = w[..., :2]\n w_t_total = w_t.logsumexp([dim, -1], True)\n p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)\n\n # Use Gaussian fusion for translation.\n t_edge = t_edge * mask.to(t_edge.dtype)[..., None]\n t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)\n else:\n raise NotImplementedError\n\n # Average rotation via SVD\n R_avg_unc = (R * R_probs).sum(dim)\n R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)\n U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)\n R_avg = U @ Vh\n\n # Enforce that matrix is rotation matrix\n d = torch.linalg.det(R_avg)\n d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)\n Vh = Vh * d_expand\n R_avg = U @ Vh\n return R_avg, t_avg\n```\n###test function signature:\n```python\n\ndef test_average_transforms(rotations, translations):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_average_transforms", + "reference": "\ndef test_average_transforms(rotations, translations):\n R = torch.stack([rotations[0], torch.eye(3)])\n t = torch.stack([translations[0], torch.zeros(3)])\n w = torch.ones(2, 2)\n mask = torch.ones(2)\n # average of a transform with the identity is \"half\" the transform\n R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)\n R_total_fromavg, _ = compose_transforms(\n R_avg, torch.zeros(3), R_avg, torch.zeros(3)\n )\n _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)\n assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)\n assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)", + "completions": [ + "```python\ndef test_average_transforms(rotations, translations):\n # test case 1\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 2\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 3\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 4\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 5\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 6\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 7\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 8\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 9\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w" + ], + "line": 65, + "token": 537, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n return GPT2Config(\n vocab_size=btlm_config.vocab_size,\n n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n n_embd=btlm_config.hidden_size,\n n_layer=btlm_config.num_hidden_layers,\n n_head=btlm_config.num_attention_heads,\n n_inner=btlm_config.n_inner,\n activation_function=btlm_config.activation_function,\n resid_pdrop=btlm_config.resid_pdrop,\n embd_pdrop=btlm_config.embd_pdrop,\n attn_pdrop=btlm_config.attn_pdrop,\n layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n initializer_range=btlm_config.initializer_range,\n bos_token_id=btlm_config.bos_token_id,\n eos_token_id=btlm_config.eos_token_id,\n # These are new arguments not in the original GPT2Config\n use_alibi=btlm_config.position_embedding_type == \"alibi\",\n use_flash_attn=btlm_config.position_embedding_type == \"alibi\", # Alibi code path requires flash_attn\n mup_width_scale=btlm_config.mup_width_scale,\n mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n mup_output_multiplier=btlm_config.mup_output_alpha,\n mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n mlp_multiple_of=1,\n )\n```\n###test function signature:\n```python\ndef test_btlm_optimized(model_name):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_btlm_optimized", + "reference": "def test_btlm_optimized(model_name):\n dtype = torch.float16\n device = \"cuda\"\n config = btlm_config_to_gpt2_config(\n AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n )\n config.fused_bias_fc = True\n config.fused_dropout_add_ln = True\n config.residual_in_fp32 = True\n\n pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n model = GPTLMHeadModel(config, device=device, dtype=dtype)\n model.load_state_dict(pretrained_state_dict)\n model.eval()\n\n torch.manual_seed(0)\n batch_size = 2\n max_seqlen = 256\n seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n input_ids = torch.randint(\n 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n )\n with torch.no_grad():\n out = model.transformer(input_ids)\n logits = model(input_ids).logits\n del model\n\n # Without device_map, the model is loaded on the CPU, which is very slow\n # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n model_ref = AutoModelForCausalLM.from_pretrained(\n model_name, device_map=\"auto\", trust_remote_code=True\n )\n model_ref.eval()\n with torch.no_grad():\n out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n logits_ref = model_ref(input_ids).logits.to(device=device)\n del model_ref\n\n model_hf = AutoModelForCausalLM.from_pretrained(\n model_name,\n torch_dtype=dtype,\n device_map={\"\": device},\n trust_remote_code=True,\n )\n model_hf.eval()\n with torch.no_grad():\n out_hf = model_hf.transformer(input_ids).last_hidden_state\n logits_hf = model_hf(input_ids).logits\n del model_hf\n\n print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n assert (logits - logits_ref).abs().max().item() < 3 * (\n logits_hf - logits_ref\n ).abs().max().item()", + "completions": [ + "```python\ndef test_btlm_optimized(model_name):\n from transformers import GPT2Config, PretrainedConfig\n from bert_extended.optimized import btlm_config_to_gpt2_config\n\n config = PretrainedConfig.from_pretrained(model_name)\n gpt2_config = btlm_config_to_gpt2_config(config)\n gpt2_config.to_json_file(f\"{model_name}_optimized.json\")\n```\n" + ], + "line": 35, + "token": 169, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 39, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef grad_solid_harmonic(r: jnp.ndarray, l_max: int) -> jnp.ndarray:\n def cartesian_sum(r, ell):\n # computes all cartesian sums of order ell\n # i.e. all terms (a*x + b*y + c*z) for a+b+c = ell\n if ell == 0:\n return np.array([0.])\n return np.sum(np.stack(\n np.meshgrid(*[r]*ell, indexing='ij'), axis=0).reshape(ell, -1), axis=0)\n\n dharmonics = []\n for ell in range(l_max + 1):\n dx = cartesian_sum(np.array([1., 0., 0.]), ell)\n dy = cartesian_sum(np.array([0., 1., 0.]), ell)\n dz = cartesian_sum(np.array([0., 0., 1.]), ell)\n d = np.stack([dx, dy, dz], axis=1)\n cp = jax.vmap(functools.partial(cartesian_product, ell=ell), out_axes=-1)\n dcp = jnp.where(r.T[None, :, :] == 0, 0,\n d[:, :, None]*cp(r)[:, None, :] / r.T[None, :, :])\n sh = jnp.einsum('cs,cxn->snx', full_cart2sph(ell, True), dcp)\n pad = l_max - ell\n sh = jnp.pad(sh, [[pad, pad], [0, 0], [0, 0]])\n dharmonics.append(sh)\n\n return jnp.stack(dharmonics, axis=1) # [L, M, grid, xyz]\n```\n###test function signature:\n```python\n\n def test_grad_solid_harmonic(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_grad_solid_harmonic", + "reference": "\n def test_grad_solid_harmonic(self):\n np.random.seed(0)\n r = np.random.randn(100, 3)\n l_max = 5\n\n jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max))\n expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])),\n [1, 2, 0, 3])\n\n with self.subTest('by hand'):\n observed = gto.grad_solid_harmonic(r, l_max)\n np.testing.assert_allclose(observed, expected, atol=1.e-4)\n with self.subTest('by jax'):\n observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max)\n np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4)", + "completions": [ + "```python\n def test_grad_solid_harmonic(self):\n r = jnp.array([[0.5, 0.5, 0.5]])\n l_max = 2\n dharmonics = grad_solid_harmonic(r, l_max)\n self.assertEqual(dharmonics.shape, (l_max + 1, l_max + 1, 1, 3))\n self.assertTrue(jnp.allclose(dharmonics[0, 0, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 0, 0, :], jnp.array([1., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 1, 0, :], jnp.array([0., 1., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 0, 0, :], jnp.array([0., 0., 1.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 1, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 2, 0, :], jnp.array([0., 0., 0.])))\n```\n" + ], + "line": 32, + "token": 389, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 40, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_basis(\n base_shape, angular_tesselation, remove_symmetries=True, eps=1e-4\n):\n\n if base_shape == 'tetrahedron':\n verts = np.array([\n (np.sqrt(8 / 9), 0, -1 / 3),\n (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),\n (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3),\n (0, 0, 1),\n ])\n faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)])\n elif base_shape == 'icosahedron':\n a = (np.sqrt(5) + 1) / 2\n verts = np.array([\n (-1, 0, a),\n (1, 0, a),\n (-1, 0, -a),\n (1, 0, -a),\n (0, a, 1),\n (0, a, -1),\n (0, -a, 1),\n (0, -a, -1),\n (a, 1, 0),\n (-a, 1, 0),\n (a, -1, 0),\n (-a, -1, 0),\n ]) / np.sqrt(a + 2)\n faces = np.array([\n (0, 4, 1),\n (0, 9, 4),\n (9, 5, 4),\n (4, 5, 8),\n (4, 8, 1),\n (8, 10, 1),\n (8, 3, 10),\n (5, 3, 8),\n (5, 2, 3),\n (2, 7, 3),\n (7, 10, 3),\n (7, 6, 10),\n (7, 11, 6),\n (11, 0, 6),\n (0, 1, 6),\n (6, 1, 10),\n (9, 0, 11),\n (9, 11, 2),\n (9, 2, 5),\n (7, 2, 11),\n ])\n elif base_shape == 'octahedron':\n verts = np.array(\n [(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]\n )\n corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n else:\n raise ValueError(f'base_shape {base_shape} not supported')\n verts = tesselate_geodesic(verts, faces, angular_tesselation)\n\n if remove_symmetries:\n # Remove elements of `verts` that are reflections of each other.\n match = compute_sq_dist(verts.T, -verts.T) < eps\n verts = verts[~np.any(np.triu(match), axis=0), :]\n\n basis = verts[:, ::-1]\n return basis\n```\n###test function signature:\n```python\n def test_generate_basis_golden(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_basis_golden", + "reference": " def test_generate_basis_golden(self):\n\n basis = geopoly.generate_basis('tetrahedron', 1)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('tetrahedron', 2)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.57735027, 0.00000000, -0.81649658],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.57735027, -0.70710678, 0.40824829],\n [-0.57735027, 0.70710678, 0.40824829],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('icosahedron', 2)\n basis_golden = np.array([\n [0.85065081, 0.00000000, 0.52573111],\n [0.80901699, 0.50000000, 0.30901699],\n [0.52573111, 0.85065081, 0.00000000],\n [1.00000000, 0.00000000, 0.00000000],\n [0.80901699, 0.50000000, -0.30901699],\n [0.85065081, 0.00000000, -0.52573111],\n [0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, -0.85065081],\n [0.50000000, 0.30901699, -0.80901699],\n [0.00000000, 1.00000000, 0.00000000],\n [-0.52573111, 0.85065081, 0.00000000],\n [-0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, 0.85065081],\n [-0.30901699, 0.80901699, 0.50000000],\n [0.30901699, 0.80901699, 0.50000000],\n [0.50000000, 0.30901699, 0.80901699],\n [0.50000000, -0.30901699, 0.80901699],\n [0.00000000, 0.00000000, 1.00000000],\n [-0.50000000, 0.30901699, 0.80901699],\n [-0.80901699, 0.50000000, 0.30901699],\n [-0.80901699, 0.50000000, -0.30901699],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('octahedron', 4)\n basis_golden = np.array([\n [0.00000000, 0.00000000, -1.00000000],\n [0.00000000, -0.31622777, -0.94868330],\n [0.00000000, -0.70710678, -0.70710678],\n [0.00000000, -0.94868330, -0.31622777],\n [0.00000000, -1.00000000, 0.00000000],\n [-0.31622777, 0.00000000, -0.94868330],\n [-0.40824829, -0.40824829, -0.81649658],\n [-0.40824829, -0.81649658, -0.40824829],\n [-0.31622777, -0.94868330, 0.00000000],\n [-0.70710678, 0.00000000, -0.70710678],\n [-0.81649658, -0.40824829, -0.40824829],\n [-0.70710678, -0.70710678, 0.00000000],\n [-0.94868330, 0.00000000, -0.31622777],\n [-0.94868330, -0.31622777, 0.00000000],\n [-1.00000000, 0.00000000, 0.00000000],\n [0.00000000, -0.31622777, 0.94868330],\n [0.00000000, -0.70710678, 0.70710678],\n [0.00000000, -0.94868330, 0.31622777],\n [0.40824829, -0.40824829, 0.81649658],\n [0.40824829, -0.81649658, 0.40824829],\n [0.31622777, -0.94868330, 0.00000000],\n [0.81649658, -0.40824829, 0.40824829],\n [0.70710678, -0.70710678, 0.00000000],\n [0.94868330, -0.31622777, 0.00000000],\n [0.31622777, 0.00000000, -0.94868330],\n [0.40824829, 0.40824829, -0.81649658],\n [0.40824829, 0.81649658, -0.40824829],\n [0.70710678, 0.00000000, -0.70710678],\n [0.81649658, 0.40824829, -0.40824829],\n [0.94868330, 0.00000000, -0.31622777],\n [0.40824829, -0.40824829, -0.81649658],\n [0.40824829, -0.81649658, -0.40824829],\n [0.81649658, -0.40824829, -0.40824829],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))", + "completions": [ + "```python\n def test_generate_basis_golden(self):\n basis = generate_basis(\n base_shape='tetrahedron', angular_tesselation=10, remove_symmetries=True, eps=1e-4\n )\n self.assertTrue(np.allclose(basis, np.array([[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0]])))\n```\n" + ], + "line": 75, + "token": 692, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 41, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:\n if not isinstance(t, torch.Tensor):\n t = torch.Tensor([t]).float()\n\n t_max = self.t_max.to(t.device)\n t_min = self.t_min.to(t.device)\n t_tilde = t_min + (t_max - t_min) * t\n\n return t_tilde\n```\n###test function signature:\n```python\n\ndef test_noise_schedule_diffusion_coeff(kind):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_noise_schedule_diffusion_coeff", + "reference": "\ndef test_noise_schedule_diffusion_coeff(kind):\n noise_schedule = GaussianNoiseSchedule(kind=kind, log_snr_range=(-6, 6))\n ts = torch.linspace(1e-2, 1 - 1e-2, 10)\n t_map = noise_schedule.t_map(ts) # map time to the prescribed log_SNR range\n\n if kind == \"log_snr\":\n g = noise_schedule.g(ts)\n\n # compute true beta_t\n l_range = noise_schedule.log_snr_range\n snr = torch.exp((1 - t_map) * l_range[1] + t_map * l_range[0])\n g_true = (-(l_range[0] - l_range[1]) / (snr + 1)).sqrt()\n\n assert torch.allclose(g, g_true, atol=1e-4)\n\n if kind == \"ot_linear\":\n g = noise_schedule.g(ts)\n g_true = (2.0 * t_map / (1.0 - t_map)).sqrt()\n assert torch.allclose(g, g_true, atol=1e-4)", + "completions": [ + "```python\ndef test_noise_schedule_diffusion_coeff(kind):\n from torch.testing import assert_allclose\n from torch.nn import functional as F\n from torch import tensor\n from torch.distributions import Uniform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import Transform\n from torch.distributions.transforms import SigmoidTransform\n from torch.distributions.transforms import TanhTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import ComposeTransform\n from torch.distributions.transforms import PowerTransform\n from torch.distributions.transforms import ExpTransform\n from torch.distributions.transforms import LogTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import" + ], + "line": 17, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 42, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_data_quality", + "reference": "\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index, temperature_data, temperature_mean=False, data_quality=True\n )\n assert df.shape == (27, 4)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_not_null\",\n \"temperature_null\",\n ]\n assert round(df.temperature_not_null.mean(), 2) == 729.23\n assert round(df.temperature_null.mean(), 2) == 0.0", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"1H\", tz=\"America/Chicago\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0" + ], + "line": 168, + "token": 984, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 43, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n\n # when\n result = http_client.unload_model(model_id=\"some/1\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"some/1\",\n }", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n model_id = \"my_model_id\"\n de_aliased_model_id = \"my_de_aliased_model_id\"\n response_payload = {\n \"model_id\": de_aliased_model_id,\n \"model_name\": \"my_model_name\",\n \"model_version\": \"my_model_version\",\n \"model_type\": \"my_model_type\",\n \"model_status\": \"my_model_status\",\n \"model_description\": \"my_model_description\",\n \"model_tags\": [\"my_model_tag1\", \"my_model_tag2\"],\n \"model_created_at\": \"my_model_created_at\",\n \"model_updated_at\": \"my_model_updated_at\",\n }\n requests_mock.post(\n f\"{self.__api_url}/model/remove\",\n json=response_payload,\n headers=DEFAULT_HEADERS,\n )\n # when\n response = self.client.unload_model(model_id=model_id)\n # then\n assert response.model_id == de_aliased_model_id\n assert response.model_name == \"my_model_name\"\n assert response.model_version == \"my_model_version\"\n assert response.model_type == \"my_model_type\"\n assert response.model_status == \"my_model_status\"\n assert response.model_description == \"my_model_description\"\n assert response.model_tags == [\"my_model_tag1\", \"my_model_tag2\"]\n assert response.model_created_at == \"my_model_created_at\"\n assert response.model_updated_at == \"my_model_updated_at\"\n assert self.client.selected_model is None\n```\n" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=20.0,\n max_iteration=100,\n distance_strict=False,\n seed=None,\n):\n rng = np.random.default_rng(seed=seed)\n units_locations = np.zeros((num_units, 3), dtype=\"float32\")\n\n minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um\n minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um\n\n units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)\n units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)\n units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)\n\n if minimum_distance is not None:\n solution_found = False\n renew_inds = None\n for i in range(max_iteration):\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n inds0, inds1 = np.nonzero(distances < minimum_distance)\n mask = inds0 != inds1\n inds0 = inds0[mask]\n inds1 = inds1[mask]\n\n if inds0.size > 0:\n if renew_inds is None:\n renew_inds = np.unique(inds0)\n else:\n # random only bad ones in the previous set\n renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]\n\n units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)\n units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)\n units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)\n else:\n solution_found = True\n break\n\n if not solution_found:\n if distance_strict:\n raise ValueError(\n f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} \"\n \"You can use distance_strict=False or reduce minimum distance\"\n )\n else:\n warnings.warn(f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}\")\n\n return units_locations\n```\n###test function signature:\n```python\n\ndef test_generate_unit_locations():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_unit_locations", + "reference": "\ndef test_generate_unit_locations():\n seed = 0\n\n probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)\n channel_locations = probe.contact_positions\n\n num_units = 100\n minimum_distance = 20.0\n unit_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=minimum_distance,\n max_iteration=500,\n distance_strict=False,\n seed=seed,\n )\n distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)\n dist_flat = np.triu(distances, k=1).flatten()\n dist_flat = dist_flat[dist_flat > 0]\n assert np.all(dist_flat > minimum_distance)\n\n # import matplotlib.pyplot as plt\n # fig, ax = plt.subplots()\n # ax.hist(dist_flat, bins = np.arange(0, 400, 10))\n # fig, ax = plt.subplots()\n # from probeinterface.plotting import plot_probe\n # plot_probe(probe, ax=ax)\n # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)\n # plt.show()", + "completions": [ + "```python\n\ndef test_generate_unit_locations():\n num_units = 10\n channel_locations = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])\n margin_um = 20.0\n minimum_z = 5.0\n maximum_z = 40.0\n minimum_distance = 20.0\n max_iteration = 100\n distance_strict = False\n seed = None\n\n units_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=margin_um,\n minimum_z=minimum_z,\n maximum_z=maximum_z,\n minimum_distance=minimum_distance,\n max_iteration=max_iteration,\n distance_strict=distance_strict,\n seed=seed,\n )\n\n assert units_locations.shape == (num_units, 3)\n assert np.all(units_locations[:, 0] >= margin_um)\n assert np.all(units_locations[:, 0] <= 100 - margin_um)\n assert np.all(units_locations[:, 1] >= margin_um)\n assert np.all(units_locations[:, 1] <= 100 - margin_um)\n assert np.all(units_locations[:, 2] >= minimum_z)\n assert np.all(units_locations[:, 2] <= maximum_z)\n\n # check distance\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n assert np.all(distances >= minimum_distance)\n\n # check no duplicate\n assert np.all(np.unique(units_locations, axis=0) == units_locations)\n```\n" + ], + "line": 58, + "token": 453, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 45, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def to_dataframe(\n self, index: Optional[Union[str, None]] = \"date\", sort_by: Optional[str] = None\n ) -> pd.DataFrame:\n\n def is_list_of_basemodel(items: Union[List[T], T]) -> bool:\n return isinstance(items, list) and all(\n isinstance(item, BaseModel) for item in items\n )\n\n if self.results is None or not self.results:\n raise OpenBBError(\"Results not found.\")\n\n if isinstance(self.results, pd.DataFrame):\n return self.results\n\n try:\n res = self.results\n df = None\n sort_columns = True\n\n # List[Dict]\n if isinstance(res, list) and len(res) == 1 and isinstance(res[0], dict):\n r = res[0]\n dict_of_df = {}\n\n for k, v in r.items():\n # Dict[str, List[BaseModel]]\n if is_list_of_basemodel(v):\n dict_of_df[k] = basemodel_to_df(v, index)\n sort_columns = False\n # Dict[str, Any]\n else:\n dict_of_df[k] = pd.DataFrame(v)\n\n df = pd.concat(dict_of_df, axis=1)\n\n # List[BaseModel]\n elif is_list_of_basemodel(res):\n dt: Union[List[Data], Data] = res # type: ignore\n df = basemodel_to_df(dt, index)\n sort_columns = False\n # List[List | str | int | float] | Dict[str, Dict | List | BaseModel]\n else:\n try:\n df = pd.DataFrame(res)\n # Set index, if any\n if index is not None and index in df.columns:\n df.set_index(index, inplace=True)\n\n except ValueError:\n if isinstance(res, dict):\n df = pd.DataFrame([res])\n\n if df is None:\n raise OpenBBError(\"Unsupported data format.\")\n\n # Drop columns that are all NaN, but don't rearrange columns\n if sort_columns:\n df.sort_index(axis=1, inplace=True)\n df = df.dropna(axis=1, how=\"all\")\n\n # Sort by specified column\n if sort_by:\n df.sort_values(by=sort_by, inplace=True)\n\n except OpenBBError as e:\n raise e\n except ValueError as ve:\n raise OpenBBError(\n f\"ValueError: {ve}. Ensure the data format matches the expected format.\"\n ) from ve\n except TypeError as te:\n raise OpenBBError(\n f\"TypeError: {te}. Check the data types in your results.\"\n ) from te\n except Exception as ex:\n raise OpenBBError(f\"An unexpected error occurred: {ex}\") from ex\n\n return df\n```\n###test function signature:\n```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_to_dataframe_w_args", + "reference": "def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n co = OBBject(results=results)\n\n # Act and Assert\n result = co.to_dataframe(index=index, sort_by=sort_by)\n assert isinstance(result, pd.DataFrame)\n assert result.index.name == index\n\n # check if dataframe is properly sorted\n assert result[sort_by].is_monotonic_increasing", + "completions": [ + "```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n from openbb_terminal.decorators import log_start_end\n from openbb_terminal.stocks.fundamental_analysis import financedatabase_model\n\n @log_start_end(log=logger)\n def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n df = financedatabase_model.to_dataframe(results, index, sort_by)\n # Act\n # Assert\n assert df.shape[0] > 0\n assert df.shape[1] > 0\n assert df.index.name == index\n assert df.columns.tolist() == sort_by\n\n # Act\n test_to_dataframe_w_args(results, index, sort_by)\n # Assert\n```\n" + ], + "line": 76, + "token": 566, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 46, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef unique(ar, return_index=False, return_inverse=False,\n return_counts=False, axis=None, *, equal_nan=True):\n ar = np.asanyarray(ar)\n if axis is None:\n ret = _unique1d(ar, return_index, return_inverse, return_counts, \n equal_nan=equal_nan)\n return _unpack_tuple(ret)\n\n # axis was specified and not None\n try:\n ar = np.moveaxis(ar, axis, 0)\n except np.AxisError:\n # this removes the \"axis1\" or \"axis2\" prefix from the error message\n raise np.AxisError(axis, ar.ndim) from None\n\n # Must reshape to a contiguous 2D array for this to work...\n orig_shape, orig_dtype = ar.shape, ar.dtype\n ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))\n ar = np.ascontiguousarray(ar)\n dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]\n\n # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured\n # data type with `m` fields where each field has the data type of `ar`.\n # In the following, we create the array `consolidated`, which has\n # shape `(n,)` with data type `dtype`.\n try:\n if ar.shape[1] > 0:\n consolidated = ar.view(dtype)\n else:\n # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is\n # a data type with itemsize 0, and the call `ar.view(dtype)` will\n # fail. Instead, we'll use `np.empty` to explicitly create the\n # array with shape `(len(ar),)`. Because `dtype` in this case has\n # itemsize 0, the total size of the result is still 0 bytes.\n consolidated = np.empty(len(ar), dtype=dtype)\n except TypeError as e:\n # There's no good way to do this for object arrays, etc...\n msg = 'The axis argument to unique is not supported for dtype {dt}'\n raise TypeError(msg.format(dt=ar.dtype)) from e\n\n def reshape_uniq(uniq):\n n = len(uniq)\n uniq = uniq.view(orig_dtype)\n uniq = uniq.reshape(n, *orig_shape[1:])\n uniq = np.moveaxis(uniq, 0, axis)\n return uniq\n\n output = _unique1d(consolidated, return_index,\n return_inverse, return_counts, equal_nan=equal_nan)\n output = (reshape_uniq(output[0]),) + output[1:]\n return _unpack_tuple(output)\n```\n###test function signature:\n```python\n\n def test_unique_1d(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_1d", + "reference": "\n def test_unique_1d(self):\n\n def check_all(a, b, i1, i2, c, dt):\n base_msg = 'check {0} failed for type {1}'\n\n msg = base_msg.format('values', dt)\n v = unique(a)\n assert_array_equal(v, b, msg)\n\n msg = base_msg.format('return_index', dt)\n v, j = unique(a, True, False, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i1, msg)\n\n msg = base_msg.format('return_inverse', dt)\n v, j = unique(a, False, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i2, msg)\n\n msg = base_msg.format('return_counts', dt)\n v, j = unique(a, False, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, c, msg)\n\n msg = base_msg.format('return_index and return_inverse', dt)\n v, j1, j2 = unique(a, True, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n\n msg = base_msg.format('return_index and return_counts', dt)\n v, j1, j2 = unique(a, True, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format('return_inverse and return_counts', dt)\n v, j1, j2 = unique(a, False, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i2, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format(('return_index, return_inverse '\n 'and return_counts'), dt)\n v, j1, j2, j3 = unique(a, True, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n assert_array_equal(j3, c, msg)\n\n a = [5, 7, 1, 2, 1, 5, 7]*10\n b = [1, 2, 5, 7]\n i1 = [2, 3, 0, 1]\n i2 = [2, 3, 0, 1, 0, 2, 3]*10\n c = np.multiply([2, 1, 2, 2], 10)\n\n # test for numeric arrays\n types = []\n types.extend(np.typecodes['AllInteger'])\n types.extend(np.typecodes['AllFloat'])\n types.append('datetime64[D]')\n types.append('timedelta64[D]')\n for dt in types:\n aa = np.array(a, dt)\n bb = np.array(b, dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for object arrays\n dt = 'O'\n aa = np.empty(len(a), dt)\n aa[:] = a\n bb = np.empty(len(b), dt)\n bb[:] = b\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for structured arrays\n dt = [('', 'i'), ('', 'i')]\n aa = np.array(list(zip(a, a)), dt)\n bb = np.array(list(zip(b, b)), dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for ticket #2799\n aa = [1. + 0.j, 1 - 1.j, 1]\n assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])\n\n # test for ticket #4785\n a = [(1, 2), (1, 2), (2, 3)]\n unq = [1, 2, 3]\n inv = [0, 1, 0, 1, 1, 2]\n a1 = unique(a)\n assert_array_equal(a1, unq)\n a2, a2_inv = unique(a, return_inverse=True)\n assert_array_equal(a2, unq)\n assert_array_equal(a2_inv, inv)\n\n # test for chararrays with return_inverse (gh-5099)\n a = np.chararray(5)\n a[...] = ''\n a2, a2_inv = np.unique(a, return_inverse=True)\n assert_array_equal(a2_inv, np.zeros(5))\n\n # test for ticket #9137\n a = []\n a1_idx = np.unique(a, return_index=True)[1]\n a2_inv = np.unique(a, return_inverse=True)[1]\n a3_idx, a3_inv = np.unique(a, return_index=True,\n return_inverse=True)[1:]\n assert_equal(a1_idx.dtype, np.intp)\n assert_equal(a2_inv.dtype, np.intp)\n assert_equal(a3_idx.dtype, np.intp)\n assert_equal(a3_inv.dtype, np.intp)\n\n # test for ticket 2111 - float\n a = [2.0, np.nan, 1.0, np.nan]\n ua = [1.0, 2.0, np.nan]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - complex\n a = [2.0-1j, np.nan, 1.0+1j, complex(0.0, np.nan), complex(1.0, np.nan)]\n ua = [1.0+1j, 2.0-1j, complex(0.0, np.nan)]\n ua_idx = [2, 0, 3]\n ua_inv = [1, 2, 0, 2, 2]\n ua_cnt = [1, 1, 3]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - datetime64\n nat = np.datetime64('nat')\n a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]\n ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - timedelta\n nat = np.timedelta64('nat')\n a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]\n ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for gh-19300\n all_nans = [np.nan] * 4\n ua = [np.nan]\n ua_idx = [0]\n ua_inv = [0, 0, 0, 0]\n ua_cnt = [4]\n assert_equal(np.unique(all_nans), ua)\n assert_equal(np.unique(all_nans, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(all_nans, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(all_nans, return_counts=True), (ua, ua_cnt))", + "completions": [ + "```python\n def test_unique_1d(self):\n ar = np.array([1, 2, 3, 4, 5, 5, 5, 5, 5, 5])\n self.assertEqual(unique(ar), np.array([1, 2, 3, 4, 5]))\n self.assertEqual(unique(ar, return_index=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n```\n" + ], + "line": 56, + "token": 571, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef scaled_dot_product_attention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n att_mask: Optional[Union[AttentionMask, \"SparseCS\", torch.Tensor]],\n dropout: Optional[torch.nn.Module] = None,\n block_size: int = 128,\n) -> torch.Tensor:\n autocast_disabled = (\n _has_cpp_library\n and isinstance(att_mask, SparseCS)\n or (att_mask is not None and att_mask.is_sparse)\n )\n seq_len = q.shape[-2]\n\n # switch if:\n # causal is required but mask is not sparse\n # fp16 or under amp context\n # sequence length is divisible by block size\n # same seq len for K and Q\n switch_to_blocksparse = (\n _is_blocksparse_available\n and (att_mask is not None and not att_mask.is_sparse)\n and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)\n and (q.dtype == torch.float16 or torch.is_autocast_enabled())\n and not seq_len % block_size\n and q.shape[-2] == k.shape[-2]\n )\n\n if switch_to_blocksparse:\n logger.info(\"Switching causal attention to Triton blocksparse...\")\n return blocksparse_attention(q, k, v, dropout, block_size)\n\n with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore\n if autocast_disabled:\n q, k, v = q.float(), k.float(), v.float()\n\n att = scaled_query_key_softmax(q, k, att_mask=att_mask)\n\n # Optional dropout, could be part of the masking in the future\n att = _apply_dropout(att, dropout)\n\n # Get to the predicted values, for all heads\n # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)\n y = bmm(att, v)\n return y\n```\n###test function signature:\n```python\n\ndef test_amp_attention_dense(device):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_amp_attention_dense", + "reference": "\ndef test_amp_attention_dense(device):\n b, s, d = 8, 64, 32\n prob = 0.9\n\n a = torch.rand(b, s, d, device=device)\n m = torch.rand(s, s, device=device) > prob\n\n with torch.cuda.amp.autocast():\n r = scaled_dot_product_attention(a, a, a, m)\n\n expected_device = torch.float16 if device == \"cuda\" else torch.float32\n assert r.dtype == expected_device", + "completions": [ + "```python\n\ndef test_amp_attention_dense(device):\n q = torch.randn(1, 128, 128, device=device)\n k = torch.randn(1, 128, 128, device=device)\n v = torch.randn(1, 128, 128, device=device)\n att_mask = torch.randn(1, 128, 128, device=device)\n dropout = torch.nn.Dropout(0.1)\n block_size = 128\n y = scaled_dot_product_attention(q, k, v, att_mask, dropout, block_size)\n assert y.shape == (1, 128, 128)\n```\n" + ], + "line": 50, + "token": 390, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def quantize(\n cls,\n tensor: Tensor,\n nbits: int = 4,\n channel_wise: bool = True,\n group_size: int = 64,\n optimize: bool = False,\n round_zero: bool = False,\n axis: int = 0,\n bitpack: bool = True,\n compute_dtype: torch.dtype | None = None,\n view_as_float: bool = False,\n device: str = \"cuda\",\n ) -> tuple:\n assert nbits in Quantizer.SUPPORTED_BITS, (\n \"nbits=\" + str(nbits) + \" not supported.\"\n )\n assert axis in [0, 1], \"axis should be either 0 or 1\"\n if group_size is not None:\n assert is_divisible(tensor.numel(), group_size), (\n \"group_size should be divisble by the total tensor dimensions. shape: \"\n + str(tensor.shape)\n + \", group_size: \"\n + str(group_size)\n )\n\n W = tensor.float()\n shape = W.shape\n\n # Reshape for grouping\n if (group_size is not None) and channel_wise:\n W = (\n W.reshape([-1, group_size])\n if (axis == 1)\n else W.reshape([group_size, -1])\n )\n\n # Get min/max values\n if not channel_wise:\n _min, _max = W.min(), W.max()\n optimize = False\n else:\n _min = W.min(axis=axis, keepdim=True)[0]\n _max = W.max(axis=axis, keepdim=True)[0]\n\n max_v = 2**nbits - 1\n min_v = 0\n min_max = [min_v, max_v]\n\n # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.\n scale = (max_v / (_max - _min)).clamp(\n max=2e4\n ) # clamp to avoid half-precision problems\n zero = -_min * scale\n\n # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14\n if round_zero:\n zero = torch.round(zero)\n\n # Fine-tune weights\n if optimize:\n W_q, scale, zero = Quantizer.optimize_weights(\n tensor=W,\n scale=scale,\n zero=zero,\n min_max=min_max,\n axis=axis,\n device=device,\n )\n else:\n W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])\n\n # Store meta-data (we invert the scale for dequantization)\n meta = {\n \"nbits\": nbits,\n \"group_size\": group_size,\n \"shape\": shape,\n \"scale\": 1.0 / scale,\n \"zero\": zero,\n \"axis\": axis,\n \"packing\": Quantizer.bit_to_packing[nbits],\n }\n meta[\"unpack_view_dtype\"] = Quantizer.unpack_view_dtype[meta[\"packing\"]]\n\n # Pack bits\n meta[\"view_as_float\"] = view_as_float\n if bitpack:\n W_q = Quantizer.pack[meta[\"packing\"]](W_q)\n if view_as_float:\n W_q = W_q.view(\n torch.float32 if compute_dtype is None else compute_dtype\n ) # store quantized weights as compute_dtype\n else:\n W_q = W_q.to(tensor.dtype)\n meta[\"packing\"] = None\n\n # cleanup\n del W, _min, _max\n torch.cuda.empty_cache()\n\n return W_q, meta\n```\n###test function signature:\n```python\n\n def test_quantizer(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_quantizer", + "reference": "\n def test_quantizer(self): \n for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:\n for nbits in [8,4,3,2,1]:\n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, view_as_float=False)\n if nbits == 3:\n assert W_q.dtype == torch.int32\n else:\n assert W_q.dtype == torch.uint8\n w_dq = Quantizer.dequantize(W_q, meta)\n norm1 = torch.norm(self.w - w_dq, p=0.7)\n \n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, compute_dtype=compute_dtype, view_as_float=True)\n assert W_q.dtype == compute_dtype\n w_dq = Quantizer.dequantize(W_q, meta)\n norm2 = torch.norm(self.w - w_dq, p=0.7)\n \n self.assertTrue(torch.equal(norm1, norm2))", + "completions": [ + "```python\n\n def test_quantizer(self):\n # test case 1\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 2\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=True, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], 0.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 3\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=True, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 4\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor," + ], + "line": 100, + "token": 645, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 49, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_num_ref_area", + "reference": "\ndef test_nanoplot_out_bars_with_num_ref_area():\n\n out_bars_with_num_ref_area = _generate_nanoplot(**CASES[13])\n\n assert _is_nanoplot_output(out_bars_with_num_ref_area)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_area = [1, 3]\n y_ref_line = 2\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_" + ], + "line": 749, + "token": 7038, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 50, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef mean(x, axis=None, keepdims=False):\n if isinstance(x, (list, tuple)):\n x = stack(x)\n x = convert_to_tensor(x)\n if axis == () or axis == []:\n # Torch handles the empty axis case differently from numpy.\n return x\n axis = to_tuple_or_list(axis) # see [NB] below\n\n ori_dtype = standardize_dtype(x.dtype)\n # torch.mean only supports floating point inputs\n compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n if \"int\" in ori_dtype or ori_dtype == \"bool\":\n result_dtype = compute_dtype\n else:\n result_dtype = ori_dtype\n\n # [NB] the python torch op torch.mean() is generated into\n # `torch._C._VariableFunctions.pyi`, and the method\n # signature is overloaded.\n # Dynamo won't actually find the correct signature of\n # `torch.mean()` if arguments are passed via kwargs\n # So we have to pass the arguments via positional args\n # EXCEPT for those that are forced as kwargs via the `*`\n # delimiter in the overloaded method signatures.\n # Additionally, we have to create a singleton-tuple\n # when `axis` is an int to match the existing fn signature\n result = torch.mean(\n x,\n axis,\n keepdims,\n dtype=to_torch_dtype(compute_dtype),\n )\n return cast(result, result_dtype)\n```\n###test function signature:\n```python\n\n def test_mean(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mean", + "reference": "\n def test_mean(self):\n x = np.array([[1, 2, 3], [3, 2, 1]])\n self.assertAllClose(knp.mean(x), np.mean(x))\n self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n self.assertAllClose(\n knp.mean(x, axis=1, keepdims=True),\n np.mean(x, axis=1, keepdims=True),\n )\n\n self.assertAllClose(knp.Mean()(x), np.mean(x))\n self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n self.assertAllClose(\n knp.Mean(axis=1, keepdims=True)(x),\n np.mean(x, axis=1, keepdims=True),\n )\n\n # test overflow\n x = np.array([65504, 65504, 65504], dtype=\"float16\")\n self.assertAllClose(knp.mean(x), np.mean(x))", + "completions": [ + "```python\n\n def test_mean(self):\n x = torch.randn(2, 3, 4)\n self.assertEqual(mean(x), torch.mean(x))\n self.assertEqual(mean(x, axis=0), torch.mean(x, dim=0))\n self.assertEqual(mean(x, axis=1), torch.mean(x, dim=1))\n self.assertEqual(mean(x, axis=2), torch.mean(x, dim=2))\n self.assertEqual(mean(x, axis=(0, 1)), torch.mean(x, dim=(0, 1)))\n self.assertEqual(mean(x, axis=(0, 2)), torch.mean(x, dim=(0, 2)))\n self.assertEqual(mean(x, axis=(1, 2)), torch.mean(x, dim=(1, 2)))\n self.assertEqual(mean(x, axis=(0, 1, 2)), torch.mean(x, dim=(0, 1, 2)))\n self.assertEqual(mean(x, axis=0, keepdims=True), torch.mean(x, dim=0, keepdim=True))\n self.assertEqual(mean(x, axis=1, keepdims=True), torch.mean(x, dim=1, keepdim=True))\n self.assertEqual(mean(x, axis=2, keepdims=True), torch.mean(x, dim=2, keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1), keepdims=True), torch.mean(x, dim=(0, 1), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 2), keepdims=True), torch.mean(x, dim=(0, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(1, 2), keepdims=True), torch.mean(x, dim=(1, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1, 2), keepdims=True), torch.mean(x, dim=(0, 1, 2), keepdim=True))\n```\n" + ], + "line": 42, + "token": 322, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 51, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_intrinsic_value(\n cash_flow: float,\n growth_rate: float,\n perpetual_growth_rate: float,\n weighted_average_cost_of_capital: float,\n cash_and_cash_equivalents: float,\n total_debt: float,\n shares_outstanding: float,\n periods: int = 5,\n) -> pd.DataFrame:\n components = {}\n\n cash_flow_projection = [cash_flow]\n\n # Cash Flow to Use\n for period in range(1, periods + 1):\n if period == 1:\n cash_flow_projection.append(cash_flow_projection[0] * (1 + growth_rate))\n else:\n cash_flow_projection.append(\n cash_flow_projection[period - 1] * (1 + growth_rate)\n )\n\n # Calculate the Terminal Value\n terminal_value = (\n cash_flow_projection[-1]\n * (1 + perpetual_growth_rate)\n / (weighted_average_cost_of_capital - perpetual_growth_rate)\n )\n\n # Add Terminal Value to the end of the cash flow projection\n cash_flow_projection[-1] = cash_flow_projection[-1] + terminal_value\n\n # Calculate the Present Value based on the Discounted Cash Flow Formula\n cash_flow_present_value = []\n for index, cash_flow_value in enumerate(cash_flow_projection):\n cash_flow_present_value.append(\n cash_flow_value / (1 + weighted_average_cost_of_capital) ** (index + 1)\n )\n\n # Calculate the Enterprise Value\n enterprise_value = sum(cash_flow_present_value)\n\n # Calculate the Equity Value\n equity_value = enterprise_value + cash_and_cash_equivalents - total_debt\n\n # Calculate the Intrinsic Value\n intrinsic_value = equity_value / shares_outstanding\n\n # Combine the components into a dictionary\n components = {\n \"Terminal Value\": terminal_value,\n \"Cash Flow Projection\": cash_flow_projection[-1],\n \"Enterprise Value\": enterprise_value,\n \"Equity Value\": equity_value,\n \"Intrinsic Value\": intrinsic_value,\n }\n\n return pd.DataFrame.from_dict(\n components, orient=\"index\", columns=[f\"Periods = {periods}\"]\n )\n```\n###test function signature:\n```python\n\ndef test_get_intrinsic_value(recorder):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_intrinsic_value", + "reference": "\ndef test_get_intrinsic_value(recorder):\n recorder.capture(\n intrinsic_model.get_intrinsic_value(\n cash_flow=500,\n growth_rate=0.05,\n perpetual_growth_rate=0.02,\n weighted_average_cost_of_capital=0.1,\n cash_and_cash_equivalents=100,\n total_debt=100,\n shares_outstanding=100,\n periods=5,\n )\n )", + "completions": [ + "```python\n\ndef test_get_intrinsic_value(recorder):\n cash_flow = 1000000\n growth_rate = 0.05\n perpetual_growth_rate = 0.05\n weighted_average_cost_of_capital = 0.1\n cash_and_cash_equivalents = 1000000\n total_debt = 1000000\n shares_outstanding = 1000000\n periods = 5\n\n result = get_intrinsic_value(\n cash_flow,\n growth_rate,\n perpetual_growth_rate,\n weighted_average_cost_of_capital,\n cash_and_cash_equivalents,\n total_debt,\n shares_outstanding,\n periods,\n )\n\n recorder.capture(result)\n```\n" + ], + "line": 61, + "token": 379, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 52, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef to_numpy(\n evset: EventSet,\n timestamp_to_datetime: bool = True,\n timestamps: bool = True,\n) -> Dict[str, ndarray]:\n timestamp_key = \"timestamp\"\n index_names = evset.schema.index_names()\n feature_names = evset.schema.feature_names()\n\n column_names = index_names + feature_names\n if timestamps:\n column_names += [timestamp_key]\n\n dst = {column_name: [] for column_name in column_names}\n for index, data in evset.data.items():\n assert isinstance(index, tuple)\n\n if timestamps:\n # Timestamps\n if evset.schema.is_unix_timestamp and timestamp_to_datetime:\n dst[timestamp_key].append(\n data.timestamps.astype(\"datetime64[s]\")\n )\n else:\n dst[timestamp_key].append(data.timestamps)\n\n # Features\n for feature_name, feature in zip(feature_names, data.features):\n dst[feature_name].append(feature)\n\n # Indexes\n num_timestamps = len(data.timestamps)\n for index_name, index_item in zip(index_names, index):\n dst[index_name].append(np.repeat(index_item, num_timestamps))\n\n dst = {k: np.concatenate(v) for k, v in dst.items()}\n return dst\n```\n###test function signature:\n```python\n\n def test_correct(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct", + "reference": "\n def test_correct(self):\n evset = event_set(\n timestamps=[\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n features={\n \"feature_1\": [0.5, 0.6],\n \"my_index\": [\"red\", \"blue\"],\n },\n indexes=[\"my_index\"],\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"timestamp\": np.array(\n [\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n dtype=\"datetime64[s]\",\n ),\n \"feature_1\": np.array([0.5, 0.6]),\n \"my_index\": np.array([b\"red\", b\"blue\"]),\n }\n\n for k in expected:\n np.testing.assert_array_equal(\n np.sort(result[k]), np.sort(expected[k])\n )", + "completions": [ + "```python\n def test_correct(self):\n evset = EventSet(\n data={\n \"timestamp\": [1610000000, 1610000000, 1610000000, 1610000000],\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n },\n schema=EventSetSchema(\n index_names=[\"index\"],\n feature_names=[\"feature1\", \"feature2\"],\n is_unix_timestamp=True,\n ),\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n \"timestamp\": [\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n ],\n }\n\n self.assertDictEqual(result, expected)\n```\n" + ], + "line": 41, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def from_file(cls, path: str) -> \"LayerConfig\":\n logger.info(f\"Loading route config from {path}\")\n _, ext = os.path.splitext(path)\n with open(path, \"r\") as f:\n if ext == \".json\":\n layer = json.load(f)\n elif ext in [\".yaml\", \".yml\"]:\n layer = yaml.safe_load(f)\n else:\n raise ValueError(\n \"Unsupported file type. Only .json and .yaml are supported\"\n )\n\n if not is_valid(json.dumps(layer)):\n raise Exception(\"Invalid config JSON or YAML\")\n\n encoder_type = layer[\"encoder_type\"]\n encoder_name = layer[\"encoder_name\"]\n routes = []\n for route_data in layer[\"routes\"]:\n # Handle the 'llm' field specially if it exists\n if \"llm\" in route_data and route_data[\"llm\"] is not None:\n llm_data = route_data.pop(\n \"llm\"\n ) # Remove 'llm' from route_data and handle it separately\n # Use the module path directly from llm_data without modification\n llm_module_path = llm_data[\"module\"]\n # Dynamically import the module and then the class from that module\n llm_module = importlib.import_module(llm_module_path)\n llm_class = getattr(llm_module, llm_data[\"class\"])\n # Instantiate the LLM class with the provided model name\n llm = llm_class(name=llm_data[\"model\"])\n # Reassign the instantiated llm object back to route_data\n route_data[\"llm\"] = llm\n\n # Dynamically create the Route object using the remaining route_data\n route = Route(**route_data)\n routes.append(route)\n\n return cls(\n encoder_type=encoder_type, encoder_name=encoder_name, routes=routes\n )\n```\n###test function signature:\n```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_from_file_yaml", + "reference": "\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n config_path = tmp_path / \"config.yaml\"\n config_path.write_text(\n layer_yaml()\n ) # Assuming layer_yaml() returns a valid YAML string\n\n # Load the LayerConfig from the temporary file\n layer_config = LayerConfig.from_file(str(config_path))\n\n # Assertions to verify the loaded configuration\n assert layer_config.encoder_type == \"cohere\"\n assert layer_config.encoder_name == \"embed-english-v3.0\"\n assert len(layer_config.routes) == 2\n assert layer_config.routes[0].name == \"politics\"", + "completions": [ + "```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n yaml_file = tmp_path / \"test_layer.yaml\"\n yaml_file.write_text(\n \"\"\"\n encoder_type: openai\n encoder_name: gpt2\n routes:\n - name: test_route\n model_name: test_model\n llm:\n module: tests.test_layer\n class: TestLLM\n model: test_model\n encoder_kwargs:\n max_position_embeddings: 512\n num_attention_heads: 8\n intermediate_size: 2048\n hidden_act: gelu\n hidden_dropout_prob: 0.1\n attention_probs_dropout_prob: 0.1\n type_vocab_size: 16\n initializer_range: 0.02\n layer_norm_eps: 1e-12\n use_cache: True\n is_decoder: False\n classifier_dropout: 0.1\n use_multiplicative_input_mask: False\n num_labels: 2\n output_attentions: False\n output_hidden_states: False\n return_dict: True\n task_specific_params:\n summarization:\n max_length: 20\n min_length: 5\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n" + ], + "line": 49, + "token": 405, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 54, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def bind_all(self, circuit: QuantumCircuit) -> np.ndarray:\n arr = np.empty(self.shape, dtype=object)\n for idx in np.ndindex(self.shape):\n arr[idx] = self.bind(circuit, idx)\n return arr\n```\n###test function signature:\n```python\n def test_data(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_data", + "reference": " def test_data(self):\n with self.subTest(\"binding a single value\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 0)\n self.assertEqual(ba.shape, ())\n self.assertEqual(ba.size, 1)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuit = ba.bind(self.circuit, ())\n self.assertEqual(bound_circuit, self.circuit.assign_parameters(vals))\n\n with self.subTest(\"binding an array\"):\n vals = np.linspace(0, 1, 300).reshape((2, 3, 50))\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 2)\n self.assertEqual(ba.shape, (2, 3))\n self.assertEqual(ba.size, 6)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuits = ba.bind_all(self.circuit)\n self.assertEqual(bound_circuits.shape, (2, 3))\n self.assertEqual(bound_circuits[0, 0], self.circuit.assign_parameters(vals[0, 0]))\n self.assertEqual(bound_circuits[0, 1], self.circuit.assign_parameters(vals[0, 1]))\n self.assertEqual(bound_circuits[0, 2], self.circuit.assign_parameters(vals[0, 2]))\n self.assertEqual(bound_circuits[1, 0], self.circuit.assign_parameters(vals[1, 0]))\n self.assertEqual(bound_circuits[1, 1], self.circuit.assign_parameters(vals[1, 1]))\n self.assertEqual(bound_circuits[1, 2], self.circuit.assign_parameters(vals[1, 2]))\n\n with self.subTest(\"binding a single param\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params[0]: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 1)\n self.assertEqual(ba.ndim, 1)\n self.assertEqual(ba.shape, (50,))\n self.assertEqual(ba.size, 50)\n self.assertEqual(list(ba.data.keys()), [(self.params[0].name,)])\n np.testing.assert_allclose(list(ba.data.values()), [vals[..., np.newaxis]])", + "completions": [ + "```python\n def test_data(self):\n from qiskit import QuantumCircuit\n from qiskit.quantum_info import Pauli\n from qiskit.quantum_info import PauliTable\n from qiskit.quantum_info import PauliSumOp\n from qiskit.quantum_info import PauliOp\n from qiskit.quantum_info import PauliExpectation\n from qiskit.quantum_info import PauliMoment\n from qiskit.quantum_info import PauliTensor\n from qiskit.quantum_info import PauliList\n from qiskit.quantum_info import PauliTransform\n from qiskit.quantum_info import PauliTransformOp\n from qiskit.quantum_info import PauliTransformList\n from qiskit.quantum_info import PauliTransformSum\n from qiskit.quantum_info import PauliTransformProduct\n from qiskit.quantum_info import PauliTransformExpectation\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum" + ], + "line": 15, + "token": 122, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef set_seed(\n seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None\n) -> int:\n world_size, rank = get_world_size_and_rank()\n max_val = np.iinfo(np.uint32).max - world_size + 1\n min_val = np.iinfo(np.uint32).min\n if seed is None:\n rand_seed = torch.randint(min_val, max_val, (1,))\n seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks\n if seed < min_val or seed > max_val:\n raise ValueError(\n f\"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]\"\n )\n local_seed = seed + rank\n _log.debug(\n f\"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}\"\n )\n\n torch.manual_seed(local_seed)\n np.random.seed(local_seed)\n random.seed(local_seed)\n\n if debug_mode is not None:\n _log.debug(f\"Setting deterministic debug mode to {debug_mode}\")\n torch.set_deterministic_debug_mode(debug_mode)\n deterministic_debug_mode = torch.get_deterministic_debug_mode()\n if deterministic_debug_mode == 0:\n _log.debug(\"Disabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = False\n torch.backends.cudnn.benchmark = True\n else:\n _log.debug(\"Enabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.benchmark = False\n # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n return seed\n```\n###test function signature:\n```python\n\n def test_deterministic_true(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_deterministic_true", + "reference": "\n def test_deterministic_true(self) -> None:\n for det_debug_mode, det_debug_mode_str in [(1, \"warn\"), (2, \"error\")]:\n warn_only = det_debug_mode == 1\n for debug_mode in (det_debug_mode, det_debug_mode_str):\n set_seed(42, debug_mode=debug_mode)\n assert torch.backends.cudnn.deterministic\n assert not torch.backends.cudnn.benchmark\n assert det_debug_mode == torch.get_deterministic_debug_mode()\n assert torch.are_deterministic_algorithms_enabled()\n assert (\n warn_only == torch.is_deterministic_algorithms_warn_only_enabled()\n )\n assert os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] == \":4096:8\"", + "completions": [ + "```python\n def test_deterministic_true(self) -> None:\n seed = set_seed(debug_mode=True)\n self.assertEqual(seed, 123456789)\n```\n" + ], + "line": 45, + "token": 345, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 56, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n```\n###test function signature:\n```python\n\ndef test_concat_periodarray_2d():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_concat_periodarray_2d", + "reference": "\ndef test_concat_periodarray_2d():\n pi = pd.period_range(\"2016-01-01\", periods=36, freq=\"D\")\n arr = pi._data.reshape(6, 6)\n\n result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)\n tm.assert_period_array_equal(result, arr)\n\n result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)\n tm.assert_period_array_equal(result, arr)\n\n msg = (\n \"all the input array dimensions.* for the concatenation axis must match exactly\"\n )\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)\n\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:2], arr[2:]], axis=1)", + "completions": [ + "```python\n\ndef test_concat_periodarray_2d():\n import numpy as np\n import pandas as pd\n from pandas.core.dtypes.common import is_period_dtype\n\n def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n\n def _concat_datetime(to_concat, axis: AxisInt = 0):\n from pandas.core.arrays import PeriodArray\n\n if axis == 0:\n " + ], + "line": 70, + "token": 610, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 57, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef ingest_ocp_payload(request):\n request_id = uuid4().hex\n response_data = {\"request-id\": request_id, \"payload-name\": []}\n if request.method == \"POST\":\n if not request.FILES:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for _, file in request.FILES.items():\n payload_name = file.name\n response_data[\"payload-name\"].append(payload_name)\n s3_signature = get_s3_signature(settings.S3_ENDPOINT, payload_name, method=\"put_object\")\n res = upload_file_to_s3(s3_signature, data=file.file)\n if res.status_code == HTTPStatus.OK:\n response_data[\"upload\"] = \"success\"\n else:\n response_data[\"upload\"] = \"failed\"\n response_data[\"failed-reason\"] = res.reason\n return Response(response_data, status=res.status_code)\n send_payload(request_id, payload_name)\n else:\n params = request.query_params\n payload_names = params.get(\"payload_name\")\n if not payload_names:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for payload_name in payload_names.split(\",\"):\n send_payload(request_id, payload_name)\n response_data[\"payload-name\"].append(payload_name)\n\n response_data[\"ingest-started\"] = True\n\n return Response(response_data, status=HTTPStatus.ACCEPTED)\n```\n###test function signature:\n```python\n\n def test_ingest_ocp_payload(self):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ingest_ocp_payload", + "reference": "\n def test_ingest_ocp_payload(self):\n # Arrange\n file1 = SimpleUploadedFile(\"file1.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n file2 = SimpleUploadedFile(\"file2.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n test_table = [\n # Happy path tests\n (\n \"POST\",\n {\"file1\": file1, \"file2\": file2},\n \"\",\n HTTPStatus.ACCEPTED,\n {\"upload\": \"success\", \"ingest-started\": True},\n \"happy_path_post\",\n ),\n (\"GET\", {}, \"?payload_name=file1,file2\", HTTPStatus.ACCEPTED, {\"ingest-started\": True}, \"happy_path_get\"),\n # Edge cases\n (\"POST\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_files_post\"),\n (\"GET\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_payload_name_get\"),\n # Error cases\n (\"POST\", {\"file1\": file1}, \"\", HTTPStatus.INTERNAL_SERVER_ERROR, {\"upload\": \"failed\"}, \"error_case_post\"),\n (\"GET\", {}, \"?payload_name=non_existent_file\", HTTPStatus.ACCEPTED, {}, \"error_case_get\"),\n ]\n for test in test_table:\n method, files, query_params, expected_status, expected_response, test_id = test\n with self.subTest(test=test_id):\n\n mock_response = MagicMock()\n mock_response.status_code = (\n HTTPStatus.OK if expected_response.get(\"upload\") != \"failed\" else HTTPStatus.INTERNAL_SERVER_ERROR\n )\n mock_response.reason = \"Some reason\"\n\n with patch(\"masu.api.ingest_ocp_payload.get_s3_signature\") as mock_get_s3_signature, patch(\n \"masu.api.ingest_ocp_payload.upload_file_to_s3\", return_value=mock_response\n ) as mock_upload_file_to_s3, patch(\"masu.api.ingest_ocp_payload.send_payload\") as mock_send_payload:\n\n request = (\n APIRequestFactory().get(f\"/{query_params}\")\n if method == \"GET\"\n else APIRequestFactory().post(\"/\", data=files)\n )\n\n # Act\n response = ingest_ocp_payload(request)\n\n # Assert\n self.assertEqual(response.status_code, expected_status)\n if method == \"POST\" and files:\n mock_get_s3_signature.assert_called()\n mock_upload_file_to_s3.assert_called()\n if method == \"GET\" and \"payload_name\" in query_params:\n mock_send_payload.assert_called()\n for key, value in expected_response.items():\n self.assertEqual(response.data[key], value)", + "completions": [ + "```python\n def test_ingest_ocp_payload(self):\n # Arrange\n request = self.factory.post(\n \"/api/v1/ingest/ocp/payload\",\n data={\"payload_name\": \"payload_name\"},\n content_type=\"multipart/form-data\",\n )\n request.FILES = {\"payload_name\": \"payload_name\"}\n # Act\n response = ingest_ocp_payload(request)\n # Assert\n self.assertEqual(response.status_code, HTTPStatus.ACCEPTED)\n self.assertEqual(response.data[\"request-id\"], \"request_id\")\n self.assertEqual(response.data[\"payload-name\"], [\"payload_name\"])\n self.assertEqual(response.data[\"ingest-started\"], True)\n```\n" + ], + "line": 41, + "token": 308, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def compute_nb_items_with_metadata_field(\n project_id: str,\n metadata_field: str,\n collection_name: str,\n quantile_value: float = 0.1,\n) -> Tuple[int, float, int]:\n\n mongo_db = await get_mongo_db()\n\n try:\n pipeline = [\n # Match on the project id and the existence of the metadata field\n {\n \"$match\": {\n \"project_id\": project_id,\n f\"metadata.{metadata_field}\": {\"$exists\": True, \"$ne\": None},\n }\n },\n {\n \"$group\": {\n \"_id\": \"$user_id\", # Group by 'user_id'\n \"count\": {\"$sum\": 1}, # Count the documents in each group\n }\n },\n {\n \"$sort\": {\"count\": -1} # Sort by 'count' in descending order\n },\n ]\n\n # Execute the aggregation pipeline\n results = [] # List to hold the results\n async for group in mongo_db[collection_name].aggregate(pipeline):\n results.append(group[\"count\"])\n\n logger.debug(\n f\"Results for {metadata_field} in collection {collection_name} for project {project_id}: {results}\"\n )\n\n # If no results, return 0\n if len(results) == 0:\n return 0, 0.0, 0\n\n # Get the average and the quantiles\n average = sum(results) / len(results)\n bottom_quantile = results[int(len(results) * quantile_value)]\n top_quantile = results[int(len(results) * (1 - quantile_value))]\n\n return bottom_quantile, average, top_quantile\n\n except Exception as e:\n logger.warning(\n f\"Failed to fetch the number of {metadata_field} in collection {collection_name} for project {project_id}: {e}\",\n )\n return 0, 0.0, 0\n```\n###test function signature:\n```python\n\nasync def test_main_pipeline(db, populated_project):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_main_pipeline", + "reference": "\nasync def test_main_pipeline(db, populated_project):\n async for mongo_db in db:\n test_project_id = populated_project.id\n\n # Run the function\n topics = await fetch_topics(test_project_id)\n\n logger.debug(f\"Topics: {topics}\")\n\n count = await nb_items_with_a_metadata_field(\n test_project_id, \"tasks\", \"user_id\"\n )\n\n assert count > 0\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_successrate_metadata_quantiles(test_project_id, \"user_id\")\n\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert bottom_quantile >= 0\n # assert bottom_quantile <= top_quantile\n\n test_project_id = populated_project.id\n test_metadata_field = \"user_id\"\n test_collection_name = \"tasks\"\n test_quantile_value = 0.1\n\n # Run the function\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_nb_items_with_metadata_field(\n test_project_id,\n test_metadata_field,\n test_collection_name,\n test_quantile_value,\n )\n\n logger.debug(\"runned compute_nb_items_with_metadata_field\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, int)\n assert isinstance(average, float)\n assert isinstance(top_quantile, int)\n # assert bottom_quantile <= average <= top_quantile\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_session_length_per_metadata(\n test_project_id,\n test_metadata_field,\n test_quantile_value,\n )\n logger.debug(\"runned compute_session_length_per_metadata\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, float)\n assert isinstance(average, float)\n assert isinstance(top_quantile, float)\n # assert bottom_quantile <= average <= top_quantile", + "completions": [ + "```python\nasync def test_main_pipeline(db, populated_project):\n # Arrange\n project_id = populated_project.id\n collection_name = \"collection_name\"\n metadata_field = \"metadata_field\"\n quantile_value = 0.1\n\n # Act\n result = await compute_nb_items_with_metadata_field(\n project_id, metadata_field, collection_name, quantile_value\n )\n\n # Assert\n assert result == (0, 0.0, 0)\n```\n" + ], + "line": 56, + "token": 430, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n```\n###test function signature:\n```python\n\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_engine_auto_error_message", + "reference": "\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n from pandas.compat._optional import VERSIONS\n\n # Do we have engines installed, but a bad version of them?\n pa_min_ver = VERSIONS.get(\"pyarrow\")\n fp_min_ver = VERSIONS.get(\"fastparquet\")\n have_pa_bad_version = (\n False\n if not _HAVE_PYARROW\n else Version(pyarrow.__version__) < Version(pa_min_ver)\n )\n have_fp_bad_version = (\n False\n if not _HAVE_FASTPARQUET\n else Version(fastparquet.__version__) < Version(fp_min_ver)\n )\n # Do we have usable engines installed?\n have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version\n have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version\n\n if not have_usable_pa and not have_usable_fp:\n # No usable engines found.\n if have_pa_bad_version:\n match = f\"Pandas requires version .{pa_min_ver}. or newer of .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n\n if have_fp_bad_version:\n match = f\"Pandas requires version .{fp_min_ver}. or newer of .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")", + "completions": [ + "```python\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n import pyarrow\n import fastparquet\n from pyarrow.parquet import ParquetEngine\n from fastparquet.parquet import ParquetEngine as FastParquetEngine\n\n def get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n\n # pyarrow is installed\n assert get_engine(\"auto\").__class__ == PyArrowImpl\n assert get_engine(\"pyarrow\").__class__ == PyArrowImpl\n assert get_engine(\"fastparquet\").__class__ == FastParquetImpl\n\n # pyarrow is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'pyarrow'\" in str(excinfo.value)\n\n # fastparquet is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'fastparquet'\" in str(excinfo.value)\n\n # pyarrow is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n\n # fastparquet is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n" + ], + "line": 38, + "token": 277, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 60, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def attribute(\n self,\n inp: InterpretableInput,\n target: Union[str, torch.Tensor, None] = None,\n num_trials: int = 1,\n gen_args: Optional[Dict] = None,\n # internal callback hook can be used for logging\n _inspect_forward: Optional[Callable] = None,\n **kwargs,\n ) -> LLMAttributionResult:\n\n assert isinstance(\n inp, self.SUPPORTED_INPUTS\n ), f\"LLMAttribution does not support input type {type(inp)}\"\n\n if target is None:\n # generate when None\n assert hasattr(self.model, \"generate\") and callable(self.model.generate), (\n \"The model does not have recognizable generate function.\"\n \"Target must be given for attribution\"\n )\n\n if not gen_args:\n gen_args = DEFAULT_GEN_ARGS\n\n model_inp = self._format_model_input(inp.to_model_input())\n output_tokens = self.model.generate(model_inp, **gen_args)\n target_tokens = output_tokens[0][model_inp.size(1) :]\n else:\n assert gen_args is None, \"gen_args must be None when target is given\"\n\n if type(target) is str:\n # exclude sos\n target_tokens = self.tokenizer.encode(target)[1:]\n target_tokens = torch.tensor(target_tokens)\n elif type(target) is torch.Tensor:\n target_tokens = target\n\n attr = torch.zeros(\n [\n 1 + len(target_tokens) if self.include_per_token_attr else 1,\n inp.n_itp_features,\n ],\n dtype=torch.float,\n device=self.device,\n )\n\n for _ in range(num_trials):\n attr_input = inp.to_tensor().to(self.device)\n\n cur_attr = self.attr_method.attribute(\n attr_input,\n additional_forward_args=(inp, target_tokens, _inspect_forward),\n **kwargs,\n )\n\n # temp necessary due to FA & Shapley's different return shape of multi-task\n # FA will flatten output shape internally (n_output_token, n_itp_features)\n # Shapley will keep output shape (batch, n_output_token, n_input_features)\n cur_attr = cur_attr.reshape(attr.shape)\n\n attr += cur_attr\n\n attr = attr / num_trials\n\n attr = inp.format_attr(attr)\n\n return LLMAttributionResult(\n attr[0],\n attr[1:]\n if self.include_per_token_attr\n else None, # shape(n_output_token, n_input_features)\n inp.values,\n self.tokenizer.convert_ids_to_tokens(target_tokens),\n )\n```\n###test function signature:\n```python\n\n def test_llm_attr(self, AttrClass) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_llm_attr", + "reference": "\n def test_llm_attr(self, AttrClass) -> None:\n llm = DummyLLM()\n llm.to(self.device)\n tokenizer = DummyTokenizer()\n llm_attr = LLMAttribution(AttrClass(llm), tokenizer)\n\n inp = TextTemplateInput(\"{} b {} {} e {}\", [\"a\", \"c\", \"d\", \"f\"])\n res = llm_attr.attribute(inp, \"m n o p q\")\n\n self.assertEqual(res.seq_attr.shape, (4,))\n self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))\n self.assertEqual(res.input_tokens, [\"a\", \"c\", \"d\", \"f\"])\n self.assertEqual(res.output_tokens, [\"m\", \"n\", \"o\", \"p\", \"q\"])\n self.assertEqual(res.seq_attr.device.type, self.device)\n self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)", + "completions": [ + "```python\n def test_llm_attr(self, AttrClass) -> None:\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for L" + ], + "line": 72, + "token": 464, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 61, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def plan(self, task: str):\n agent_pool_description = self.agent_pool.get_agent_pool_description()\n feedback = self.get_last_feedback()\n last_plan = self.get_last_plan()\n\n # run the planning\n plan = self.planning.execute(task, agent_pool_description, last_plan, feedback)\n\n logger.info(f\"Plan: {plan}\")\n\n self.shared_memory.add(\n event_type=EventType.planning, agent=self.name, content=str(plan)\n )\n\n return plan\n```\n###test function signature:\n```python\n\ndef test_planner_succeeds(get_llm): # noqa: F811\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_planner_succeeds", + "reference": "\ndef test_planner_succeeds(get_llm): # noqa: F811\n llm = get_llm(__file__, test_planner_succeeds.__name__)\n\n physicist_description = (\n \"The physicist agent answers questions or research about physics-related topics\"\n )\n physicist = Physicist(\n name=\"Physicist\",\n description=physicist_description,\n llm=llm,\n )\n\n agent_pool = AgentPool()\n agent_pool.add_agents([physicist])\n\n shared_memory = SharedMemory(\n objective=\"Share the information across different agents.\",\n agent_pool=agent_pool,\n )\n\n planner = Planner(\n name=\"planner\",\n agent_pool=agent_pool,\n shared_memory=shared_memory,\n llm=llm,\n )\n\n task = \"\"\"We need to render a highly complex 3D image on the solar system. We can use any publicly avaliable\n resources to achieve this task.\"\"\" # noqa: E501\n\n p = planner.plan(task=task)\n assert len(p.steps) > 0", + "completions": [ + "```python\ndef test_planner_succeeds(get_llm): # noqa: F811\n planner = get_llm\n planner.agent_pool.set_agent_pool_description(\n {\n \"agent_pool\": {\n \"agents\": [\n {\n \"name\": \"agent1\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n {\n \"name\": \"agent2\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n ]\n }\n }\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task1\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task2\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task3\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task4\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task5\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task6\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task7\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\"," + ], + "line": 21, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_edge(self, from_step: str, to_step: str) -> None:\n if from_step not in self.G:\n raise ValueError(f\"Step with name '{from_step}' does not exist\")\n\n if to_step not in self.G:\n raise ValueError(f\"Step with name '{to_step}' does not exist\")\n\n if to_step in self.G[from_step]:\n raise ValueError(\n f\"There is already a edge from '{to_step}' to '{from_step}'\"\n )\n\n if to_step in nx.ancestors(self.G, from_step):\n raise ValueError(\n f\"Cannot add edge from '{from_step}' to '{to_step}' as it would create a cycle.\"\n )\n\n self.G.add_edge(from_step, to_step)\n```\n###test function signature:\n```python\n\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validate_missing_step_input", + "reference": "\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n class DummyStep3(Step):\n @property\n def inputs(self) -> List[str]:\n return [\"instruction\"]\n\n @property\n def outputs(self) -> List[str]:\n return [\"response\"]\n\n def process(self) -> \"StepOutput\": # type: ignore\n yield [{\"response\": \"response1\"}]\n\n dag = DAG()\n dag.add_step(dummy_generator_step)\n dag.add_step(DummyStep3(name=\"dummy_step_3\", pipeline=pipeline))\n dag.add_edge(\"dummy_generator_step\", \"dummy_step_3\")\n\n with pytest.raises(\n ValueError,\n match=\"Step 'dummy_step_3' should have a parameter with type hint `StepInput`\",\n ):\n dag.validate()", + "completions": [ + "```python\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_2\")\n```\n" + ], + "line": 26, + "token": 238, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 63, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_hourly_temp_mean", + "reference": "\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n meter_data = il_electricity_cdd_hdd_hourly[\"meter_data\"][\"2016-03-01\":\"2016-07-01\"]\n temperature_data = il_electricity_cdd_hdd_hourly[\"temperature_data\"][\n \"2016-03-01\":\"2016-07-01\"\n ]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert list(sorted(df.columns)) == [\n \"n_hours_dropped\",\n \"n_hours_kept\",\n \"temperature_mean\",\n ]\n assert df.shape == (2952, 3)\n\n assert round(df.temperature_mean.mean()) == 62.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n df = il_electricity_cdd_hdd_hourly.copy()\n df = df.iloc[100:150]\n df = df.assign(\n **{\n \"cdd_10\": np.maximum(df.temperature_mean - 10, 0),\n \"cdd_20\": np.maximum(df.temperature_mean - 20, 0),\n \"hdd_10\": np.maximum(10 - df.temperature_mean, 0),\n \"hdd_20\": np.maximum(20 - df.temperature_mean, 0),\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n df = df.drop([\"temperature_mean\"], axis=1)\n df = df.reindex(il_electricity_cdd_hdd_hourly.index)\n df = overwrite_partial_rows_with_nan(df)\n df = df.iloc[:-1].reindex(df.index)\n assert_frame_equal(df, il_electricity_cdd_hdd_hourly)\n```\n" + ], + "line": 167, + "token": 992, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 64, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef pandasSQL_builder(\n con,\n schema: str | None = None,\n need_transaction: bool = False,\n) -> PandasSQL:\n import sqlite3\n\n if isinstance(con, sqlite3.Connection) or con is None:\n return SQLiteDatabase(con)\n\n sqlalchemy = import_optional_dependency(\"sqlalchemy\", errors=\"ignore\")\n\n if isinstance(con, str) and sqlalchemy is None:\n raise ImportError(\"Using URI string without sqlalchemy installed.\")\n\n if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):\n return SQLDatabase(con, schema, need_transaction)\n\n warnings.warn(\n \"pandas only supports SQLAlchemy connectable (engine/connection) or \"\n \"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 \"\n \"objects are not tested. Please consider using SQLAlchemy.\",\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return SQLiteDatabase(con)\n```\n###test function signature:\n```python\n\n def test_execute_closed_connection(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_execute_closed_connection", + "reference": "\n def test_execute_closed_connection(self):\n create_sql = \"\"\"\n CREATE TABLE test\n (\n a TEXT,\n b TEXT,\n c REAL,\n PRIMARY KEY (a, b)\n );\n \"\"\"\n with contextlib.closing(sqlite3.connect(\":memory:\")) as conn:\n cur = conn.cursor()\n cur.execute(create_sql)\n\n with sql.pandasSQL_builder(conn) as pandas_sql:\n pandas_sql.execute('INSERT INTO test VALUES(\"foo\", \"bar\", 1.234)')\n\n msg = \"Cannot operate on a closed database.\"\n with pytest.raises(sqlite3.ProgrammingError, match=msg):\n tquery(\"select * from test\", con=conn)", + "completions": [ + "```python\n def test_execute_closed_connection(self):\n from pandas.io.sql import PandasSQL\n from pandas.io.sql.sqldatabase import SQLiteDatabase\n from pandas.io.sql.pandas_sql import PandasSQLQueryCompiler\n from pandas.io.sql.pandas_sql import PandasSQLEngine\n from pandas.io.sql.pandas_sql import PandasSQLTable\n from pandas.io.sql.pandas_sql import PandasSQLConverter\n from pandas.io.sql.pandas_sql import PandasSQLAlchemy\n from pandas.io.sql.pandas_sql import PandasSQLQuery\n from pandas.io.sql.pandas_sql import PandasSQLExecutionContext\n from pandas.io.sql.pandas_sql import PandasSQLCompiler\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas." + ], + "line": 31, + "token": 237, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 65, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_hourly_degree_days", + "reference": "\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index,\n temperature_data,\n heating_balance_points=[60, 61],\n cooling_balance_points=[65, 66],\n temperature_mean=False,\n degree_day_method=\"hourly\",\n )\n assert df.shape == (27, 6)\n assert list(sorted(df.columns)) == [\n \"cdd_65\",\n \"cdd_66\",\n \"hdd_60\",\n \"hdd_61\",\n \"n_hours_dropped\",\n \"n_hours_kept\",\n ]\n snapshot.assert_match(\n [\n round(df.hdd_60.mean(), 2),\n round(df.hdd_61.mean(), 2),\n round(df.cdd_65.mean(), 2),\n round(df.cdd_66.mean(), 2),\n round(df.n_hours_kept.mean(), 2),\n round(df.n_hours_dropped.mean(), 2),\n ],\n \"values\",\n )", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"H\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 1" + ], + "line": 168, + "token": 985, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef initialize_model_parallel(\n tensor_model_parallel_size: int = 1,\n pipeline_model_parallel_size: int = 1,\n virtual_pipeline_model_parallel_size: Optional[int] = None,\n pipeline_model_parallel_split_rank: Optional[int] = None,\n use_fp8: bool = False,\n) -> None:\n # Get world size and rank. Ensure some consistencies.\n assert torch.distributed.is_initialized()\n world_size: int = torch.distributed.get_world_size()\n\n if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:\n raise RuntimeError(\n f\"world_size ({world_size}) is not divisible by tensor_model_parallel_size \"\n f\"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n )\n\n data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n\n num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n num_data_parallel_groups: int = world_size // data_parallel_size\n\n if virtual_pipeline_model_parallel_size is not None:\n if not pipeline_model_parallel_size > 2:\n raise RuntimeError(\"pipeline-model-parallel size should be greater than 2 with \" \"interleaved schedule\")\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size\n\n if pipeline_model_parallel_split_rank is not None:\n global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK\n _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank\n\n rank = torch.distributed.get_rank()\n\n # Build the data-parallel groups.\n global _DATA_PARALLEL_GROUP\n global _DATA_PARALLEL_GROUP_GLOO\n global _DATA_PARALLEL_GLOBAL_RANKS\n assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'\n all_data_parallel_group_ranks = []\n for i in range(pipeline_model_parallel_size):\n start_rank = i * num_pipeline_model_parallel_groups\n end_rank = (i + 1) * num_pipeline_model_parallel_groups\n for j in range(tensor_model_parallel_size):\n ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)\n all_data_parallel_group_ranks.append(list(ranks))\n group = torch.distributed.new_group(ranks)\n group_gloo = torch.distributed.new_group(ranks, backend=\"gloo\")\n if rank in ranks:\n _DATA_PARALLEL_GROUP = group\n _DATA_PARALLEL_GROUP_GLOO = group_gloo\n _DATA_PARALLEL_GLOBAL_RANKS = ranks\n\n # Build the model-parallel groups.\n global _MODEL_PARALLEL_GROUP\n assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'\n for i in range(data_parallel_size):\n ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _MODEL_PARALLEL_GROUP = group\n\n # Build the tensor model-parallel groups.\n global _TENSOR_MODEL_PARALLEL_GROUP\n assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized'\n for i in range(num_tensor_model_parallel_groups):\n ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _TENSOR_MODEL_PARALLEL_GROUP = group\n\n # Build the pipeline model-parallel groups and embedding groups\n # (first and last rank in each pipeline model-parallel group).\n global _PIPELINE_MODEL_PARALLEL_GROUP\n global _PIPELINE_GLOBAL_RANKS\n assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized'\n global _EMBEDDING_GROUP\n global _EMBEDDING_GLOBAL_RANKS\n assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'\n global _POSITION_EMBEDDING_GROUP\n global _POSITION_EMBEDDING_GLOBAL_RANKS\n assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'\n for i in range(num_pipeline_model_parallel_groups):\n ranks = range(i, world_size, num_pipeline_model_parallel_groups)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _PIPELINE_MODEL_PARALLEL_GROUP = group\n _PIPELINE_GLOBAL_RANKS = ranks\n # Setup embedding group (to exchange gradients between\n # first and last stages).\n if len(ranks) > 1:\n embedding_ranks = [ranks[0], ranks[-1]]\n position_embedding_ranks = [ranks[0]]\n if pipeline_model_parallel_split_rank is not None:\n if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:\n embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]]\n if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:\n position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]\n else:\n embedding_ranks = ranks\n position_embedding_ranks = ranks\n\n group = torch.distributed.new_group(embedding_ranks)\n if rank in embedding_ranks:\n _EMBEDDING_GROUP = group\n if rank in ranks:\n _EMBEDDING_GLOBAL_RANKS = embedding_ranks\n\n group = torch.distributed.new_group(position_embedding_ranks)\n if rank in position_embedding_ranks:\n _POSITION_EMBEDDING_GROUP = group\n if rank in ranks:\n _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks\n\n # Build the FP8 groups.\n global _AMAX_REDUCTION_GROUP\n assert _AMAX_REDUCTION_GROUP is None, \\\n 'FP8 amax reduction group is already initialized'\n if use_fp8:\n amax_group_size: int = tensor_model_parallel_size * data_parallel_size\n num_amax_groups: int = world_size // amax_group_size\n for i in range(num_amax_groups):\n start_rank = i * amax_group_size\n end_rank = (i + 1) * amax_group_size\n ranks = range(start_rank, end_rank)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _AMAX_REDUCTION_GROUP = group\n\n # Initialize global memory buffer\n # This isn't really \"parallel state\" but there isn't another good place to\n # put this. If we end up with a more generic initialization of megatron-core\n # we could stick it there\n _set_global_memory_buffer()\n```\n###test function signature:\n```python\n\ndef test_initialize__and_destroy_model_parallel():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_initialize__and_destroy_model_parallel", + "reference": "\ndef test_initialize__and_destroy_model_parallel():\n with pytest.raises(AssertionError):\n assert(ps.initialize_model_parallel())\n Utils.initialize_distributed()\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))\n Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)\n\n assert(ps.model_parallel_is_initialized())\n assert(ps.get_model_parallel_group() is not None)\n assert(ps.get_tensor_model_parallel_group() is not None)\n assert(ps.get_pipeline_model_parallel_group() is not None)\n assert(ps.get_data_parallel_group() is not None) \n Utils.destroy_model_parallel()\n assert(ps._MODEL_PARALLEL_GROUP is None)", + "completions": [ + "```python\ndef test_initialize_and_destroy_model_parallel():\n initialize_model_parallel(\n tensor_model_parallel_size=1,\n pipeline_model_parallel_size=1,\n virtual_pipeline_model_parallel_size=None,\n pipeline_model_parallel_split_rank=None,\n use_fp8=False,\n )\n destroy_model_parallel()\n```\n" + ], + "line": 133, + "token": 886, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_shapes(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_shapes", + "reference": "\n def test_spectrogram_shapes(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 729))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 183))", + "completions": [ + "```python\n def test_spectrogram_shapes(self):\n waveform = np.random.rand(1000)\n window = np.hanning(100)\n frame_length = 100\n hop_length = 10\n fft_length = 100\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (10, 100))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_base64_image_given_and_no_resize_needed", + "reference": "\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n _, image = example_local_image\n _, buffer = cv2.imencode(\".jpg\", image)\n base64_image = base64.b64encode(buffer).decode(\"ascii\")\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=base64_image,\n max_width=None,\n max_height=None,\n )\n\n # then\n assert (\n scaling_factor is None\n ), \"No resize parameters given, so scaling factor should not be established\"\n assert (\n serialised_image == base64_image\n ), \"Serialised image should be identical with input\"", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n image_bytes = base64.b64decode(example_local_image[0])\n image = bytes_to_opencv_image(payload=image_bytes)\n # when\n result = load_image_from_string(\n reference=example_local_image[0],\n max_height=None,\n max_width=None,\n )\n # then\n assert result[0] == example_local_image[0]\n assert result[1] is None\n assert np.array_equal(image, result[1])\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 69, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def submit_payout_address(self, robot_index=1):\n path = reverse(\"order\")\n params = f\"?order_id={self.order_id}\"\n headers = self.get_robot_auth(robot_index)\n\n payout_address = create_address(\"robot\")\n signed_payout_address = sign_message(\n payout_address,\n passphrase_path=f\"tests/robots/{robot_index}/token\",\n private_key_path=f\"tests/robots/{robot_index}/enc_priv_key\",\n )\n body = {\n \"action\": \"update_address\",\n \"address\": signed_payout_address,\n \"mining_fee_rate\": 50,\n }\n self.response = self.client.post(path + params, body, **headers)\n```\n###test function signature:\n```python\n def test_successful_onchain(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_successful_onchain", + "reference": " def test_successful_onchain(self):\n trade = Trade(self.client)\n trade.publish_order()\n trade.take_order()\n trade.lock_taker_bond()\n trade.lock_escrow(trade.taker_index)\n trade.submit_payout_address(trade.maker_index)\n trade.confirm_fiat(trade.maker_index)\n trade.confirm_fiat(trade.taker_index)\n\n trade.process_payouts(mine_a_block=True)\n trade.get_order(trade.maker_index)\n\n data = trade.response.json()\n\n self.assertEqual(trade.response.status_code, 200)\n self.assertResponse(trade.response)\n\n self.assertEqual(data[\"status_message\"], Order.Status(Order.Status.SUC).label)\n self.assertTrue(data[\"is_fiat_sent\"])\n self.assertFalse(data[\"is_disputed\"])\n self.assertIsInstance(data[\"maker_summary\"][\"address\"], str)\n self.assertIsHash(data[\"maker_summary\"][\"txid\"])\n\n self.assert_order_logs(data[\"id\"])", + "completions": [ + "```python\n def test_successful_onchain(self):\n self.submit_payout_address(robot_index=1)\n self.assertEqual(self.response.status_code, 200)\n self.assertEqual(self.response.json()[\"status\"], \"success\")\n```\n" + ], + "line": 26, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 70, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_blank_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_blank_line", + "reference": "\ndef test_header_to_metadata_and_cell_blank_line():\n text = \"\"\"---\ntitle: Sample header\n---\n\nHeader is followed by a blank line\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {}\n assert lines[pos].startswith(\"Header is\")", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_blank_line():\n lines = [\"#' ---\", \"\"]\n header_prefix = \"#'\"\n header_suffix = \"\"\n ext = None\n root_level_metadata_as_raw_cell = True\n metadata, jupyter, cell, i = header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext, root_level_metadata_as_raw_cell\n )\n assert metadata == {}\n assert jupyter == []\n assert cell == new_raw_cell(source=\"\\n\".join([\"---\"]), metadata={\"lines_to_next_cell\": 1})\n assert i == 1\n```\n" + ], + "line": 104, + "token": 642, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 71, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def _create_output(\n self,\n accepted: torch.Tensor, # [batch_size, k]\n recovered_token_ids: torch.Tensor, # [batch_size, k]\n draft_token_ids: torch.Tensor, # [batch_size, k]\n bonus_token_ids: torch.Tensor, # [batch_size]\n ) -> torch.Tensor:\n bonus_token_ids = bonus_token_ids.squeeze()\n batch_size, k = recovered_token_ids.shape\n\n # Determine the index of the first False value for each row.\n limits = (accepted == 0).max(1).indices\n limits[~(accepted == 0).any(1)] = k\n\n # Create masks using the indices.\n indices = torch.arange(k, device=accepted.device).unsqueeze(0)\n accepted_mask = indices < limits.unsqueeze(1)\n after_false_mask = indices == limits.unsqueeze(1)\n\n # Create an extended output tensor\n output_with_bonus_tokens = -torch.ones(\n (batch_size, k + self._num_bonus_tokens),\n dtype=self.token_id_dtype,\n device=accepted.device)\n output = output_with_bonus_tokens[:, :k]\n\n # Fill in the first k columns of the output tensor using masks and data\n # tensors.\n output[:, :k] = torch.where(accepted_mask, draft_token_ids,\n -torch.ones_like(draft_token_ids))\n\n # Fill the last column.\n # We check output directly as accepted may have True values inconsistent\n # with causal acceptance.\n output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,\n bonus_token_ids, -1)\n\n # We disable bonus tokens because it causes corrupt KV cache for\n # proposal methods that require KV cache. We can fix it by \"prefilling\"\n # the bonus token in the proposer. The following issue tracks the fix.\n # https://github.com/vllm-project/vllm/issues/4212\n output_with_bonus_tokens[:, -1] = -1\n\n # Fill the recovered token ids.\n output.mul_(~after_false_mask).add_(\n recovered_token_ids.mul(after_false_mask))\n\n self.num_accepted_tokens += accepted.sum()\n self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()\n self.num_draft_tokens += batch_size * k\n\n return output_with_bonus_tokens\n```\n###test function signature:\n```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct_output_format", + "reference": "def test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n set_random_seed(seed)\n torch.set_default_device(device)\n\n batch_size = 10\n k = 5\n vocab_size = 3000\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"no_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"some_tokens_accepted\":\n last_accepted_indices = torch.randint(low=-1,\n high=k,\n size=(batch_size, ))\n accepted = mock_causal_accepted_tensor(k, last_accepted_indices)\n else:\n raise AssertionError()\n\n recovered_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n draft_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n bonus_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, 1),\n dtype=torch.int64)\n\n rejection_sampler = RejectionSampler()\n rejection_sampler.init_gpu_tensors(rank=0)\n output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access\n accepted,\n recovered_token_ids,\n draft_token_ids,\n bonus_token_ids,\n )\n\n # Bonus tokens are currently disabled. Verify they're set to -1.\n # See https://github.com/vllm-project/vllm/issues/4212\n expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n # Expect all tokens to be equal to draft tokens.\n assert torch.equal(output_token_ids[:, :-1], draft_token_ids)\n\n # Expect all bonus tokens to be included.\n assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)\n elif which_tokens_accepted == \"no_tokens_accepted\":\n # Expect first token to be equal to recovered tokens.\n assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])\n\n # Expect everything else to be -1.\n assert torch.equal(output_token_ids[:, 1:],\n torch.ones_like(output_token_ids[:, 1:]) * -1)\n elif which_tokens_accepted == \"some_tokens_accepted\":\n recovered_plus_bonus = torch.cat(\n (recovered_token_ids, expected_bonus_token_ids), dim=-1)\n # Assert first rejected token is a recovered token or bonus token.\n assert torch.equal(\n recovered_plus_bonus[torch.arange(0, batch_size),\n last_accepted_indices + 1],\n output_token_ids[torch.arange(0, batch_size),\n last_accepted_indices + 1])\n\n # Assert every subsequent token is -1.\n subsequent_mask = torch.arange(0, k + 1).expand(\n batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)\n assert torch.all(output_token_ids[subsequent_mask] == -1)", + "completions": [ + "```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n torch.manual_seed(seed)\n batch_size = 10\n k = 10\n num_bonus_tokens = 1\n accepted = torch.randint(0, 2, (batch_size, k), dtype=torch.int32)\n recovered_token_ids = torch.randint(0, 10, (batch_size, k),\n dtype=torch.int32)\n draft_token_ids = torch.randint(0, 10, (batch_size, k), dtype=torch.int32)\n bonus_token_ids = torch.randint(0, 10, (batch_size,), dtype=torch.int32)\n output = _create_output(accepted, recovered_token_ids, draft_token_ids,\n bonus_token_ids)\n assert output.shape == (batch_size, k + num_bonus_tokens)\n assert output.dtype == torch.int32\n assert output.device == device\n assert (output[:, :k] == torch.where(accepted, draft_token_ids,\n -torch.ones_like(draft_token_ids))).all()\n assert (output[:, -1] == torch.where(output[:, -1] != -1, bonus_token_ids,\n -1)).all()\n assert (output[:, -1] == -1).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted" + ], + "line": 54, + "token": 433, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 72, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_array(fp, allow_pickle=False, pickle_kwargs=None, *,\n max_header_size=_MAX_HEADER_SIZE):\n if allow_pickle:\n # Effectively ignore max_header_size, since `allow_pickle` indicates\n # that the input is fully trusted.\n max_header_size = 2**64\n\n version = read_magic(fp)\n _check_version(version)\n shape, fortran_order, dtype = _read_array_header(\n fp, version, max_header_size=max_header_size)\n if len(shape) == 0:\n count = 1\n else:\n count = numpy.multiply.reduce(shape, dtype=numpy.int64)\n\n # Now read the actual data.\n if dtype.hasobject:\n # The array contained Python objects. We need to unpickle the data.\n if not allow_pickle:\n raise ValueError(\"Object arrays cannot be loaded when \"\n \"allow_pickle=False\")\n if pickle_kwargs is None:\n pickle_kwargs = {}\n try:\n array = pickle.load(fp, **pickle_kwargs)\n except UnicodeError as err:\n # Friendlier error message\n raise UnicodeError(\"Unpickling a python object failed: %r\\n\"\n \"You may need to pass the encoding= option \"\n \"to numpy.load\" % (err,)) from err\n else:\n if isfileobj(fp):\n # We can use the fast fromfile() function.\n array = numpy.fromfile(fp, dtype=dtype, count=count)\n else:\n # This is not a real file. We have to read it the\n # memory-intensive way.\n # crc32 module fails on reads greater than 2 ** 32 bytes,\n # breaking large reads from gzip streams. Chunk reads to\n # BUFFER_SIZE bytes to avoid issue and reduce memory overhead\n # of the read. In non-chunked case count < max_read_count, so\n # only one read is performed.\n\n # Use np.ndarray instead of np.empty since the latter does\n # not correctly instantiate zero-width string dtypes; see\n # https://github.com/numpy/numpy/pull/6430\n array = numpy.ndarray(count, dtype=dtype)\n\n if dtype.itemsize > 0:\n # If dtype.itemsize == 0 then there's nothing more to read\n max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)\n\n for i in range(0, count, max_read_count):\n read_count = min(max_read_count, count - i)\n read_size = int(read_count * dtype.itemsize)\n data = _read_bytes(fp, read_size, \"array data\")\n array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n count=read_count)\n\n if fortran_order:\n array.shape = shape[::-1]\n array = array.transpose()\n else:\n array.shape = shape\n\n return array\n```\n###test function signature:\n```python\n\ndef test_version_2_0():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_version_2_0", + "reference": "\ndef test_version_2_0():\n f = BytesIO()\n # requires more than 2 byte for header\n dt = [((\"%d\" % i) * 100, float) for i in range(500)]\n d = np.ones(1000, dtype=dt)\n\n format.write_array(f, d, version=(2, 0))\n with warnings.catch_warnings(record=True) as w:\n warnings.filterwarnings('always', '', UserWarning)\n format.write_array(f, d)\n assert_(w[0].category is UserWarning)\n\n # check alignment of data portion\n f.seek(0)\n header = f.readline()\n assert_(len(header) % format.ARRAY_ALIGN == 0)\n\n f.seek(0)\n n = format.read_array(f, max_header_size=200000)\n assert_array_equal(d, n)\n\n # 1.0 requested but data cannot be saved this way\n assert_raises(ValueError, format.write_array, f, d, (1, 0))", + "completions": [ + "```python\ndef test_version_2_0():\n fp = open(\"test_file.npy\", \"rb\")\n read_array(fp)\n```\n" + ], + "line": 70, + "token": 516, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 73, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):\n if not torch.is_tensor(audio):\n if isinstance(audio, str):\n audio = load_audio(audio)\n audio = torch.from_numpy(audio)\n\n window = torch.hann_window(N_FFT).to(audio.device)\n stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)\n magnitudes = stft[..., :-1].abs() ** 2\n\n filters = mel_filters(audio.device, n_mels)\n mel_spec = filters @ magnitudes\n\n log_spec = torch.clamp(mel_spec, min=1e-10).log10()\n log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)\n log_spec = (log_spec + 4.0) / 4.0\n return log_spec\n```\n###test function signature:\n```python\n\ndef test_audio():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_audio", + "reference": "\ndef test_audio():\n audio_path = os.path.join(os.path.dirname(__file__), \"jfk.flac\")\n audio = load_audio(audio_path)\n assert audio.ndim == 1\n assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12\n assert 0 < audio.std() < 1\n\n mel_from_audio = log_mel_spectrogram(audio)\n mel_from_file = log_mel_spectrogram(audio_path)\n\n assert np.allclose(mel_from_audio, mel_from_file)\n assert mel_from_audio.max() - mel_from_audio.min() <= 2.0", + "completions": [ + "```python\ndef test_audio():\n audio = torch.randn(1, 16000)\n log_mel_spectrogram(audio)\n```\n" + ], + "line": 24, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 74, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef transform_grid_cell(\n grid: Grid,\n cell: npt.NDArray[np.float_] = np.eye(3),\n output_shape: Optional[Tuple[int, int, int]] = None,\n mode: str = \"constant\",\n order: int = 1,\n **kwargs,\n) -> Grid:\n # Take the current shape of the grid if no output shape was provided\n if output_shape is None:\n output_shape = grid.shape\n\n # Make sure the cell has type float\n cell = np.asarray(cell, dtype=float)\n\n # Get the current cell in coordinates of the destination axes\n inv_cell = cell_invert(cell).T\n projected_cell = grid.cell.dot(inv_cell)\n\n # From that, infere how long will the bounding box of the cell be\n lengths = abs(projected_cell).sum(axis=0)\n\n # Create the transformation matrix. Since we want to control the shape\n # of the output, we can not use grid.dcell directly, we need to modify it.\n scales = output_shape / lengths\n forward_t = (grid.dcell.dot(inv_cell) * scales).T\n\n # Scipy's affine transform asks for the inverse transformation matrix, to\n # map from output pixels to input pixels. By taking the inverse of our\n # transformation matrix, we get exactly that.\n tr = cell_invert(forward_t).T\n\n # Calculate the offset of the image so that all points of the grid \"fall\" inside\n # the output array.\n # For this we just calculate the centers of the input and output images\n center_input = 0.5 * (_a.asarrayd(grid.shape) - 1)\n center_output = 0.5 * (_a.asarrayd(output_shape) - 1)\n\n # And then make sure that the input center that is interpolated from the output\n # falls in the actual input's center\n offset = center_input - tr.dot(center_output)\n\n # We pass all the parameters to scipy's affine_transform\n transformed_image = affine_transform(\n grid.grid,\n tr,\n order=1,\n offset=offset,\n output_shape=output_shape,\n mode=mode,\n **kwargs,\n )\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1))\n new_grid.grid = transformed_image\n new_grid.geometry = grid.geometry\n new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset))\n\n # Find the offset between the origin before and after the transformation\n return new_grid\n```\n###test function signature:\n```python\n\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_grid_cell", + "reference": "\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10))\n\n assert new_grid.shape == (10, 10, 10)\n assert new_grid.lattice.is_cartesian()\n\n if not skewed:\n assert np.allclose(new_grid.lattice.cell, grid.lattice.cell)\n\n assert np.allclose(new_grid.grid, grid.grid) == (not skewed)\n assert not np.allclose(new_grid.grid, 0)\n\n assert np.allclose(new_grid.origin, grid.origin) == (not skewed)\n\n # Convert to a skewed cell\n directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]])\n new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5))\n\n assert new_grid.shape == (5, 5, 5)\n for i in range(3):\n n = new_grid.lattice.cell[i] / directions[i]\n assert np.allclose(n, n[0])", + "completions": [ + "```python\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n cell = grid.cell.dot(cell_invert(grid.cell))\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = transform_grid_cell(grid, cell, grid.shape, mode=\"constant\", order=1)\n\n # Check that the new grid has the same shape and cell as the original grid\n assert new_grid.shape == grid.shape\n assert np.allclose(new_grid.cell, grid.cell)\n\n # Check that the new grid has the same geometry as the original grid\n assert new_grid.geometry == grid.geometry\n\n # Check that the new grid has the same origin as the original grid\n assert np.allclose(new_grid.origin, grid.origin)\n\n # Check that the new grid has the same dcell as the original grid\n assert np.allclose(new_grid.dcell, grid.dcell)\n\n # Check that the new grid has the same grid as the original grid\n assert np.allclose(new_grid.grid, grid.grid)\n\n # Check that the new grid has the same skewed as the original grid\n assert np.allclose(new_grid.skewed, grid.skewed)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid" + ], + "line": 63, + "token": 533, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:\n\n index_norm = norm(index, axis=1)\n xq_norm = norm(xq.T)\n sim = np.dot(index, xq.T) / (index_norm * xq_norm)\n return sim\n```\n###test function signature:\n```python\ndef test_top_scores__is_sorted(test_index):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_top_scores__is_sorted", + "reference": "def test_top_scores__is_sorted(test_index):\n\n xq = test_index[0] # should have max similarity\n\n sim = similarity_matrix(xq, test_index)\n _, idx = top_scores(sim, 3)\n\n # Scores and indexes should be sorted ascending\n assert np.array_equal(idx, np.array([2, 1, 0]))", + "completions": [ + "```python\ndef test_top_scores__is_sorted(test_index):\n sim = similarity_matrix(xq, test_index)\n assert np.all(np.diff(sim, axis=1) >= 0)\n```\n" + ], + "line": 15, + "token": 125, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_mixed_ref_area_2", + "reference": "\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n\n out_bars_with_mixed_ref_area_2 = _generate_nanoplot(**CASES[16])\n\n assert _is_nanoplot_output(out_bars_with_mixed_ref_area_2)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n y_ref_area = [1, 10]\n y_ref_line = 5\n x_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n expand_x = \"2020-01-01 00:00:00 UTC\"\n expand_y = \"2020-01-01 00:00:00 UTC\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n all_single_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_" + ], + "line": 749, + "token": 7038, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 77, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def mean(self, lower=0.0, upper=np.inf):\n loc = []\n for idx in self.index:\n qpd = self.qpd.loc[idx, :].values[0]\n l, _ = quad(exp_func, args=(qpd), a=lower, b=upper)\n loc.append(l)\n loc_arr = np.array(loc)\n return pd.DataFrame(loc_arr, index=self.index, columns=self.columns)\n```\n###test function signature:\n```python\ndef test_qpd_b_simple_use():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_qpd_b_simple_use", + "reference": "def test_qpd_b_simple_use():\n from skpro.distributions.qpd import QPD_B\n\n qpd = QPD_B(\n alpha=0.2,\n qv_low=[1, 2],\n qv_median=[3, 4],\n qv_high=[5, 6],\n lower=0,\n upper=10,\n )\n\n qpd.mean()", + "completions": [ + "```python\ndef test_qpd_b_simple_use():\n import numpy as np\n import pandas as pd\n from scipy.integrate import quad\n from scipy.special import expit\n\n def exp_func(x, qpd):\n return expit(qpd * x)\n\n qpd = pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z'])\n qpd_b = qpd.mean(lower=0.0, upper=np.inf)\n assert qpd_b.equals(pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z']))\n```\n" + ], + "line": 18, + "token": 146, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef create(\n supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False\n) -> adapters.HTTPAdapter:\n settings = supply_settings()\n allowed_methods = set(util.Retry.DEFAULT_ALLOWED_METHODS)\n if include_post:\n allowed_methods.add(\"POST\")\n\n retries = util.Retry(\n total=settings.total,\n backoff_factor=settings.backoff_factor,\n status_forcelist=settings.statuses,\n allowed_methods=allowed_methods,\n )\n return adapters.HTTPAdapter(max_retries=retries)\n```\n###test function signature:\n```python\n\ndef test_should_create_adapter(include_post):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_should_create_adapter", + "reference": "\ndef test_should_create_adapter(include_post):\n total_retries = 5\n backoff_factor = 0.4\n statuses = (random.randint(400, 600),)\n settings = config.ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses)\n adapter = http_adapter.create(supply_settings=lambda: settings, include_post=include_post)\n assert adapter.max_retries.total == total_retries\n assert adapter.max_retries.backoff_factor == backoff_factor\n assert adapter.max_retries.status_forcelist == statuses\n allowed_methods = adapter.max_retries.allowed_methods or []\n if include_post:\n assert \"POST\" in allowed_methods\n else:\n assert \"POST\" not in allowed_methods", + "completions": [ + "```python\n\ndef test_should_create_adapter(include_post):\n assert create(include_post=include_post) == adapters.HTTPAdapter(max_retries=util.Retry(\n total=config.ResilienceSettings().total,\n backoff_factor=config.ResilienceSettings().backoff_factor,\n status_forcelist=config.ResilienceSettings().statuses,\n allowed_methods=set(util.Retry.DEFAULT_ALLOWED_METHODS) if not include_post else set(util.Retry.DEFAULT_ALLOWED_METHODS).union({\"POST\"})\n ))\n```\n" + ], + "line": 24, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 79, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:\n summary_path = os.path.join(trial_path, 'summary.csv')\n if not os.path.exists(summary_path):\n raise ValueError(f\"summary.csv does not exist in {trial_path}.\")\n trial_summary_df = load_summary_file(summary_path, dict_columns=['best_module_params'])\n config_yaml_path = os.path.join(trial_path, 'config.yaml')\n with open(config_yaml_path, 'r') as f:\n config_dict = yaml.safe_load(f)\n yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)\n if output_path is not None:\n with open(output_path, 'w') as f:\n yaml.dump(yaml_dict, f)\n return yaml_dict\n```\n###test function signature:\n```python\n\ndef test_runner(evaluator):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_runner", + "reference": "\ndef test_runner(evaluator):\n os.environ['BM25'] = 'bm25'\n evaluator.start_trial(os.path.join(resource_dir, 'simple.yaml'))\n project_dir = evaluator.project_dir\n\n def runner_test(runner: Runner):\n answer = runner.run('What is the best movie in Korea? Have Korea movie ever won Oscar?',\n 'retrieved_contents')\n assert len(answer) == 1\n assert isinstance(answer, list)\n assert isinstance(answer[0], str)\n\n runner = Runner.from_trial_folder(os.path.join(project_dir, '0'))\n runner_test(runner)\n runner_test(runner)\n\n with tempfile.NamedTemporaryFile(suffix='yaml', mode='w+t') as yaml_path:\n extract_best_config(os.path.join(project_dir, '0'), yaml_path.name)\n runner = Runner.from_yaml(yaml_path.name, project_dir=project_dir)\n runner_test(runner)", + "completions": [ + "```python\ndef test_runner(evaluator):\n def test_extract_best_config():\n trial_path = 'path/to/trial'\n output_path = 'path/to/output'\n yaml_dict = extract_best_config(trial_path, output_path)\n assert yaml_dict is not None\n assert isinstance(yaml_dict, dict)\n assert 'best_module_params' in yaml_dict\n assert isinstance(yaml_dict['best_module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['" + ], + "line": 23, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 80, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def init_weights(self):\n logger = MMLogger.get_current_instance()\n if self.init_cfg is None:\n logger.warn(f'No pre-trained weights for '\n f'{self.__class__.__name__}, '\n f'training start from scratch')\n if self.use_abs_pos_embed:\n trunc_normal_(self.absolute_pos_embed, std=0.02)\n for m in self.modules():\n if isinstance(m, nn.Linear):\n trunc_normal_init(m, std=.02, bias=0.)\n elif isinstance(m, nn.LayerNorm):\n constant_init(m, 1.0)\n else:\n assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n f'specify `Pretrained` in ' \\\n f'`init_cfg` in ' \\\n f'{self.__class__.__name__} '\n ckpt = CheckpointLoader.load_checkpoint(\n self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n if 'state_dict' in ckpt:\n _state_dict = ckpt['state_dict']\n elif 'model' in ckpt:\n _state_dict = ckpt['model']\n else:\n _state_dict = ckpt\n if self.convert_weights:\n # supported loading weight from original repo,\n _state_dict = swin_converter(_state_dict)\n\n state_dict = OrderedDict()\n for k, v in _state_dict.items():\n if k.startswith('backbone.'):\n state_dict[k[9:]] = v\n\n # strip prefix of state_dict\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n # reshape absolute position embedding\n if state_dict.get('absolute_pos_embed') is not None:\n absolute_pos_embed = state_dict['absolute_pos_embed']\n N1, L, C1 = absolute_pos_embed.size()\n N2, C2, H, W = self.absolute_pos_embed.size()\n if N1 != N2 or C1 != C2 or L != H * W:\n logger.warning('Error in loading absolute_pos_embed, pass')\n else:\n state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n # interpolate position bias table if needed\n relative_position_bias_table_keys = [\n k for k in state_dict.keys()\n if 'relative_position_bias_table' in k\n ]\n for table_key in relative_position_bias_table_keys:\n table_pretrained = state_dict[table_key]\n table_current = self.state_dict()[table_key]\n L1, nH1 = table_pretrained.size()\n L2, nH2 = table_current.size()\n if nH1 != nH2:\n logger.warning(f'Error in loading {table_key}, pass')\n elif L1 != L2:\n S1 = int(L1**0.5)\n S2 = int(L2**0.5)\n table_pretrained_resized = F.interpolate(\n table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n size=(S2, S2),\n mode='bicubic')\n state_dict[table_key] = table_pretrained_resized.view(\n nH2, L2).permute(1, 0).contiguous()\n\n # load state_dict\n self.load_state_dict(state_dict, False)\n```\n###test function signature:\n```python\ndef test_swin_transformer():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_swin_transformer", + "reference": "def test_swin_transformer():\n\n with pytest.raises(TypeError):\n # Pretrained arg must be str or None.\n SwinTransformer(pretrained=123)\n\n with pytest.raises(AssertionError):\n # Because swin uses non-overlapping patch embed, so the stride of patch\n # embed must be equal to patch size.\n SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)\n\n # test pretrained image size\n with pytest.raises(AssertionError):\n SwinTransformer(pretrain_img_size=(224, 224, 224))\n\n # Test absolute position embedding\n temp = torch.randn((1, 3, 224, 224))\n model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)\n model.init_weights()\n model(temp)\n\n # Test patch norm\n model = SwinTransformer(patch_norm=False)\n model(temp)\n\n # Test normal inference\n temp = torch.randn((1, 3, 32, 32))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 31, 31))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 112, 137))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 28, 35)\n assert outs[1].shape == (1, 192, 14, 18)\n assert outs[2].shape == (1, 384, 7, 9)\n assert outs[3].shape == (1, 768, 4, 5)\n\n model = SwinTransformer(frozen_stages=4)\n model.train()\n for p in model.parameters():\n assert not p.requires_grad", + "completions": [ + "```python\ndef test_swin_transformer():\n from mmcv.cnn import (build_pos_embed, build_pos_embed_sine,\n constant_init, trunc_normal_init)\n from mmcv.runner import CheckpointLoader\n from mmdet.models.builder import build_backbone\n from mmdet.models.utils import swin_converter\n from mmdet.utils import get_root_logger\n from mmdet.utils.logger import MMLogger\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import PatchEmbed\n from mmdet.models.backbones.swin import PatchMerging\n from mmdet.models.backbones.swin import BasicLayer\n from mmdet.models.backbones.swin import BasicBlock\n from mmdet.models.backbones.swin import DepthwiseSeparableConv\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransform" + ], + "line": 79, + "token": 586, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 81, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\n\ndef test_date_columns_equal():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_date_columns_equal", + "reference": "\ndef test_date_columns_equal():\n data = \"\"\"a|b|expected\n2017-01-01|2017-01-01|True\n2017-01-02|2017-01-02|True\n2017-10-01|2017-10-10|False\n2017-01-01||False\n|2017-01-01|False\n||True\"\"\"\n df = pl.read_csv(\n io.StringIO(data),\n separator=\"|\",\n null_values=[\"NULL\"],\n missing_utf8_is_empty_string=True,\n )\n # First compare just the strings\n actual_out = columns_equal(df[\"a\"], df[\"b\"], rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)\n\n # Then compare converted to datetime objects\n col_a = df[\"a\"].str.to_datetime(strict=False)\n col_b = df[\"b\"].str.to_datetime(strict=False)\n actual_out = columns_equal(col_a, col_b, rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)", + "completions": [ + "```python\ndef test_date_columns_equal():\n col_1 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n col_2 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 82, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def upsert_child_firing(\n db: PrefectDBInterface,\n session: AsyncSession,\n firing: Firing,\n):\n automation_id = firing.trigger.automation.id\n parent_trigger_id = firing.trigger.parent.id\n child_trigger_id = firing.trigger.id\n\n upsert = (\n postgresql.insert(db.CompositeTriggerChildFiring)\n .values(\n automation_id=automation_id,\n parent_trigger_id=parent_trigger_id,\n child_trigger_id=child_trigger_id,\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n )\n .on_conflict_do_update(\n index_elements=[\n db.CompositeTriggerChildFiring.automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id,\n ],\n set_=dict(\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n updated=pendulum.now(\"UTC\"),\n ),\n )\n )\n\n await session.execute(upsert)\n\n result = await session.execute(\n sa.select(db.CompositeTriggerChildFiring).filter(\n db.CompositeTriggerChildFiring.automation_id == automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id == parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id == child_trigger_id,\n )\n )\n\n return result.scalars().one()\n```\n###test function signature:\n```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_insert_compound_trigger_child_firing", + "reference": "\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n firing = Firing(\n trigger=saw_a_red_dragon,\n trigger_states={TriggerState.Triggered},\n triggered=frozen_time.add(seconds=1),\n triggering_labels={},\n triggering_event=baby_red_dragon_passed_by,\n )\n res = await upsert_child_firing(\n session=automations_session,\n firing=firing,\n )\n\n assert res and res.id\n assert res.automation_id == raise_the_alarm.id\n assert res.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert res.child_trigger_id == saw_a_red_dragon.id\n assert res.child_firing_id == firing.id\n assert res.child_fired_at == firing.triggered\n\n child_firings = await get_child_firings(\n session=automations_session, trigger=saw_a_red_and_green_dragon\n )\n assert len(child_firings) == 1\n assert child_firings[0].id == res.id\n assert child_firings[0].child_firing_id == firing.id", + "completions": [ + "```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n \"\"\"\n Test that a compound trigger child firing is inserted correctly\n \"\"\"\n # Arrange\n firing = Firing(\n id=uuid4(),\n triggered=frozen_time,\n trigger=saw_a_red_and_green_dragon,\n event=baby_red_dragon_passed_by,\n )\n\n # Act\n result = await upsert_child_firing(\n db=automations_session,\n session=automations_session,\n firing=firing,\n )\n\n # Assert\n assert result.automation_id == raise_the_alarm.id\n assert result.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert result.child_trigger_id == saw_a_red_dragon.id\n assert result.child_firing_id == firing.id\n assert result.child_fired_at == firing.triggered\n assert result.child_firing == firing.dict()\n assert result.updated == frozen_time\n\n```\n" + ], + "line": 58, + "token": 225, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 83, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef write_output(output_file, output_type, results):\n\n # If json or txt, write contents directly to file.\n if output_type == \"json\" or output_type == \"txt\":\n with open(output_file, \"w\") as f:\n json.dump(results, f, indent=4)\n return\n # If csv or xlsx, convert results to pandas dataframes.\n urls_df = get_urls_df(results)\n codes_df = get_codes_df(results)\n\n # If csv, write dataframes to separate csv files for urls, codes.\n if output_type == \"csv\":\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n urls_df.to_csv(urls_output_file, index=False)\n codes_df.to_csv(codes_output_file, index=False)\n\n # If xlsx, write dataframes to separate sheets for urls, codes.\n if output_type == \"xlsx\":\n writer = pd.ExcelWriter(output_file, engine=\"xlsxwriter\")\n urls_df.to_excel(writer, sheet_name=\"URLs\", index=False)\n codes_df.to_excel(writer, sheet_name=\"Codes\", index=False)\n writer.close()\n```\n###test function signature:\n```python\n def test_write_output_csv(self, mock_urls, mock_codes):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_output_csv", + "reference": " def test_write_output_csv(self, mock_urls, mock_codes):\n\n test_file = \"./test_output/test_file.csv\"\n test_file_urls = \"./test_output/test_file_urls.csv\"\n test_file_codes = \"./test_output/test_file_codes.csv\"\n test_results = {\"test\": \"test\"}\n mock_urls.return_value = pd.DataFrame([test_results])\n mock_codes.return_value = pd.DataFrame([test_results])\n\n with open(test_file_urls, \"w\") as f:\n pass\n\n with open(test_file_codes, \"w\") as f:\n pass\n\n write_output(test_file, \"csv\", test_results)\n\n test_data_urls = pd.read_csv(test_file_urls).to_dict(orient=\"records\")[0]\n test_data_codes = pd.read_csv(test_file_codes).to_dict(orient=\"records\")[0]\n\n os.remove(test_file_urls)\n os.remove(test_file_codes)\n\n self.assertEqual(test_data_urls, test_results)\n self.assertEqual(test_data_codes, test_results)", + "completions": [ + "```python\ndef test_write_output_csv(self, mock_urls, mock_codes):\n \"\"\"\n Test write_output function with csv output type.\n \"\"\"\n output_file = \"output.csv\"\n output_type = \"csv\"\n results = {\"urls\": mock_urls, \"codes\": mock_codes}\n write_output(output_file, output_type, results)\n\n # Check that the urls and codes dataframes are written to separate csv files.\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n self.assertTrue(os.path.isfile(urls_output_file))\n self.assertTrue(os.path.isfile(codes_output_file))\n\n # Check that the dataframes are correct.\n urls_df = pd.read_csv(urls_output_file)\n codes_df = pd.read_csv(codes_output_file)\n self.assertEqual(urls_df.shape, (2, 2))\n self.assertEqual(codes_df.shape, (2, 2))\n self.assertEqual(urls_df.iloc[0, 0], \"url1\")\n self.assertEqual(codes_df.iloc[0, 0], \"code1\")\n self.assertEqual(urls_df.iloc[1, 0], \"url2\")\n self.assertEqual(codes_df.iloc[1, 0], \"code2\")\n```\n" + ], + "line": 31, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 84, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def terminate_idle_instance() -> None:\n async with get_session_ctx() as session:\n async with PROCESSING_POOL_LOCK:\n res = await session.execute(\n select(InstanceModel)\n .where(\n InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE,\n InstanceModel.deleted == False,\n InstanceModel.job == None, # noqa: E711\n InstanceModel.status == InstanceStatus.IDLE,\n )\n .options(joinedload(InstanceModel.project))\n )\n instances = res.scalars().all()\n\n for instance in instances:\n last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)\n if instance.last_job_processed_at is not None:\n last_time = instance.last_job_processed_at.replace(\n tzinfo=datetime.timezone.utc\n )\n\n idle_seconds = instance.termination_idle_time\n delta = datetime.timedelta(seconds=idle_seconds)\n\n current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc)\n\n if last_time + delta < current_time:\n jpd = JobProvisioningData.__response__.parse_raw(\n instance.job_provisioning_data\n )\n await terminate_job_provisioning_data_instance(\n project=instance.project, job_provisioning_data=jpd\n )\n instance.deleted = True\n instance.deleted_at = get_current_datetime()\n instance.finished_at = get_current_datetime()\n instance.status = InstanceStatus.TERMINATED\n instance.termination_reason = \"Idle timeout\"\n\n idle_time = current_time - last_time\n logger.info(\n \"Instance %s terminated by termination policy: idle time %ss\",\n instance.name,\n str(idle_time.seconds),\n extra={\n \"instance_name\": instance.name,\n \"instance_status\": InstanceStatus.TERMINATED.value,\n },\n )\n\n await session.commit()\n```\n###test function signature:\n```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_terminate_by_idle_timeout", + "reference": "\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n project = await create_project(session=session)\n pool = await create_pool(session, project)\n\n instance = await create_instance(session, project, pool, status=InstanceStatus.IDLE)\n instance.termination_idle_time = 300\n instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE\n instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)\n await session.commit()\n\n with patch(\n \"dstack._internal.server.background.tasks.process_instances.terminate_job_provisioning_data_instance\"\n ):\n await terminate_idle_instance()\n\n await session.refresh(instance)\n\n assert instance is not None\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"", + "completions": [ + "```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n # given\n project = ProjectModel(name=\"test_project\")\n await session.add(project)\n await session.commit()\n\n instance = InstanceModel(\n name=\"test_instance\",\n project=project,\n termination_policy=TerminationPolicy.DESTROY_AFTER_IDLE,\n termination_idle_time=10,\n status=InstanceStatus.IDLE,\n created_at=datetime.datetime.now(datetime.timezone.utc),\n )\n await session.add(instance)\n await session.commit()\n\n # when\n await terminate_idle_instance()\n\n # then\n instance = await session.get(InstanceModel, instance.id)\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"\n```\n" + ], + "line": 56, + "token": 289, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 85, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef nargsort(\n items,\n kind: str = \"quicksort\",\n ascending: bool = True,\n na_position: str = \"last\",\n key: Callable | None = None,\n mask: npt.NDArray[np.bool_] | None = None,\n) -> npt.NDArray[np.intp]:\n\n if key is not None:\n items = ensure_key_mapped(items, key)\n return nargsort(\n items,\n kind=kind,\n ascending=ascending,\n na_position=na_position,\n key=None,\n mask=mask,\n )\n\n if isinstance(items, ABCRangeIndex):\n return items.argsort(ascending=ascending) # TODO: test coverage with key?\n elif not isinstance(items, ABCMultiIndex):\n items = extract_array(items)\n if mask is None:\n mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?\n\n if is_extension_array_dtype(items):\n return items.argsort(ascending=ascending, kind=kind, na_position=na_position)\n else:\n items = np.asanyarray(items)\n\n idx = np.arange(len(items))\n non_nans = items[~mask]\n non_nan_idx = idx[~mask]\n\n nan_idx = np.nonzero(mask)[0]\n if not ascending:\n non_nans = non_nans[::-1]\n non_nan_idx = non_nan_idx[::-1]\n indexer = non_nan_idx[non_nans.argsort(kind=kind)]\n if not ascending:\n indexer = indexer[::-1]\n # Finally, place the NaNs at the end or the beginning according to\n # na_position\n if na_position == \"last\":\n indexer = np.concatenate([indexer, nan_idx])\n elif na_position == \"first\":\n indexer = np.concatenate([nan_idx, indexer])\n else:\n raise ValueError(f\"invalid na_position: {na_position}\")\n return ensure_platform_int(indexer)\n```\n###test function signature:\n```python\n\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nargsort", + "reference": "\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)\n\n # mergesort is the most difficult to get right because we want it to be\n # stable.\n\n # According to numpy/core/tests/test_multiarray, \"\"\"The number of\n # sorted items must be greater than ~50 to check the actual algorithm\n # because quick and merge sort fall over to insertion sort for small\n # arrays.\"\"\"\n\n result = nargsort(\n items, kind=\"mergesort\", ascending=ascending, na_position=na_position\n )\n tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)", + "completions": [ + "```python\n def test_nargsort(self, ascending, na_position, exp, box):\n items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216," + ], + "line": 58, + "token": 412, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 86, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_rotated_rb_circuits(\n n_qubits: int,\n num_cliffords: int,\n theta: Optional[float] = None,\n trials: int = 1,\n return_type: Optional[str] = None,\n seed: Optional[int] = None,\n) -> List[QPROGRAM]:\n\n circuits = cast(\n List[cirq.Circuit],\n generate_rb_circuits(n_qubits, num_cliffords, trials, seed=seed),\n )\n\n if theta is None:\n theta = 2 * np.pi * np.random.rand()\n\n for circ in circuits:\n qubits = list(circ.all_qubits())\n circ.insert(len(circ) // 2, cirq.Rz(rads=theta).on(qubits[0]))\n\n return_type = \"cirq\" if not return_type else return_type\n return [convert_from_mitiq(circuit, return_type) for circuit in circuits]\n```\n###test function signature:\n```python\n\ndef test_rotated_rb_circuits(n_qubits, theta):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rotated_rb_circuits", + "reference": "\ndef test_rotated_rb_circuits(n_qubits, theta):\n depth = 10\n for trials in [2, 3]:\n circuits = generate_rotated_rb_circuits(\n n_qubits=n_qubits, num_cliffords=depth, theta=theta, trials=trials\n )\n for circ in circuits:\n zero_prob = (\n cirq.DensityMatrixSimulator()\n .simulate(circ)\n .final_density_matrix[0, 0]\n .real\n )\n assert -1.0001 <= zero_prob <= 1.0001", + "completions": [ + "```python\ndef test_rotated_rb_circuits(n_qubits, theta):\n from cirq import Circuit, LineQubit\n from cirq.testing import assert_has_diagram\n from mitiq.generate import generate_rotated_rb_circuits\n\n circuits = generate_rotated_rb_circuits(n_qubits, 10, theta)\n for circuit in circuits:\n assert_has_diagram(\n circuit,\n f\"\"\"\n0: ───Rz({theta})───\n1: ───Rz({theta})───\n2: ───Rz({theta})───\n3: ───Rz({theta})───\n4: ───Rz({theta})───\n5: ───Rz({theta})───\n6: ───Rz({theta})───\n7: ───Rz({theta})───\n8: ───Rz({theta})───\n9: ───Rz({theta})───\n10: ───Rz({theta})───\n11: ───Rz({theta})───\n12: ───Rz({theta})───\n13: ───Rz({theta})───\n14: ───Rz({theta})───\n15: ───Rz({theta})───\n16: ───Rz({theta})───\n17: ───Rz({theta})───\n18: ───Rz({theta})───\n19: ───Rz({theta})───\n20: ───Rz({theta})───\n21: ───Rz({theta})───\n22: ───Rz({theta})───\n23: ───Rz({theta})───\n24: ───Rz({theta})───\n25: ───Rz({theta})───\n26: ───Rz({theta})───\n27: ───Rz({theta})───\n28: ───Rz({theta})───\n29: ───Rz({theta})───\n30: ───Rz({theta})───\n31: ───Rz({theta})───\n32: ───Rz({theta})───\n33: ───Rz({theta})───\n34: ───Rz({theta})───\n35: ───Rz({theta})───\n36: ───Rz({theta})───\n37: ───Rz({theta})───\n38: ───Rz({theta})───\n39: ───Rz({theta})───\n40: ───Rz({theta})───\n41: ───Rz({theta})───\n42: ───Rz({theta})───\n43: ───Rz({theta})───\n44: ───Rz({theta})───\n45: ───Rz({theta})───\n46: ───Rz({theta})───\n47: ───Rz({theta})───\n48: ───Rz({theta})───\n49: ───Rz({theta})───\n50: ───Rz({theta})───\n51: ───Rz({theta})───\n52: ───Rz({theta})───\n53: ───Rz({theta})───\n54: ───Rz({theta})───\n55: ───Rz({theta})───\n56: ───Rz({theta})───\n57: ───Rz({theta})───\n58: ───Rz({theta})───\n59: ───Rz({theta})───\n60: ───Rz({theta})───\n61: ───Rz({theta})───\n62: ───Rz({theta})───\n63: ───" + ], + "line": 29, + "token": 226, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 87, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef eye(N, M=None, k=None, dtype=None):\n dtype = to_torch_dtype(dtype or config.floatx())\n M = N if M is None else M\n k = 0 if k is None else k\n if k == 0:\n # TODO: torch.eye doesn't support bfloat16 with cpu\n if get_device() == \"cpu\" and dtype == torch.bfloat16:\n return cast(\n torch.eye(\n N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n ),\n dtype,\n )\n return torch.eye(N, M, dtype=dtype, device=get_device())\n diag_length = builtins.max(N, M)\n diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n return torch.diag(diag, diagonal=k)[:N, :M]\n```\n###test function signature:\n```python\n\n def test_eye(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_eye", + "reference": "\n def test_eye(self):\n self.assertAllClose(knp.eye(3), np.eye(3))\n self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n self.assertAllClose(knp.Eye()(3), np.eye(3))\n self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1))\n\n # Test k >= N\n self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3))\n\n # Test k > 0 and N >= M\n self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1))\n\n # Test k > 0 and N < M and N + k > M\n self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2))\n\n # Test k < 0 and M >= N\n self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1))\n\n # Test k < 0 and M < N and M - k > N\n self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))", + "completions": [ + "```python\n\n def test_eye(self):\n self.assertEqual(eye(3), torch.eye(3))\n self.assertEqual(eye(3, 4), torch.eye(3, 4))\n self.assertEqual(eye(3, 4, 1), torch.eye(3, 4, 1))\n self.assertEqual(eye(3, 4, -1), torch.eye(3, 4, -1))\n self.assertEqual(eye(3, 4, 2), torch.eye(3, 4, 2))\n self.assertEqual(eye(3, 4, -2), torch.eye(3, 4, -2))\n self.assertEqual(eye(3, 4, 3), torch.eye(3, 4, 3))\n self.assertEqual(eye(3, 4, -3), torch.eye(3, 4, -3))\n self.assertEqual(eye(3, 4, 4), torch.eye(3, 4, 4))\n self.assertEqual(eye(3, 4, -4), torch.eye(3, 4, -4))\n self.assertEqual(eye(3, 4, 5), torch.eye(3, 4, 5))\n self.assertEqual(eye(3, 4, -5), torch.eye(3, 4, -5))\n self.assertEqual(eye(3, 4, 6), torch.eye(3, 4, 6))\n self.assertEqual(eye(3, 4, -6), torch.eye(3, 4, -6))\n self.assertEqual(eye(3, 4, 7), torch.eye(3, 4, 7))\n self.assertEqual(eye(3, 4, -7), torch.eye(3, 4, -7))\n self.assertEqual(eye(3, 4, 8), torch.eye(3, 4, 8))\n self.assertEqual(eye(3, 4, -8), torch.eye(3, 4, -8))\n self.assertEqual(eye(3, 4, 9), torch.eye(3, 4, 9))\n self.assertEqual(eye(3, 4, -9), torch.eye(3, 4, -9))\n self.assertEqual(eye(3, 4, 10), torch.eye(3, 4, 10))\n self.assertEqual(eye(3, 4, -10), torch.eye(3, 4, -10))\n self.assertEqual(eye(3, 4, 11), torch.eye(3, 4, 11))\n self.assertEqual(eye(3, 4, -11), torch.eye(3, 4, -11))\n self.assertEqual(eye(3, 4, 12), torch.eye(3, 4, 12))\n self.assertEqual(eye(3, 4, -12), torch.eye(3, 4, -12))\n self.assertEqual(eye(3, 4, 13), torch.eye(3, 4, 13))\n self.assertEqual(eye(3, 4, -13), torch.eye(3, 4, -13))\n self.assertEqual(eye(3, 4, 14), torch.eye(3, 4, 14))\n self.assertEqual(eye(3, 4, -14), torch.eye(3, 4, -14))\n self.assertEqual(eye(3, 4, 15), torch.eye(" + ], + "line": 27, + "token": 219, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef ragged_qk_dotprod(\n query: RaggedActivations, key: RaggedActivations, lut: RaggedQkPidLookupTable\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n # check constraints\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n # allocates output\n # max_n_ctx_q_across_seqs = query.max_n_ctx_per_seq\n\n assert query.n_seqs == key.n_seqs\n # TODO: flag use zeros for garbage\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n # Integers\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n```\n###test function signature:\n```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ragged_qk_dotprod_single_seq", + "reference": "\ndef test_ragged_qk_dotprod_single_seq(dtype):\n d_head = 2\n\n key = RaggedActivations.from_list(\n [\n make_seq(n_ctx=3, value=42, d_model=d_head, dtype=dtype),\n ]\n )\n query = RaggedActivations.from_list(\n [\n make_seq(n_ctx=4, value=55, d_model=d_head, dtype=dtype),\n ]\n )\n torch_scores = scores_via_qk_dotprod(query, key)\n print(f\"{torch_scores=}\")\n\n lut = RaggedQkPidLookupTable.from_query_and_key_tokens_per_seq(\n n_ctx_q_per_seq=query.n_ctx_per_seq, n_ctx_k_per_seq=key.n_ctx_per_seq\n )\n\n scores = ragged_qk_dotprod(query, key, lut)\n assert_eq(torch_scores, scores)", + "completions": [ + "```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n # query\n query_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n query = RaggedActivations(\n raw_tensor=query_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # key\n key_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n key = RaggedActivations(\n raw_tensor=key_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # lut\n pid_to_in_q_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_in_k_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_q_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_k_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_seq_idx = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n lut = RaggedQkPidLookupTable(\n pid_to_in_q_token_offset=pid_to_in_q_token_offset,\n pid_to_in_k_token_offset=pid_to_in_k_token_offset,\n pid_to_out_q_block=pid_to_out_q_block,\n pid_to_out_k_block=pid_to_out_k_block,\n pid_to_out_seq_idx=pid_to_out_seq_idx,\n n_pids_total=1,\n )\n\n # scores_out\n scores_out = torch.ones(\n (1, 1, 1), dtype=dtype, device=torch.device(\"cuda\")\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_" + ], + "line": 57, + "token": 310, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 89, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):\n # Ravel both arrays, behavior for the first array could be different\n ar1 = np.asarray(ar1).ravel()\n ar2 = np.asarray(ar2).ravel()\n\n # Ensure that iteration through object arrays yields size-1 arrays\n if ar2.dtype == object:\n ar2 = ar2.reshape(-1, 1)\n\n if kind not in {None, 'sort', 'table'}:\n raise ValueError(\n f\"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.\")\n\n # Can use the table method if all arrays are integers or boolean:\n is_int_arrays = all(ar.dtype.kind in (\"u\", \"i\", \"b\") for ar in (ar1, ar2))\n use_table_method = is_int_arrays and kind in {None, 'table'}\n\n if use_table_method:\n if ar2.size == 0:\n if invert:\n return np.ones_like(ar1, dtype=bool)\n else:\n return np.zeros_like(ar1, dtype=bool)\n\n # Convert booleans to uint8 so we can use the fast integer algorithm\n if ar1.dtype == bool:\n ar1 = ar1.astype(np.uint8)\n if ar2.dtype == bool:\n ar2 = ar2.astype(np.uint8)\n\n ar2_min = np.min(ar2)\n ar2_max = np.max(ar2)\n\n ar2_range = int(ar2_max) - int(ar2_min)\n\n # Constraints on whether we can actually use the table method:\n # 1. Assert memory usage is not too large\n below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)\n # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype\n range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max\n # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype\n if ar1.size > 0:\n ar1_min = np.min(ar1)\n ar1_max = np.max(ar1)\n\n # After masking, the range of ar1 is guaranteed to be\n # within the range of ar2:\n ar1_upper = min(int(ar1_max), int(ar2_max))\n ar1_lower = max(int(ar1_min), int(ar2_min))\n\n range_safe_from_overflow &= all((\n ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,\n ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min\n ))\n\n # Optimal performance is for approximately\n # log10(size) > (log10(range) - 2.27) / 0.927.\n # However, here we set the requirement that by default\n # the intermediate array can only be 6x\n # the combined memory allocation of the original\n # arrays. See discussion on \n # https://github.com/numpy/numpy/pull/12065.\n\n if (\n range_safe_from_overflow and \n (below_memory_constraint or kind == 'table')\n ):\n\n if invert:\n outgoing_array = np.ones_like(ar1, dtype=bool)\n else:\n outgoing_array = np.zeros_like(ar1, dtype=bool)\n\n # Make elements 1 where the integer exists in ar2\n if invert:\n isin_helper_ar = np.ones(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 0\n else:\n isin_helper_ar = np.zeros(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 1\n\n # Mask out elements we know won't work\n basic_mask = (ar1 <= ar2_max) & (ar1 >= ar2_min)\n outgoing_array[basic_mask] = isin_helper_ar[ar1[basic_mask] -\n ar2_min]\n\n return outgoing_array\n elif kind == 'table': # not range_safe_from_overflow\n raise RuntimeError(\n \"You have specified kind='table', \"\n \"but the range of values in `ar2` or `ar1` exceed the \"\n \"maximum integer of the datatype. \"\n \"Please set `kind` to None or 'sort'.\"\n )\n elif kind == 'table':\n raise ValueError(\n \"The 'table' method is only \"\n \"supported for boolean or integer arrays. \"\n \"Please select 'sort' or None for kind.\"\n )\n\n\n # Check if one of the arrays may contain arbitrary objects\n contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject\n\n # This code is run when\n # a) the first condition is true, making the code significantly faster\n # b) the second condition is true (i.e. `ar1` or `ar2` may contain\n # arbitrary objects), since then sorting is not guaranteed to work\n if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:\n if invert:\n mask = np.ones(len(ar1), dtype=bool)\n for a in ar2:\n mask &= (ar1 != a)\n else:\n mask = np.zeros(len(ar1), dtype=bool)\n for a in ar2:\n mask |= (ar1 == a)\n return mask\n\n # Otherwise use sorting\n if not assume_unique:\n ar1, rev_idx = np.unique(ar1, return_inverse=True)\n ar2 = np.unique(ar2)\n\n ar = np.concatenate((ar1, ar2))\n # We need this to be a stable sort, so always use 'mergesort'\n # here. The values from the first array should always come before\n # the values from the second array.\n order = ar.argsort(kind='mergesort')\n sar = ar[order]\n if invert:\n bool_ar = (sar[1:] != sar[:-1])\n else:\n bool_ar = (sar[1:] == sar[:-1])\n flag = np.concatenate((bool_ar, [invert]))\n ret = np.empty(ar.shape, dtype=bool)\n ret[order] = flag\n\n if assume_unique:\n return ret[:len(ar1)]\n else:\n return ret[rev_idx]\n```\n###test function signature:\n```python\n\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_in1d_ravel", + "reference": "\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n a = np.arange(6).reshape(2, 3)\n b = np.arange(3, 9).reshape(3, 2)\n long_b = np.arange(3, 63).reshape(30, 2)\n ec = np.array([False, False, False, True, True, True])\n\n assert_array_equal(in1d(a, b, assume_unique=True, kind=kind),\n ec)\n assert_array_equal(in1d(a, b, assume_unique=False,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=True,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=False,\n kind=kind),\n ec)", + "completions": [ + "```python\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n ar1 = np.array([1, 2, 3, 4, 5])\n ar2 = np.array([1, 2, 3, 4, 5])\n ar1_ravel = ar1.ravel()\n ar2_ravel = ar2.ravel()\n self.assertTrue(np.array_equal(ar1_ravel, ar1))\n self.assertTrue(np.array_equal(ar2_ravel, ar2))\n self.assertTrue(np.array_equal(in1d(ar1, ar2, kind=kind),\n in1d(ar1_ravel, ar2_ravel, kind=kind)))\n```\n" + ], + "line": 133, + "token": 1112, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 90, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_metadata", + "reference": "\ndef test_header_to_metadata_and_cell_metadata():\n text = \"\"\"---\ntitle: Sample header\njupyter:\n mainlanguage: python\n---\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {\"mainlanguage\": \"python\"}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {\"lines_to_next_cell\": 0}\n assert pos == len(lines)", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_metadata():\n lines = [\n \"#' ---\",\n \"#' jupyter:\",\n \"#' jupytext:\",\n \"#' text_representation:\",\n \"#' extension: .py\",\n \"#' format_name: light\",\n \"#' format_version: '1.5'\",\n \"#' jupytext_version: 1.11.4\",\n \"#' kernelspec:\",\n \"#' display_name: Python 3\",\n \"#' language: python\",\n \"#' name: python3\",\n \"#' ---\",\n \"#'\",\n \"#' This is a comment\",\n \"#'\",\n \"#' And another one\",\n \"#'\",\n \"#' ```python\",\n \"#' import numpy as np\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' import matplotlib.pyplot as plt\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %matplotlib inline\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2" + ], + "line": 104, + "token": 642, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef import_optional_dependency(\n name: str,\n extra: str = \"\",\n errors: str = \"raise\",\n min_version: str | None = None,\n):\n\n assert errors in {\"warn\", \"raise\", \"ignore\"}\n\n package_name = INSTALL_MAPPING.get(name)\n install_name = package_name if package_name is not None else name\n\n msg = (\n f\"Missing optional dependency '{install_name}'. {extra} \"\n f\"Use pip or conda to install {install_name}.\"\n )\n try:\n module = importlib.import_module(name)\n except ImportError:\n if errors == \"raise\":\n raise ImportError(msg)\n return None\n\n # Handle submodules: if we have submodule, grab parent module from sys.modules\n parent = name.split(\".\")[0]\n if parent != name:\n install_name = parent\n module_to_get = sys.modules[install_name]\n else:\n module_to_get = module\n minimum_version = min_version if min_version is not None else VERSIONS.get(parent)\n if minimum_version:\n version = get_version(module_to_get)\n if version and Version(version) < Version(minimum_version):\n msg = (\n f\"Pandas requires version '{minimum_version}' or newer of '{parent}' \"\n f\"(version '{version}' currently installed).\"\n )\n if errors == \"warn\":\n warnings.warn(\n msg,\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return None\n elif errors == \"raise\":\n raise ImportError(msg)\n\n return module\n```\n###test function signature:\n```python\n\ndef test_bad_version(monkeypatch):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_bad_version", + "reference": "\ndef test_bad_version(monkeypatch):\n name = \"fakemodule\"\n module = types.ModuleType(name)\n module.__version__ = \"0.9.0\"\n sys.modules[name] = module\n monkeypatch.setitem(VERSIONS, name, \"1.0.0\")\n\n match = \"Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'\"\n with pytest.raises(ImportError, match=match):\n import_optional_dependency(\"fakemodule\")\n\n # Test min_version parameter\n result = import_optional_dependency(\"fakemodule\", min_version=\"0.8\")\n assert result is module\n\n with tm.assert_produces_warning(UserWarning):\n result = import_optional_dependency(\"fakemodule\", errors=\"warn\")\n assert result is None\n\n module.__version__ = \"1.0.0\" # exact match is OK\n result = import_optional_dependency(\"fakemodule\")\n assert result is module", + "completions": [ + "```python\ndef test_bad_version(monkeypatch):\n monkeypatch.setattr(import_optional_dependency, \"get_version\", lambda x: \"0.1.0\")\n monkeypatch.setattr(import_optional_dependency, \"VERSIONS\", {\"pandas\": \"0.2.0\"})\n with pytest.raises(ImportError):\n import_optional_dependency(\"pandas\")\n```\n" + ], + "line": 54, + "token": 352, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_mel_spectrogram(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mel_spectrogram", + "reference": "\n def test_mel_spectrogram(self):\n waveform = self._load_datasamples(1)[0]\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=513,\n num_mel_filters=13,\n min_frequency=100,\n max_frequency=4000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"htk\",\n )\n self.assertEqual(mel_filters.shape, (513, 13))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n )\n self.assertEqual(spec.shape, (513, 732))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n mel_filters=mel_filters,\n )\n self.assertEqual(spec.shape, (13, 732))\n\n # fmt: off\n expected = np.array([\n 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,\n 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,\n 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,\n 9.44153646e-04\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:, 300], expected))", + "completions": [ + "```python\n def test_mel_spectrogram(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n mel_floor = 1e-10\n log_mel = \"log\"\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (1, 10))\n self.assertEqual(spectrogram[0][0], 1)\n self.assertEqual(spectrogram[0][1], 2)\n self.assertEqual(spectrogram[0][2], 3)\n self.assertEqual(spectrogram[0][3], 4)\n self.assertEqual(spectrogram[0][4], 5)\n self.assertEqual(spectrogram[0][5], 6)\n self.assertEqual(spectrogram[0][6], 7)\n self.assertEqual(spectrogram[0][7], 8)\n self.assertEqual(spectrogram[0][8], 9)\n self.assertEqual(spectrogram[0][9], 10)\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 93, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef create(\n meta_schema,\n validators=(),\n version=None,\n type_checker=_types.draft202012_type_checker,\n format_checker=_format.draft202012_format_checker,\n id_of=_id_of,\n applicable_validators=methodcaller(\"items\"),\n):\n # preemptively don't shadow the `Validator.format_checker` local\n format_checker_arg = format_checker\n\n @attr.s\n class Validator:\n\n VALIDATORS = dict(validators)\n META_SCHEMA = dict(meta_schema)\n TYPE_CHECKER = type_checker\n FORMAT_CHECKER = format_checker_arg\n ID_OF = staticmethod(id_of)\n\n schema = attr.ib(repr=reprlib.repr)\n resolver = attr.ib(default=None, repr=False)\n format_checker = attr.ib(default=None)\n\n def __init_subclass__(cls):\n warnings.warn(\n (\n \"Subclassing validator classes is not intended to \"\n \"be part of their public API. A future version \"\n \"will make doing so an error, as the behavior of \"\n \"subclasses isn't guaranteed to stay the same \"\n \"between releases of jsonschema. Instead, prefer \"\n \"composition of validators, wrapping them in an object \"\n \"owned entirely by the downstream library.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n\n def __attrs_post_init__(self):\n if self.resolver is None:\n self.resolver = RefResolver.from_schema(\n self.schema,\n id_of=id_of,\n )\n\n @classmethod\n def check_schema(cls, schema, format_checker=_UNSET):\n Validator = validator_for(cls.META_SCHEMA, default=cls)\n if format_checker is _UNSET:\n format_checker = Validator.FORMAT_CHECKER\n validator = Validator(\n schema=cls.META_SCHEMA,\n format_checker=format_checker,\n )\n for error in validator.iter_errors(schema):\n raise exceptions.SchemaError.create_from(error)\n\n def evolve(self, **changes):\n # Essentially reproduces attr.evolve, but may involve instantiating\n # a different class than this one.\n cls = self.__class__\n\n schema = changes.setdefault(\"schema\", self.schema)\n NewValidator = validator_for(schema, default=cls)\n\n for field in attr.fields(cls):\n if not field.init:\n continue\n attr_name = field.name # To deal with private attributes.\n init_name = attr_name if attr_name[0] != \"_\" else attr_name[1:]\n if init_name not in changes:\n changes[init_name] = getattr(self, attr_name)\n\n return NewValidator(**changes)\n\n def iter_errors(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.iter_errors \"\n \"is deprecated and will be removed in a future \"\n \"release. Call validator.evolve(schema=new_schema).\"\n \"iter_errors(...) instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n else:\n _schema = self.schema\n\n if _schema is True:\n return\n elif _schema is False:\n yield exceptions.ValidationError(\n f\"False schema does not allow {instance!r}\",\n validator=None,\n validator_value=None,\n instance=instance,\n schema=_schema,\n )\n return\n\n scope = id_of(_schema)\n if scope:\n self.resolver.push_scope(scope)\n try:\n for k, v in applicable_validators(_schema):\n validator = self.VALIDATORS.get(k)\n if validator is None:\n continue\n\n errors = validator(self, v, instance, _schema) or ()\n for error in errors:\n # set details if not already set by the called fn\n error._set(\n validator=k,\n validator_value=v,\n instance=instance,\n schema=_schema,\n type_checker=self.TYPE_CHECKER,\n )\n if k not in {\"if\", \"$ref\"}:\n error.schema_path.appendleft(k)\n yield error\n finally:\n if scope:\n self.resolver.pop_scope()\n\n def descend(self, instance, schema, path=None, schema_path=None):\n for error in self.evolve(schema=schema).iter_errors(instance):\n if path is not None:\n error.path.appendleft(path)\n if schema_path is not None:\n error.schema_path.appendleft(schema_path)\n yield error\n\n def validate(self, *args, **kwargs):\n for error in self.iter_errors(*args, **kwargs):\n raise error\n\n def is_type(self, instance, type):\n try:\n return self.TYPE_CHECKER.is_type(instance, type)\n except exceptions.UndefinedTypeCheck:\n raise exceptions.UnknownType(type, instance, self.schema)\n\n def is_valid(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.is_valid is deprecated \"\n \"and will be removed in a future release. Call \"\n \"validator.evolve(schema=new_schema).is_valid(...) \"\n \"instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n self = self.evolve(schema=_schema)\n\n error = next(self.iter_errors(instance), None)\n return error is None\n\n if version is not None:\n safe = version.title().replace(\" \", \"\").replace(\"-\", \"\")\n Validator.__name__ = Validator.__qualname__ = f\"{safe}Validator\"\n Validator = validates(version)(Validator)\n\n return Validator\n```\n###test function signature:\n```python\n\n def test_validates_registers_meta_schema_id(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validates_registers_meta_schema_id", + "reference": "\n def test_validates_registers_meta_schema_id(self):\n meta_schema_key = \"meta schema id\"\n my_meta_schema = {\"id\": meta_schema_key}\n\n validators.create(\n meta_schema=my_meta_schema,\n version=\"my version\",\n id_of=lambda s: s.get(\"id\", \"\"),\n )\n self.addCleanup(validators._META_SCHEMAS.pop, meta_schema_key)\n\n self.assertIn(meta_schema_key, validators._META_SCHEMAS)", + "completions": [ + "```python\n def test_validates_registers_meta_schema_id(self):\n validator = create(\n meta_schema={\n \"$id\": \"https://example.com/schema.json\",\n \"type\": \"object\",\n \"properties\": {\"foo\": {\"type\": \"string\"}},\n }\n )\n self.assertEqual(\n validator.resolver.store[\"https://example.com/schema.json\"],\n validator.schema,\n )\n```\n" + ], + "line": 160, + "token": 896, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 94, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_metrics(directory: str) -> pd.Series:\n\n dfs = []\n # Loop through all csv files in the directory\n csv_files = [\n filename for filename in os.listdir(directory) if filename.endswith(\".csv\")\n ]\n\n if len(csv_files) == 0:\n return pd.Series()\n\n for filename in csv_files:\n filepath = os.path.join(directory, filename)\n # Read the csv file into a pandas DataFrame\n df = pd.read_csv(\n filepath,\n header=None,\n names=[\n \"initial_time_iso\",\n \"lead_time_hours\",\n \"channel\",\n \"metric\",\n \"value\",\n ],\n )\n out = pd.DataFrame()\n out[\"initial_time\"] = pd.to_datetime(df[\"initial_time_iso\"])\n out[\"lead_time\"] = pd.to_timedelta(df[\"lead_time_hours\"], unit=\"h\")\n out[\"channel\"] = df[\"channel\"]\n out[\"metric\"] = df[\"metric\"]\n out[\"value\"] = df[\"value\"]\n dfs.append(out)\n\n df = pd.concat(dfs, axis=0)\n df.set_index([\"initial_time\", \"lead_time\", \"channel\", \"metric\"], inplace=True)\n return df[\"value\"]\n```\n###test function signature:\n```python\n\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_and_read_metric", + "reference": "\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n initial_time = datetime.datetime(2022, 1, 1, 0, 0, 0)\n lead_time = datetime.timedelta(hours=24)\n channel = \"t2m\"\n metric = \"rmse\"\n value = 25.6\n\n # Read the metric from the directory\n directory = tmpdir.mkdir(\"metrics\")\n filename = \"test_metric.csv\"\n filepath = os.path.join(directory, filename)\n with open(filepath, \"w\") as f:\n forecast_metrics_io.write_metric(\n f, initial_time, lead_time, channel, metric, value\n )\n metrics = forecast_metrics_io.read_metrics(directory)\n\n # Check that the metric was written correctly\n expected_index = pd.MultiIndex.from_tuples(\n [(initial_time, lead_time, channel, metric)],\n names=[\"initial_time\", \"lead_time\", \"channel\", \"metric\"],\n )\n expected_value = pd.Series([value], index=expected_index, name=\"value\")\n pd.testing.assert_series_equal(metrics, expected_value)", + "completions": [ + "```python\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n metric = pd.Series(\n data=[1, 2, 3],\n index=[\n pd.Timestamp(\"2020-01-01 00:00:00\"),\n pd.Timestamp(\"2020-01-01 01:00:00\"),\n pd.Timestamp(\"2020-01-01 02:00:00\"),\n ],\n name=\"metric\",\n )\n metric.to_csv(tmpdir.join(\"metric.csv\"), header=False)\n\n # Read the metric from the file object\n read_metric = read_metrics(tmpdir.strpath)\n\n # Check that the metric is the same\n assert_series_equal(metric, read_metric)\n```\n" + ], + "line": 43, + "token": 320, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 95, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n\n def test_kde_bandwidth_method(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_kde_bandwidth_method", + "reference": "\n def test_kde_bandwidth_method(self):\n\n np.random.seed(8765678)\n n_basesample = 50\n xn = np.random.randn(n_basesample)\n\n # Default\n gkde = mlab.GaussianKDE(xn)\n # Supply a callable\n gkde2 = mlab.GaussianKDE(xn, 'scott')\n # Supply a scalar\n gkde3 = mlab.GaussianKDE(xn, bw_method=gkde.factor)\n\n xs = np.linspace(-7, 7, 51)\n kdepdf = gkde.evaluate(xs)\n kdepdf2 = gkde2.evaluate(xs)\n assert kdepdf.all() == kdepdf2.all()\n kdepdf3 = gkde3.evaluate(xs)\n assert kdepdf.all() == kdepdf3.all()", + "completions": [ + "```python\n def test_kde_bandwidth_method(self):\n from sklearn.neighbors import KernelDensity\n from sklearn.metrics import euclidean_distances\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n from sklearn.model_selection import train_test_split\n from sklearn.metrics import accuracy_score\n from sklearn.metrics import f1_score\n from sklearn.metrics import precision_score\n from sklearn.metrics import recall_score\n from sklearn.metrics import roc_auc_score\n from sklearn.metrics import confusion_matrix\n from sklearn.metrics import classification_report\n from sklearn.metrics import cohen_kappa_score\n from sklearn.metrics import mean_squared_error\n from sklearn.metrics import mean_absolute_error\n from sklearn.metrics import mean_squared_log_error\n from sklearn.metrics import median_absolute_error\n from sklearn.metrics import r2_score\n from sklearn.metrics import explained_variance_score\n from sklearn.metrics import silhouette_score\n from sklearn.metrics import calinski_harabasz_score\n from sklearn.metrics import davies_bouldin_score\n from sklearn.metrics import mutual_info_score\n from sklearn.metrics import normalized_mutual_info_score\n from sklearn.metrics import adjusted_mutual_info_score\n from sklearn.metrics import adjusted_rand_score\n from sklearn.metrics import completeness_score\n from sklearn.metrics import contingency_matrix\n from sklearn.metrics import homogeneity_completeness_v_measure\n from sklearn.metrics import homogeneity_score\n from sklearn.metrics import v_measure_score\n from sklearn.metrics import normalized_v_measure_score\n from sklearn.metrics import poisson\n from sklearn.metrics import kld\n from sklearn.metrics import kld_normalized\n from sklearn.metrics import kld_unnormalized\n from sklearn.metrics import kld_unnormalized_normalized\n from sklearn.metrics import kld_unnormalized_normalized_sym\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_top_k_embeddings(\n query_embedding: List[float],\n embeddings: List[List[float]],\n similarity_fn: Optional[Callable[..., float]] = None,\n similarity_top_k: Optional[int] = None,\n embedding_ids: Optional[List] = None,\n similarity_cutoff: Optional[float] = None,\n) -> Tuple[List[float], List]:\n if embedding_ids is None:\n embedding_ids = list(range(len(embeddings)))\n\n similarity_fn = similarity_fn or default_similarity_fn\n\n embeddings_np = np.array(embeddings)\n query_embedding_np = np.array(query_embedding)\n\n similarity_heap: List[Tuple[float, Any]] = []\n for i, emb in enumerate(embeddings_np):\n similarity = similarity_fn(query_embedding_np, emb)\n if similarity_cutoff is None or similarity > similarity_cutoff:\n heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))\n if similarity_top_k and len(similarity_heap) > similarity_top_k:\n heapq.heappop(similarity_heap)\n result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)\n\n result_similarities = [s for s, _ in result_tups]\n result_ids = [n for _, n in result_tups]\n\n return result_similarities, result_ids\n```\n###test function signature:\n```python\ndef test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_top_k_mmr_embeddings", + "reference": "def test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n query_embedding = [5.0, 0.0, 0.0]\n embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]]\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.8\n )\n\n assert np.isclose(0.8 * 4 / 5, result_similarities[0], atol=0.00001)\n assert np.isclose(\n 0.8 * 3 / 5 - (1 - 0.8) * (3 * 4 / 25 + 3 * 4 / 25),\n result_similarities[1],\n atol=0.00001,\n )\n assert np.isclose(\n 0.8 * -4 / 5 - (1 - 0.8) * (3 * -4 / 25 + 4 * 3 / 25),\n result_similarities[2],\n atol=0.00001,\n )\n assert result_ids == [0, 1, 2]\n\n # Tests that if the first embedding vector is close to the second,\n # it will return the third\n query_embedding = [1.0, 0.0, 1.0]\n embeddings = [[1.0, 0.0, 0.9], [1.0, 0.0, 0.8], [0.7, 0.0, 1.0]]\n\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.5\n )\n assert result_ids == [0, 2, 1]\n\n # Tests that embedding ids map properly to results\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, embedding_ids=[\"A\", \"B\", \"C\"], mmr_threshold=0.5\n )\n assert result_ids == [\"A\", \"C\", \"B\"]\n # Test that it will go back to the original order under a high threshold\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n assert result_ids == [0, 1, 2]\n\n # Test similarity_top_k works\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1, similarity_top_k=2\n )\n assert result_ids == [0, 1]\n\n # Test the results for get_top_k_embeddings and get_top_k_mmr_embeddings are the\n # same for threshold = 1\n query_embedding = [10, 23, 90, 78]\n embeddings = [[1, 23, 89, 68], [1, 74, 144, 23], [0.23, 0.0, 1.0, 9]]\n result_similarities_no_mmr, result_ids_no_mmr = get_top_k_embeddings(\n query_embedding, embeddings\n )\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n\n for result_no_mmr, result_with_mmr in zip(\n result_similarities_no_mmr, result_similarities\n ):\n assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)", + "completions": [ + "```python\ndef test_get_top_k_mmr_embeddings() -> None:\n query_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]\n embeddings = [\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n ]\n result_similarities, result_ids = get_top_k_embeddings(\n query_embedding, embeddings, similarity_top_k=3\n )\n assert result_similarities == [0.5, 0.4, 0.3]\n assert result_ids == [0, 1, 2]\n```\n" + ], + "line": 35, + "token": 309, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 97, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef frame_accuracy(\n logits: np.ndarray,\n label: np.ndarray,\n background_logit_threshold: float = 0.0,\n) -> float:\n top1_idx = np.argmax(logits, axis=-1)\n background_label = np.sum(label, axis=-1) == 0\n pred_background = (\n np.max(logits, axis=-1) <= background_logit_threshold\n ).astype(np.int32)\n\n # Extracts the label at the highest logit index for each input.\n top1_correct = np.take_along_axis(label, top1_idx[..., None], axis=-1)\n top1_correct = np.squeeze(top1_correct)\n foreground_correct = ~pred_background.astype(bool) * top1_correct\n\n # Count correctly classified background samples.\n background_correct = pred_background * background_label\n correct = foreground_correct + background_correct\n return np.sum(correct) / len(logits)\n```\n###test function signature:\n```python\n\n def test_frame_accuracy_all_background(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_frame_accuracy_all_background", + "reference": "\n def test_frame_accuracy_all_background(self):\n logits = np.array(\n [\n [1.5, 0.0, -1.0],\n [0.1, 0.7, 0.2],\n [0.1, 0.7, 0.2],\n [0.1, -0.7, 0.2],\n [-0.1, -0.7, -0.2],\n ],\n np.float32,\n )\n label = np.array(\n [\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n ],\n np.int32,\n )\n actual = metrics.frame_accuracy(logits, label)\n self.assertAlmostEqual(actual, 0.2)", + "completions": [ + "```python\ndef test_frame_accuracy_all_background(self):\n logits = np.array([[0.1, 0.2, 0.3, 0.4]])\n label = np.array([[0, 0, 0, 0]])\n self.assertAlmostEqual(frame_accuracy(logits, label), 1.0)\n```\n" + ], + "line": 28, + "token": 206, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_local_image_to_be_loaded", + "reference": "\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n file_path, _ = example_local_image\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=file_path,\n max_width=64,\n max_height=128,\n )\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert abs(scaling_factor - 0.5) < 1e-5\n assert decoding_result.shape == (64, 64, 3)\n assert (decoding_result == 0).all()", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n reference, image = example_local_image\n # when\n result = load_image_from_string(\n reference=reference,\n max_height=image.shape[0],\n max_width=image.shape[1],\n )\n # then\n assert result[0] == reference\n assert result[1] is None\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 99, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = \"htk\") -> Union[float, np.ndarray]:\n\n if mel_scale not in [\"slaney\", \"htk\", \"kaldi\"]:\n raise ValueError('mel_scale should be one of \"htk\", \"slaney\" or \"kaldi\".')\n\n if mel_scale == \"htk\":\n return 2595.0 * np.log10(1.0 + (freq / 700.0))\n elif mel_scale == \"kaldi\":\n return 1127.0 * np.log(1.0 + (freq / 700.0))\n\n min_log_hertz = 1000.0\n min_log_mel = 15.0\n logstep = 27.0 / np.log(6.4)\n mels = 3.0 * freq / 200.0\n\n if isinstance(freq, np.ndarray):\n log_region = freq >= min_log_hertz\n mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep\n elif freq >= min_log_hertz:\n mels = min_log_mel + np.log(freq / min_log_hertz) * logstep\n\n return mels\n```\n###test function signature:\n```python\n\n def test_hertz_to_mel(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_hertz_to_mel", + "reference": "\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(0.0), 0.0)\n self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)\n\n inputs = np.array([100, 200])\n expected = np.array([150.48910241, 283.22989816])\n self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))\n\n self.assertEqual(hertz_to_mel(0.0, \"slaney\"), 0.0)\n self.assertEqual(hertz_to_mel(100, \"slaney\"), 1.5)\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"slaney\"), expected))\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"kaldi\"), expected))\n\n with pytest.raises(ValueError):\n hertz_to_mel(100, mel_scale=None)", + "completions": [ + "```python\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(" + ], + "line": 27, + "token": 262, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/6/EI/line_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Python/6/EI/line_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..303a59d42d1bdf5e1735d4526f9dc73b559ffe87 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/6/EI/line_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(15~133),subset_1(160~187),subset_2,subset_3,subset_4,subset_5(749~749) +StarCoder2-15b,24.35,27.41,0.00,0.00,0.00,20.70 +CodeLlama-7b,26.59,24.25,0.00,0.00,0.00,27.18 +CodeLlama-13b,26.66,23.69,0.00,0.00,0.00,31.82 +CodeLlama-34b,27.00,30.05,0.00,0.00,0.00,24.58 +DeepSeek-Coder-1.3b,25.33,33.52,0.00,0.00,0.00,21.28 +DeepSeek-Coder-6.7b,24.64,26.60,0.00,0.00,0.00,20.74 +DeepSeek-Coder-33b,28.72,36.20,0.00,0.00,0.00,25.46 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/6/EI/token_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Python/6/EI/token_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..580002ea65b01850a5474ca28fcff3e33fc59433 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/6/EI/token_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(122~1112),subset_1(1981~1981),subset_2,subset_3,subset_4,subset_5(7038~7038) +StarCoder2-15b,24.65,15.68,0.00,0.00,0.00,20.70 +CodeLlama-7b,26.49,22.34,0.00,0.00,0.00,27.18 +CodeLlama-13b,26.64,10.43,0.00,0.00,0.00,31.82 +CodeLlama-34b,27.29,19.34,0.00,0.00,0.00,24.58 +DeepSeek-Coder-1.3b,25.90,21.64,0.00,0.00,0.00,21.28 +DeepSeek-Coder-6.7b,24.71,29.94,0.00,0.00,0.00,20.74 +DeepSeek-Coder-33b,29.18,32.72,0.00,0.00,0.00,25.46 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/6/EI/tongji.py b/dataset/Test Generation/ComplexCodeEval-Python/6/EI/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/6/EI/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/6/QS/QS.json b/dataset/Test Generation/ComplexCodeEval-Python/6/QS/QS.json new file mode 100644 index 0000000000000000000000000000000000000000..d5e24ec41baa0ffb1d84791f5618c342b91d3982 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/6/QS/QS.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_pair_vector_and_distance(g: dgl.DGLGraph):\n dst_pos = g.ndata[\"pos\"][g.edges()[1]] + g.edata[\"pbc_offshift\"]\n src_pos = g.ndata[\"pos\"][g.edges()[0]]\n bond_vec = dst_pos - src_pos\n bond_dist = torch.norm(bond_vec, dim=1)\n\n return bond_vec, bond_dist\n```\n###test function signature:\n```python\n\n def test_compute_angle(self, graph_Mo, graph_CH4):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_angle", + "reference": "\n def test_compute_angle(self, graph_Mo, graph_CH4):\n s1, g1, state1 = graph_Mo\n lattice = torch.tensor(s1.lattice.matrix, dtype=matgl.float_th).unsqueeze(dim=0)\n g1.edata[\"pbc_offshift\"] = torch.matmul(g1.edata[\"pbc_offset\"], lattice[0])\n g1.ndata[\"pos\"] = g1.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g1)\n g1.edata[\"bond_vec\"] = bv\n g1.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g1, 4.0)\n\n line_graph = create_line_graph(g1, 4.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))\n np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata[\"theta\"])), decimal=4)\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n s2, g2, state2 = graph_CH4\n lattice = torch.tensor(np.identity(3), dtype=matgl.float_th).unsqueeze(dim=0)\n g2.edata[\"pbc_offshift\"] = torch.matmul(g2.edata[\"pbc_offset\"], lattice[0])\n g2.ndata[\"pos\"] = g2.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g2)\n g2.edata[\"bond_vec\"] = bv\n g2.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g2, 2.0)\n\n line_graph = create_line_graph(g2, 2.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata[\"theta\"]))\n )\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )", + "completions": [ + "```python\n def test_compute_angle(self, graph_Mo, graph_CH4):\n \"\"\"\n Test compute_angle function\n \"\"\"\n # compute angle between Mo-C and Mo-H\n angle_Mo_C = compute_angle(graph_Mo, graph_CH4, 0, 1)\n angle_Mo_H = compute_angle(graph_Mo, graph_CH4, 0, 2)\n\n # check angle between Mo-C and Mo-H\n self.assertAlmostEqual(angle_Mo_C, 104.5, places=1)\n self.assertAlmostEqual(angle_Mo_H, 109.5, places=1)\n```\n" + ], + "line": 16, + "token": 148, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 1, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def _make_schedule(self, targets: List[Task]) -> List[Task]:\n self.schedule = []\n self.dependencies = self._build_dependencies(targets)\n\n edge_tups = []\n for node in self.dependencies:\n for dependency in self.dependencies[node]:\n edge_tups.append((dependency, node))\n\n for task in targets:\n # add edges from a dummy node to each target to guarantee\n # they will be included in the final schedule\n edge_tups.append((Executor.DUMMY_TASK_VALUE, task))\n\n def _compare_key(task: Union[Task, str]):\n if task == Executor.DUMMY_TASK_VALUE:\n return (\"\", 0)\n return (\n task.group_label() or \"\",\n -task.priority(),\n )\n\n graph = networkx.DiGraph(edge_tups)\n res = [\n t\n for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)\n if t != Executor.DUMMY_TASK_VALUE\n ]\n return res\n```\n###test function signature:\n```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_group_label_with_dependencies", + "reference": "\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = create_mock_task(\"task1\", result=1, group_label=\"group1\")\n task2 = create_mock_task(\n \"task2\", result=2, dependencies={\"task1\": task1}, group_label=\"group2\"\n )\n task3 = create_mock_task(\n \"task3\", result=3, dependencies={\"task2\": task2}, group_label=\"group1\"\n )\n\n executor = Executor([task3])\n schedule = executor._make_schedule([task3])\n scheduled_labels = [\n task.group_label() for task in schedule if task.group_label()\n ]\n\n # Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1\n group1_indices = [\n i for i, label in enumerate(scheduled_labels) if label == \"group1\"\n ]\n group2_index = scheduled_labels.index(\"group2\")\n\n assert (\n group1_indices[-1] > group2_index\n ), \"Task with the same group label but later dependency was not scheduled after different group label\"", + "completions": [ + "```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = Task(name=\"task1\", group_label=\"group1\")\n task2 = Task(name=\"task2\", group_label=\"group1\")\n task3 = Task(name=\"task3\", group_label=\"group2\")\n task4 = Task(name=\"task4\", group_label=\"group2\")\n task5 = Task(name=\"task5\", group_label=\"group3\")\n task6 = Task(name=\"task6\", group_label=\"group3\")\n task7 = Task(name=\"task7\", group_label=\"group4\")\n task8 = Task(name=\"task8\", group_label=\"group4\")\n task9 = Task(name=\"task9\", group_label=\"group5\")\n task10 = Task(name=\"task10\", group_label=\"group5\")\n\n # Creating dependencies between tasks\n task1.depends_on(task2)\n task2.depends_on(task3)\n task3.depends_on(task4)\n task4.depends_on(task5)\n task5.depends_on(task6)\n task6.depends_on(task7)\n task7.depends_on(task8)\n task8.depends_on(task9)\n task9.depends_on(task10)\n\n # Creating a list of tasks\n tasks = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies\n tasks_with_dependencies = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with group labels\n tasks_with_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_" + ], + "line": 36, + "token": 242, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 2, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n):\n if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:\n return\n\n head_dim = dim // heads\n max_seq_len = find_multiple(max_seq_len, 8)\n\n for b in layers:\n b.attention.kv_cache = KVCache(\n max_batch_size, max_seq_len, heads, head_dim\n )\n\n freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)\n causal_mask = torch.tril(\n torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)\n )\n\n return causal_mask, freq_cis\n```\n###test function signature:\n```python\n\ndef test_setup_cache_max_batch_size_greater_than_max():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_setup_cache_max_batch_size_greater_than_max", + "reference": "\ndef test_setup_cache_max_batch_size_greater_than_max():\n layers = [\n Mock(attention=Mock(kw_cache=None)),\n Mock(attention=Mock(kw_cache=None)),\n ]\n max_batch_size = 64\n max_seq_len = 32\n heads = 4\n head_dim = 32\n block_size = 32\n rope_base = 1000\n setup_cache(\n max_batch_size + 10,\n max_seq_len,\n head_dim * heads,\n heads,\n layers,\n block_size,\n rope_base,\n )\n for layer in layers:\n assert isinstance(layer.attention.kw_cache, KVCache)\n assert layer.attention.kw_cache.k_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )\n assert layer.attention.kw_cache.v_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )", + "completions": [ + "```python\ndef test_setup_cache_max_batch_size_greater_than_max():\n max_batch_size = 10\n max_seq_len = 10\n dim = 10\n heads = 10\n layers = [1, 2, 3]\n block_size = 10\n rope_base = 10\n\n causal_mask, freq_cis = setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n )\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size" + ], + "line": 26, + "token": 168, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_file_to_be_downloaded", + "reference": "\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n response = Response()\n response.status_code = 200\n image = np.zeros((128, 128, 3), dtype=np.uint8)\n _, encoded_image = cv2.imencode(\".jpg\", image)\n response._content = encoded_image\n requests_get_mock.return_value = response\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=\"https://some/file.jpg\"\n )\n\n # then\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert scaling_factor is None\n assert decoding_result.shape == image.shape\n assert np.allclose(decoding_result, image)\n requests_get_mock.assert_called_once_with(\"https://some/file.jpg\")", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n reference = \"https://www.example.com/image.jpg\"\n max_height = 100\n max_width = 100\n requests_get_mock.return_value.content = b\"image_bytes\"\n # when\n result = load_image_from_string(\n reference=reference, max_height=max_height, max_width=max_width\n )\n # then\n assert result == (\n \"data:image/jpeg;base64,image_bytes\",\n None,\n )\n```\n" + ], + "line": 47, + "token": 276, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 4, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_random_circuits", + "reference": "\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n placer = placer(connectivity=connectivity)\n layout_circ = Circuit(5)\n initial_layout = placer(layout_circ)\n router = Sabre(connectivity=connectivity, lookahead=look, decay_lookahead=decay)\n circuit = generate_random_circuit(nqubits=5, ngates=n_gates)\n measurement = gates.M(*range(5))\n circuit.add(measurement)\n transpiled_circuit, final_qubit_map = router(circuit, initial_layout)\n assert router.added_swaps >= 0\n assert_connectivity(connectivity, transpiled_circuit)\n assert_placement(transpiled_circuit, final_qubit_map)\n assert n_gates + router.added_swaps + 1 == transpiled_circuit.ngates\n assert_circuit_equivalence(\n original_circuit=circuit,\n transpiled_circuit=transpiled_circuit,\n final_map=final_qubit_map,\n initial_map=initial_layout,\n )\n assert transpiled_circuit.queue[-1].register_name == measurement.register_name", + "completions": [ + "```python\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n import random\n import networkx as nx\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import" + ], + "line": 27, + "token": 250, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 5, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef xx_circuit_step(source, strength, target, embodiment):\n\n permute_source_for_overlap, permute_target_for_overlap = None, None\n\n # apply all possible reflections, shifts to the source\n for source_reflection_name in reflection_options:\n reflected_source_coord, source_reflection, reflection_phase_shift = apply_reflection(\n source_reflection_name, source\n )\n for source_shift_name in shift_options:\n shifted_source_coord, source_shift, shift_phase_shift = apply_shift(\n source_shift_name, reflected_source_coord\n )\n\n # check for overlap, back out permutation\n source_shared, target_shared = None, None\n for i, j in [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]:\n\n if (\n abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi)) < EPSILON\n or abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi) - np.pi)\n < EPSILON\n ):\n source_shared, target_shared = i, j\n break\n if source_shared is None:\n continue\n\n # pick out the other coordinates\n source_first, source_second = (x for x in [0, 1, 2] if x != source_shared)\n target_first, target_second = (x for x in [0, 1, 2] if x != target_shared)\n\n # check for arccos validity\n r, s, u, v, x, y = decompose_xxyy_into_xxyy_xx(\n float(target[target_first]),\n float(target[target_second]),\n float(shifted_source_coord[source_first]),\n float(shifted_source_coord[source_second]),\n float(strength),\n )\n if any(math.isnan(val) for val in (r, s, u, v, x, y)):\n continue\n\n # OK: this combination of things works.\n # save the permutation which rotates the shared coordinate into ZZ.\n permute_source_for_overlap = canonical_rotation_circuit(source_first, source_second)\n permute_target_for_overlap = canonical_rotation_circuit(target_first, target_second)\n break\n\n if permute_source_for_overlap is not None:\n break\n\n if permute_source_for_overlap is None:\n raise QiskitError(\n \"Error during RZX decomposition: Could not find a suitable Weyl \"\n f\"reflection to match {source} to {target} along {strength}.\"\n )\n\n prefix_circuit, affix_circuit = QuantumCircuit(2), QuantumCircuit(2)\n\n # the basic formula we're trying to work with is:\n # target^p_t_f_o =\n # rs * (source^s_reflection * s_shift)^p_s_f_o * uv * operation * xy\n # but we're rearranging it into the form\n # target = affix source prefix\n # and computing just the prefix / affix circuits.\n\n # the outermost prefix layer comes from the (inverse) target permutation.\n prefix_circuit.compose(permute_target_for_overlap.inverse(), inplace=True)\n # the middle prefix layer comes from the local Z rolls.\n prefix_circuit.rz(2 * x, [0])\n prefix_circuit.rz(2 * y, [1])\n prefix_circuit.compose(embodiment, inplace=True)\n prefix_circuit.rz(2 * u, [0])\n prefix_circuit.rz(2 * v, [1])\n # the innermost prefix layer is source_reflection, shifted by source_shift,\n # finally conjugated by p_s_f_o.\n prefix_circuit.compose(permute_source_for_overlap, inplace=True)\n prefix_circuit.compose(source_reflection, inplace=True)\n prefix_circuit.global_phase += -np.log(reflection_phase_shift).imag\n prefix_circuit.global_phase += -np.log(shift_phase_shift).imag\n\n # the affix circuit is constructed in reverse.\n # first (i.e., innermost), we install the other half of the source transformations and p_s_f_o.\n affix_circuit.compose(source_reflection.inverse(), inplace=True)\n affix_circuit.compose(source_shift, inplace=True)\n affix_circuit.compose(permute_source_for_overlap.inverse(), inplace=True)\n # then, the other local rolls in the middle.\n affix_circuit.rz(2 * r, [0])\n affix_circuit.rz(2 * s, [1])\n # finally, the other half of the p_t_f_o conjugation.\n affix_circuit.compose(permute_target_for_overlap, inplace=True)\n\n return {\"prefix_circuit\": prefix_circuit, \"affix_circuit\": affix_circuit}\n```\n###test function signature:\n```python\n def test_xx_circuit_step(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_xx_circuit_step", + "reference": " def test_xx_circuit_step(self):\n\n for _ in range(100):\n source_coordinate, interaction, target_coordinate = self._generate_xxyy_test_case()\n\n source_embodiment = qiskit.QuantumCircuit(2)\n source_embodiment.append(UnitaryGate(canonical_matrix(*source_coordinate)), [0, 1])\n interaction_embodiment = qiskit.QuantumCircuit(2)\n interaction_embodiment.append(UnitaryGate(canonical_matrix(*interaction)), [0, 1])\n\n prefix_circuit, affix_circuit = itemgetter(\"prefix_circuit\", \"affix_circuit\")(\n xx_circuit_step(\n source_coordinate, interaction[0], target_coordinate, interaction_embodiment\n )\n )\n\n target_embodiment = QuantumCircuit(2)\n target_embodiment.compose(prefix_circuit, inplace=True)\n target_embodiment.compose(source_embodiment, inplace=True)\n target_embodiment.compose(affix_circuit, inplace=True)\n self.assertTrue(\n np.all(\n np.abs(\n qiskit.quantum_info.operators.Operator(target_embodiment).data\n - canonical_matrix(*target_coordinate)\n )\n < EPSILON\n )\n )", + "completions": [ + "```python\n def test_xx_circuit_step(self):\n source = np.array([0.1, 0.2, 0.3])\n strength = 0.4\n target = np.array([0.5, 0.6, 0.7])\n embodiment = QuantumCircuit(2)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.cx(0, 1)\n embodiment.barrier()\n embodiment.rz(0.1, 0)\n embodiment.rz(0.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.3, 0)\n embodiment.rz(0.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.5, 0)\n embodiment.rz(0.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.7, 0)\n embodiment.rz(0.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.9, 0)\n embodiment.rz(1.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.1, 0)\n embodiment.rz(1.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.3, 0)\n embodiment.rz(1.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.5, 0)\n embodiment.rz(1.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.7, 0)\n embodiment.rz(1.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.9, 0)\n embodiment.rz(2.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(2.1, 0)\n embodiment.rz(2.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment" + ], + "line": 90, + "token": 807, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 6, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_einsum(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_einsum", + "reference": "\n def test_einsum(self):\n x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n self.assertAllClose(\n knp.einsum(\"ijk,lkj->il\", x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"ijk,ikj->i\", x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"i...,j...k->...ijk\", x, y),\n np.einsum(\"i..., j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n self.assertAllClose(\n knp.Einsum(\"ijk,lkj->il\")(x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"ijk,ikj->i\")(x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n np.einsum(\"i...,j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))", + "completions": [ + "```python\n\n def test_arange(self):\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32),\n torch.arange(0, 10, 1, dtype=torch.float32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64),\n torch.arange(0, 10, 1, dtype=torch.float64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32),\n torch.arange(0, 10, 1, dtype=torch.int32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64),\n torch.arange(0, 10, 1, dtype=torch.int64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool),\n torch.arange(0, 10, 1, dtype=torch.bool),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64),\n torch.arange(0, 10, 1, dtype=torch.complex64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128),\n torch.arange(0, 10, 1, dtype=torch.complex128),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n )\n```\n" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 7, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_center_padding(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_center_padding", + "reference": "\n def test_spectrogram_center_padding(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"reflect\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,\n 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,\n 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,\n 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,\n 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,\n 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,\n 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,\n 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,\n 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,\n 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,\n 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,\n 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,\n 0.00217659, 0.00276204, 0.00260835, 0.00299299,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"constant\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,\n 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,\n 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,\n 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,\n 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,\n 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,\n 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,\n 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,\n 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,\n 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,\n 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,\n 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,\n 0.00788239, 0.00664407, 0.00824227, 0.00628301,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=False,\n )\n self.assertEqual(spec.shape, (257, 728))\n\n # fmt: off\n expected = np.array([\n 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,\n 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,\n 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,\n 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,\n 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,\n 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,\n 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,\n 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,\n 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,\n 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,\n 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,\n 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,\n 0.00811857, 0.00538216, 0.00685749, 0.00535275,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))", + "completions": [ + "```python\n def test_spectrogram_center_padding(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 4\n hop_length = 2\n fft_length = 16\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n np.testing.assert_array_equal(\n spectrogram,\n np.array(\n [\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0," + ], + "line": 95, + "token": 804, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 8, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_kword_ref_line", + "reference": "\ndef test_nanoplot_out_bars_with_kword_ref_line():\n\n out_bars_with_kword_ref_line = _generate_nanoplot(**CASES[12])\n\n assert _is_nanoplot_output(out_bars_with_kword_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = \"max\"\n y_ref_area = [1, 2]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = \"min\"\n expand_y = \"min\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_stroke_color=data_bar_negative_stroke_" + ], + "line": 749, + "token": 7038, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 9, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_take_along_axis(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_take_along_axis", + "reference": "\n def test_take_along_axis(self):\n x = np.arange(24).reshape([1, 2, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=1),\n np.take_along_axis(x, indices, axis=1),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=1)(x, indices),\n np.take_along_axis(x, indices, axis=1),\n )\n\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )\n\n # Test with axis=None\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.array([1, 2, 3], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=None),\n np.take_along_axis(x, indices, axis=None),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=None)(x, indices),\n np.take_along_axis(x, indices, axis=None),\n )\n\n # Test with negative indices\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )", + "completions": [ + "```python\n\n def test_take_along_axis(self):\n # test arange\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=1, dtype=torch.float32),\n torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=2, dtype=torch.float32),\n torch.tensor([0, 2, 4, 6, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=3, dtype=torch.float32),\n torch.tensor([0, 3, 6, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=4, dtype=torch.float32),\n torch.tensor([0, 4, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=5, dtype=torch.float32),\n torch.tensor([0, 5], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=6, dtype=torch.float32),\n torch.tensor([0, 6], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=7, dtype=torch.float32),\n torch.tensor([0, 7], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=8, dtype=torch.float32),\n torch.tensor([0, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=9, dtype=torch.float32),\n torch.tensor([0, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=10, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=11, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=12, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=13, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=14, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 10, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_correlation(\n response_level_scores: dict[str, list[dict[str, dict[str, Any]]]],\n name: str = '',\n) -> dict[str, dict[str, dict[str, float]]]:\n factscore_metrics = set(list_metrics(response_level_scores[FACTSCORE]))\n pred_metrics = set(list_metrics(response_level_scores[EVAL_METHOD]))\n result = {}\n\n for metric in factscore_metrics.intersection(pred_metrics):\n factscore_scores = [\n find_metric(response_level_scores[FACTSCORE][i], metric)\n for i in range(len(response_level_scores[FACTSCORE]))\n ]\n pred_scores = [\n find_metric(response_level_scores[EVAL_METHOD][i], metric)\n for i in range(len(response_level_scores[EVAL_METHOD]))\n ]\n\n # Remove failed runs\n for i in range(len(factscore_scores) - 1, -1, -1):\n if factscore_scores[i] == -1 or pred_scores[i] == -1:\n factscore_scores.pop(i)\n pred_scores.pop(i)\n\n # Cannot calculate correlation with only one data point\n if len(factscore_scores) <= 1 or len(pred_scores) <= 1:\n pearson_result, spearman_result = None, None\n else:\n pearson_result = stats.pearsonr(factscore_scores, pred_scores)\n spearman_result = stats.spearmanr(factscore_scores, pred_scores)\n\n scatter_plot(\n factscore_scores,\n pred_scores,\n f'{name}-{metric}',\n FACTSCORE,\n EVAL_METHOD,\n )\n result[metric] = {\n _PEARSON: CorrelationResult(pearson_result).result,\n _SPEARMAN: CorrelationResult(spearman_result).result,\n }\n\n return result\n```\n###test function signature:\n```python\n\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_correlation_one_data_point", + "reference": "\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n response_level_scores = {\n correlation_vs_factscore.FACTSCORE: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'N': 1}},\n ],\n correlation_vs_factscore.EVAL_METHOD: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'M': 1}},\n ],\n }\n actual_output = correlation_vs_factscore.compute_correlation(\n response_level_scores=response_level_scores\n )\n mock_pearsonr.assert_not_called()\n mock_spearmanr.assert_not_called()\n mock_scatter_plot.assert_called()\n self.assertIsInstance(actual_output, dict)\n self.assertLen(actual_output, 1)\n self.assertIn('S', actual_output)\n self.assertIsInstance(actual_output['S'], dict)\n self.assertLen(actual_output['S'], 2)\n self.assertIn('Pearson', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Pearson'], dict)\n self.assertIn('correlation', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['correlation']))\n self.assertIn('p_value', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['p_value']))\n self.assertIn('Spearman', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Spearman'], dict)\n self.assertIn('correlation', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['correlation']))\n self.assertIn('p_value', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['p_value']))", + "completions": [ + "```python\ndef test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n) -> None:\n response_level_scores = {\n FACTSCORE: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n EVAL_METHOD: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n }\n result = compute_correlation(response_level_scores)\n self.assertEqual(result, {})\n mock_spearmanr.assert_not_called()\n mock_pearsonr.assert_not_called()\n mock_scatter_plot.assert_not_called()\n```\n" + ], + "line": 54, + "token": 384, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 11, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_data_lines():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_data_lines", + "reference": "\ndef test_nanoplot_out_data_lines():\n\n out_data_lines = _generate_nanoplot(**CASES[0])\n\n assert _is_nanoplot_output(out_data_lines)\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_data_lines():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n expected = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_" + ], + "line": 749, + "token": 7038, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 12, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_daily_temp_mean", + "reference": "\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data = il_electricity_cdd_hdd_daily[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_daily[\"temperature_data\"]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert df.shape == (810, 3)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_mean\",\n ]\n\n assert round(df.temperature_mean.mean()) == 55.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"D\", tz=\"UTC\"\n )\n temperature_data = pd.DataFrame(\n {\"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]},\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n df = compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=[],\n cooling_balance_points=[],\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n )\n expected = pd.DataFrame(\n {\n \"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],\n \"n_hours_kept\": [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],\n \"n_hours_dropped\": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n },\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n pd.testing.assert_frame_equal(df, expected)\n```\n" + ], + "line": 166, + "token": 983, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 13, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef power_color(\n frequency,\n power,\n power_err=None,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n):\n freq_edges = np.asarray(freq_edges)\n if len(freq_edges) != 5:\n raise ValueError(\"freq_edges must have 5 elements\")\n\n frequency = np.asarray(frequency)\n power = np.asarray(power)\n\n if df is None:\n df = np.median(np.diff(frequency))\n input_frequency_low_edges = frequency - df / 2\n input_frequency_high_edges = frequency + df / 2\n\n if freq_edges.min() < input_frequency_low_edges[0]:\n raise ValueError(\"The minimum frequency is larger than the first frequency edge\")\n if freq_edges.max() > input_frequency_high_edges[-1]:\n raise ValueError(\"The maximum frequency is lower than the last frequency edge\")\n\n if power_err is None:\n power_err = power / np.sqrt(m)\n else:\n power_err = np.asarray(power_err)\n\n if freqs_to_exclude is not None:\n if len(np.shape(freqs_to_exclude)) == 1:\n freqs_to_exclude = [freqs_to_exclude]\n\n if (\n not isinstance(freqs_to_exclude, Iterable)\n or len(np.shape(freqs_to_exclude)) != 2\n or np.shape(freqs_to_exclude)[1] != 2\n ):\n raise ValueError(\"freqs_to_exclude must be of format [[f0, f1], [f2, f3], ...]\")\n for f0, f1 in freqs_to_exclude:\n frequency_mask = (input_frequency_low_edges > f0) & (input_frequency_high_edges < f1)\n idx0, idx1 = np.searchsorted(frequency, [f0, f1])\n power[frequency_mask] = np.mean([power[idx0], power[idx1]])\n\n var00, var00_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[:2],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var01, var01_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[2:4],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var10, var10_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[1:3],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var11, var11_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[3:5],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n pc0 = var00 / var01\n pc1 = var10 / var11\n pc0_err = pc0 * (var00_err / var00 + var01_err / var01)\n pc1_err = pc1 * (var10_err / var10 + var11_err / var11)\n if return_log:\n pc0_err = 1 / pc0 * pc0_err\n pc1_err = 1 / pc1 * pc1_err\n pc0 = np.log10(pc0)\n pc1 = np.log10(pc1)\n return pc0, pc0_err, pc1, pc1_err\n```\n###test function signature:\n```python\n\n def test_with_power_err(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_power_err", + "reference": "\n def test_with_power_err(self):\n pc0, pc0_err, pc1, pc1_err = power_color(\n self.freq,\n self.power,\n power_err=self.power / 2,\n )\n pc0e, pc0e_err, pc1e, pc1e_err = power_color(\n self.freq,\n self.power,\n power_err=self.power,\n )\n assert np.isclose(pc0, 1, atol=0.001)\n assert np.isclose(pc1, 1, atol=0.001)\n assert np.isclose(pc0e, 1, atol=0.001)\n assert np.isclose(pc1e, 1, atol=0.001)\n assert np.isclose(pc0e_err / pc0_err, 2, atol=0.001)\n assert np.isclose(pc1e_err / pc1_err, 2, atol=0.001)", + "completions": [ + "```python\n def test_with_power_err(self):\n frequency = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power_err = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n pc0, pc0_err, pc1, pc1_err = power_color(\n frequency,\n power,\n power_err=power_err,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n )\n self.assertEqual(pc0, 0.0)\n self.assertEqual(pc0_err, 0.0)\n self.assertEqual(pc1, 0.0)\n self.assertEqual(pc1_err, 0.0)\n```\n" + ], + "line": 97, + "token": 560, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 14, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def apply_top_k(\n self, logits: numpy.ndarray, filter_value=_MIN_FLOAT\n ) -> numpy.ndarray:\n logits_shape = logits.shape\n logits = logits.reshape(logits.shape[-1])\n top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]\n logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value\n\n return logits.reshape(logits_shape)\n```\n###test function signature:\n```python\n\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_apply_topk", + "reference": "\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11))\n\n token_generator = TokenGenerator(\n logits_shape=logits[-1].shape[-1],\n top_k=3,\n )\n\n filter_value = -float(\"Inf\")\n new_logits = token_generator.apply_top_k(\n logits.copy(), filter_value=filter_value\n )\n\n for _ in range(token_generator.top_k):\n curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits)\n assert curr_max > filter_value\n new_logits = numpy.delete(new_logits, idx)\n\n assert numpy.all(new_logits == filter_value)", + "completions": [ + "```python\n def test_apply_topk(\n self,\n ):\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9," + ], + "line": 21, + "token": 152, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(\n self,\n sources: List[Union[str, Path, ByteStream]],\n meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,\n ):\n parser = MarkdownIt(renderer_cls=RendererPlain)\n if self.table_to_single_line:\n parser.enable(\"table\")\n\n documents = []\n meta_list = normalize_metadata(meta=meta, sources_count=len(sources))\n\n for source, metadata in tqdm(\n zip(sources, meta_list),\n total=len(sources),\n desc=\"Converting markdown files to Documents\",\n disable=not self.progress_bar,\n ):\n try:\n bytestream = get_bytestream_from_source(source)\n except Exception as e:\n logger.warning(\"Could not read {source}. Skipping it. Error: {error}\", source=source, error=e)\n continue\n try:\n file_content = bytestream.data.decode(\"utf-8\")\n text = parser.render(file_content)\n except Exception as conversion_e:\n logger.warning(\n \"Failed to extract text from {source}. Skipping it. Error: {error}\",\n source=source,\n error=conversion_e,\n )\n continue\n\n merged_metadata = {**bytestream.meta, **metadata}\n document = Document(content=text, meta=merged_metadata)\n documents.append(document)\n\n return {\"documents\": documents}\n```\n###test function signature:\n```python\n\n def test_run_calls_normalize_metadata(self, test_files_path):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_calls_normalize_metadata", + "reference": "\n def test_run_calls_normalize_metadata(self, test_files_path):\n bytestream = ByteStream(data=b\"test\", meta={\"author\": \"test_author\", \"language\": \"en\"})\n\n converter = MarkdownToDocument()\n\n with patch(\"haystack.components.converters.markdown.normalize_metadata\") as normalize_metadata, patch(\n \"haystack.components.converters.markdown.MarkdownIt\"\n ):\n converter.run(sources=[bytestream, test_files_path / \"markdown\" / \"sample.md\"], meta={\"language\": \"it\"})\n\n # check that the metadata normalizer is called properly\n normalize_metadata.assert_called_with(meta={\"language\": \"it\"}, sources_count=2)", + "completions": [ + "```python\n def test_run_calls_normalize_metadata(self, test_files_path):\n # Arrange\n sources = [test_files_path / \"test_file.md\"]\n meta = {\"key\": \"value\"}\n\n # Act\n run(sources, meta)\n\n # Assert\n normalize_metadata.assert_called_once_with(meta=meta, sources_count=1)\n```\n" + ], + "line": 45, + "token": 311, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 16, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_config(config: object) -> None:\n if hasattr(config, \"kind\"):\n Config.running_config.classes[config.kind] = config # type: ignore\n Config.running_config.types[config.kind] = {}\n for field in fields(config): # type: ignore\n if hasattr(field, \"type\"):\n Config.running_config.types[config.kind][field.name] = optional_origin(field.type)\n else:\n raise RuntimeError(\"Config must have a 'kind' attribute\")\n```\n###test function signature:\n```python\n\ndef test_config():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_config", + "reference": "\ndef test_config():\n arg_parser = get_arg_parser()\n core_add_args(arg_parser)\n arg_parser.parse_known_args()\n cfg = Config(\"test\")\n cfg.add_config(ConfigTest)\n cfg.init_default_config()\n assert Config.dict() == {\n \"configtest\": {\n \"testvar1\": \"testing123\",\n \"testvar2\": 12345,\n \"testvar3\": {\n \"mydict\": {\"foo\": \"bar\", \"abc\": {\"def\": \"ghi\"}},\n \"myint\": 0,\n \"mystr\": \"Hello\",\n },\n }\n }\n cfg2 = Config(\"test2\")\n assert cfg.configtest.testvar1 == cfg2.configtest.testvar1\n assert cfg.configtest.testvar1 == Config.configtest.testvar1\n assert cfg.configtest.testvar2 == cfg2.configtest.testvar2\n assert cfg.configtest.testvar3 == cfg2.configtest.testvar3\n Config.configtest.testvar2 += 1\n assert cfg.configtest.testvar2 == 12346\n with pytest.raises(ConfigNotFoundError):\n Config.does_not_exist.foo = \"bar\"\n with pytest.raises(ConfigNotFoundError):\n cfg.does_not_exist.foo = \"bar\"", + "completions": [ + "```python\ndef test_config():\n from dataclasses import dataclass, field\n from typing import Optional\n\n @dataclass\n class Config:\n kind: str\n a: Optional[int] = field(default=None)\n b: Optional[int] = field(default=None)\n\n add_config(Config)\n\n assert Config.running_config.classes[\"Config\"] == Config\n assert Config.running_config.types[\"Config\"] == {\"a\": int, \"b\": int}\n```\n" + ], + "line": 19, + "token": 166, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 17, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_with_num_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_with_num_ref_line", + "reference": "\ndef test_nanoplot_out_with_num_ref_line():\n\n out_with_num_ref_line = _generate_nanoplot(**CASES[1])\n\n assert _is_nanoplot_output(out_with_num_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_with_num_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color," + ], + "line": 749, + "token": 7038, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 18, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def begin_typing(self, x=None):\n self.capturekeystrokes = True\n # Disable keypress shortcuts, which may otherwise cause the figure to\n # be saved, closed, etc., until the user stops typing. The way to\n # achieve this depends on whether toolmanager is in use.\n stack = ExitStack() # Register cleanup actions when user stops typing.\n self._on_stop_typing = stack.close\n toolmanager = getattr(\n self.ax.figure.canvas.manager, \"toolmanager\", None)\n if toolmanager is not None:\n # If using toolmanager, lock keypresses, and plan to release the\n # lock when typing stops.\n toolmanager.keypresslock(self)\n stack.callback(toolmanager.keypresslock.release, self)\n else:\n # If not using toolmanager, disable all keypress-related rcParams.\n # Avoid spurious warnings if keymaps are getting deprecated.\n with _api.suppress_matplotlib_deprecation_warning():\n stack.enter_context(mpl.rc_context(\n {k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n```\n###test function signature:\n```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_TextBox", + "reference": "\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n plt.rcParams._set(\"toolbar\", toolbar)\n\n submit_event = mock.Mock(spec=noop, return_value=None)\n text_change_event = mock.Mock(spec=noop, return_value=None)\n tool = widgets.TextBox(ax, '')\n tool.on_submit(submit_event)\n tool.on_text_change(text_change_event)\n\n assert tool.text == ''\n\n do_event(tool, '_click')\n\n tool.set_val('x**2')\n\n assert tool.text == 'x**2'\n assert text_change_event.call_count == 1\n\n tool.begin_typing()\n tool.stop_typing()\n\n assert submit_event.call_count == 2\n\n do_event(tool, '_click')\n do_event(tool, '_keypress', key='+')\n do_event(tool, '_keypress', key='5')\n\n assert text_change_event.call_count == 3", + "completions": [ + "```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n with ExitStack() as stack:\n stack.enter_context(mpl.rc_context({k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n ax.begin_typing()\n assert ax.capturekeystrokes is True\n assert toolbar.keypresslock.is_locked(ax)\n```\n" + ], + "line": 31, + "token": 258, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 19, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def read_events(\n db: PrefectDBInterface,\n session: AsyncSession,\n events_filter: EventFilter,\n limit: \"int | None\" = None,\n offset: \"int | None\" = None,\n) -> Sequence[ORMEvent]:\n # Always order by occurred timestamp, with placeholder for order direction\n order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc\n\n # Check if distinct fields are provided\n if distinct_fields := build_distinct_queries(events_filter):\n # Define window function\n window_function = (\n sa.func.row_number()\n .over(partition_by=distinct_fields, order_by=order(db.Event.occurred))\n .label(\"row_number\")\n )\n # Create a subquery with the window function\n subquery = (\n sa.select(db.Event, window_function)\n .where(\n sa.and_(\n *events_filter.build_where_clauses(db)\n ) # Ensure the same filters are applied here\n )\n .subquery()\n )\n\n # Alias the subquery for easier column references\n aliased_table = aliased(db.Event, subquery)\n\n # Create the final query from the subquery, filtering to get only rows with row_number = 1\n select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1)\n\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(subquery.c.occurred))\n\n else:\n # If no distinct fields are provided, create a query for all events\n select_events_query = sa.select(db.Event).where(\n sa.and_(*events_filter.build_where_clauses(db))\n )\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(db.Event.occurred))\n\n if limit is not None:\n limit = max(0, min(limit, events_filter.logical_limit))\n select_events_query = select_events_query.limit(limit=limit)\n if offset is not None:\n select_events_query = select_events_query.offset(offset=offset)\n\n logger.debug(\"Running PostgreSQL query: %s\", select_events_query)\n\n select_events_query_result = await session.execute(select_events_query)\n return select_events_query_result.scalars().unique().all()\n```\n###test function signature:\n```python\n\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_read_events", + "reference": "\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_1.id\n\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=3),\n until=pendulum.now(\"UTC\").subtract(days=1),\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_2.id\n\n # Read criteria should apply AND logic\n # Nothing should be returned for this resource\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n resource=EventResourceFilter(id=[\"prefect.garbage.foo\"]),\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 0", + "completions": [ + "```python\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n # Arrange\n events_filter = EventFilter(\n occurred_after=datetime.datetime(2022, 1, 1),\n occurred_before=datetime.datetime(2022, 1, 2),\n order=EventOrder.DESC,\n )\n # Act\n events = await read_events(\n db=self.db,\n session=session,\n events_filter=events_filter,\n )\n # Assert\n self.assertEqual(len(events), 2)\n self.assertEqual(events[0], event_2)\n self.assertEqual(events[1], event_1)\n```\n" + ], + "line": 58, + "token": 412, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 20, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n def test_evaluate_diff_dim(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_evaluate_diff_dim", + "reference": " def test_evaluate_diff_dim(self):\n x1 = np.arange(3, 10, 2)\n kde = mlab.GaussianKDE(x1)\n x2 = np.arange(3, 12, 2)\n y_expected = [\n 0.08797252, 0.11774109, 0.11774109, 0.08797252, 0.0370153\n ]\n y = kde.evaluate(x2)\n np.testing.assert_array_almost_equal(y, y_expected, 7)", + "completions": [ + "```python\n def test_evaluate_diff_dim(self):\n \"\"\"\n Test evaluate with different dimensions\n \"\"\"\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n\n # Generate a random dataset\n X, _ = make_blobs(n_samples=100, n_features=10, centers=5,\n random_state=42)\n\n # Standardize the data\n scaler = StandardScaler()\n X = scaler.fit_transform(X)\n\n # Create a GaussianMixture object\n gmm = GaussianMixture(n_components=5, covariance_type='full')\n gmm.fit(X)\n\n # Generate a random set of points\n n_points = 100\n points = np.random.rand(n_points, 10)\n\n # Evaluate the points\n result = gmm.evaluate(points)\n\n # Check that the result is correct\n self.assertEqual(result.shape, (n_points,))\n self.assertTrue(np.all(result >= 0))\n self.assertTrue(np.all(result <= 1))\n```\n" + ], + "line": 33, + "token": 273, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 21, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _differentiate(func, x, *, args=(), atol=None, rtol=None, maxiter=10,\n order=8, initial_step=0.5, step_factor=2.0,\n step_direction=0, preserve_shape=False, callback=None):\n # TODO (followup):\n # - investigate behavior at saddle points\n # - array initial_step / step_factor?\n # - multivariate functions?\n\n res = _differentiate_iv(func, x, args, atol, rtol, maxiter, order, initial_step,\n step_factor, step_direction, preserve_shape, callback)\n (func, x, args, atol, rtol, maxiter, order,\n h0, fac, hdir, preserve_shape, callback) = res\n\n # Initialization\n # Since f(x) (no step) is not needed for central differences, it may be\n # possible to eliminate this function evaluation. However, it's useful for\n # input validation and standardization, and everything else is designed to\n # reduce function calls, so let's keep it simple.\n temp = eim._initialize(func, (x,), args, preserve_shape=preserve_shape)\n func, xs, fs, args, shape, dtype = temp\n x, f = xs[0], fs[0]\n df = np.full_like(f, np.nan)\n # Ideally we'd broadcast the shape of `hdir` in `_elementwise_algo_init`, but\n # it's simpler to do it here than to generalize `_elementwise_algo_init` further.\n # `hdir` and `x` are already broadcasted in `_differentiate_iv`, so we know\n # that `hdir` can be broadcasted to the final shape.\n hdir = np.broadcast_to(hdir, shape).flatten()\n\n status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress\n nit, nfev = 0, 1 # one function evaluations performed above\n # Boolean indices of left, central, right, and (all) one-sided steps\n il = hdir < 0\n ic = hdir == 0\n ir = hdir > 0\n io = il | ir\n\n # Most of these attributes are reasonably obvious, but:\n # - `fs` holds all the function values of all active `x`. The zeroth\n # axis corresponds with active points `x`, the first axis corresponds\n # with the different steps (in the order described in\n # `_differentiate_weights`).\n # - `terms` (which could probably use a better name) is half the `order`,\n # which is always even.\n work = _RichResult(x=x, df=df, fs=f[:, np.newaxis], error=np.nan, h=h0,\n df_last=np.nan, error_last=np.nan, h0=h0, fac=fac,\n atol=atol, rtol=rtol, nit=nit, nfev=nfev,\n status=status, dtype=dtype, terms=(order+1)//2,\n hdir=hdir, il=il, ic=ic, ir=ir, io=io)\n # This is the correspondence between terms in the `work` object and the\n # final result. In this case, the mapping is trivial. Note that `success`\n # is prepended automatically.\n res_work_pairs = [('status', 'status'), ('df', 'df'), ('error', 'error'),\n ('nit', 'nit'), ('nfev', 'nfev'), ('x', 'x')]\n\n def pre_func_eval(work):\n \"\"\"Determine the abscissae at which the function needs to be evaluated.\n\n See `_differentiate_weights` for a description of the stencil (pattern\n of the abscissae).\n\n In the first iteration, there is only one stored function value in\n `work.fs`, `f(x)`, so we need to evaluate at `order` new points. In\n subsequent iterations, we evaluate at two new points. Note that\n `work.x` is always flattened into a 1D array after broadcasting with\n all `args`, so we add a new axis at the end and evaluate all point\n in one call to the function.\n\n For improvement:\n - Consider measuring the step size actually taken, since `(x + h) - x`\n is not identically equal to `h` with floating point arithmetic.\n - Adjust the step size automatically if `x` is too big to resolve the\n step.\n - We could probably save some work if there are no central difference\n steps or no one-sided steps.\n \"\"\"\n n = work.terms # half the order\n h = work.h # step size\n c = work.fac # step reduction factor\n d = c**0.5 # square root of step reduction factor (one-sided stencil)\n # Note - no need to be careful about dtypes until we allocate `x_eval`\n\n if work.nit == 0:\n hc = h / c**np.arange(n)\n hc = np.concatenate((-hc[::-1], hc))\n else:\n hc = np.asarray([-h, h]) / c**(n-1)\n\n if work.nit == 0:\n hr = h / d**np.arange(2*n)\n else:\n hr = np.asarray([h, h/d]) / c**(n-1)\n\n n_new = 2*n if work.nit == 0 else 2 # number of new abscissae\n x_eval = np.zeros((len(work.hdir), n_new), dtype=work.dtype)\n il, ic, ir = work.il, work.ic, work.ir\n x_eval[ir] = work.x[ir, np.newaxis] + hr\n x_eval[ic] = work.x[ic, np.newaxis] + hc\n x_eval[il] = work.x[il, np.newaxis] - hr\n return x_eval\n\n def post_func_eval(x, f, work):\n \"\"\" Estimate the derivative and error from the function evaluations\n\n As in `pre_func_eval`: in the first iteration, there is only one stored\n function value in `work.fs`, `f(x)`, so we need to add the `order` new\n points. In subsequent iterations, we add two new points. The tricky\n part is getting the order to match that of the weights, which is\n described in `_differentiate_weights`.\n\n For improvement:\n - Change the order of the weights (and steps in `pre_func_eval`) to\n simplify `work_fc` concatenation and eliminate `fc` concatenation.\n - It would be simple to do one-step Richardson extrapolation with `df`\n and `df_last` to increase the order of the estimate and/or improve\n the error estimate.\n - Process the function evaluations in a more numerically favorable\n way. For instance, combining the pairs of central difference evals\n into a second-order approximation and using Richardson extrapolation\n to produce a higher order approximation seemed to retain accuracy up\n to very high order.\n - Alternatively, we could use `polyfit` like Jacobi. An advantage of\n fitting polynomial to more points than necessary is improved noise\n tolerance.\n \"\"\"\n n = work.terms\n n_new = n if work.nit == 0 else 1\n il, ic, io = work.il, work.ic, work.io\n\n # Central difference\n # `work_fc` is *all* the points at which the function has been evaluated\n # `fc` is the points we're using *this iteration* to produce the estimate\n work_fc = (f[ic, :n_new], work.fs[ic, :], f[ic, -n_new:])\n work_fc = np.concatenate(work_fc, axis=-1)\n if work.nit == 0:\n fc = work_fc\n else:\n fc = (work_fc[:, :n], work_fc[:, n:n+1], work_fc[:, -n:])\n fc = np.concatenate(fc, axis=-1)\n\n # One-sided difference\n work_fo = np.concatenate((work.fs[io, :], f[io, :]), axis=-1)\n if work.nit == 0:\n fo = work_fo\n else:\n fo = np.concatenate((work_fo[:, 0:1], work_fo[:, -2*n:]), axis=-1)\n\n work.fs = np.zeros((len(ic), work.fs.shape[-1] + 2*n_new))\n work.fs[ic] = work_fc\n work.fs[io] = work_fo\n\n wc, wo = _differentiate_weights(work, n)\n work.df_last = work.df.copy()\n work.df[ic] = fc @ wc / work.h\n work.df[io] = fo @ wo / work.h\n work.df[il] *= -1\n\n work.h /= work.fac\n work.error_last = work.error\n # Simple error estimate - the difference in derivative estimates between\n # this iteration and the last. This is typically conservative because if\n # convergence has begin, the true error is much closer to the difference\n # between the current estimate and the *next* error estimate. However,\n # we could use Richarson extrapolation to produce an error estimate that\n # is one order higher, and take the difference between that and\n # `work.df` (which would just be constant factor that depends on `fac`.)\n work.error = abs(work.df - work.df_last)\n\n def check_termination(work):\n \"\"\"Terminate due to convergence, non-finite values, or error increase\"\"\"\n stop = np.zeros_like(work.df).astype(bool)\n\n i = work.error < work.atol + work.rtol*abs(work.df)\n work.status[i] = eim._ECONVERGED\n stop[i] = True\n\n if work.nit > 0:\n i = ~((np.isfinite(work.x) & np.isfinite(work.df)) | stop)\n work.df[i], work.status[i] = np.nan, eim._EVALUEERR\n stop[i] = True\n\n # With infinite precision, there is a step size below which\n # all smaller step sizes will reduce the error. But in floating point\n # arithmetic, catastrophic cancellation will begin to cause the error\n # to increase again. This heuristic tries to avoid step sizes that are\n # too small. There may be more theoretically sound approaches for\n # detecting a step size that minimizes the total error, but this\n # heuristic seems simple and effective.\n i = (work.error > work.error_last*10) & ~stop\n work.status[i] = _EERRORINCREASE\n stop[i] = True\n\n return stop\n\n def post_termination_check(work):\n return\n\n def customize_result(res, shape):\n return shape\n\n return eim._loop(work, callback, shape, maxiter, func, args, dtype,\n pre_func_eval, post_func_eval, check_termination,\n post_termination_check, customize_result, res_work_pairs,\n preserve_shape)\n```\n###test function signature:\n```python\n\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_maxiter_callback", + "reference": "\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n x = 0.612814\n dist = stats.norm()\n maxiter = 3\n\n def f(x):\n res = dist.cdf(x)\n return res\n\n default_order = 8\n res = differentiate(f, x, maxiter=maxiter, rtol=1e-15)\n assert not np.any(res.success)\n assert np.all(res.nfev == default_order + 1 + (maxiter - 1)*2)\n assert np.all(res.nit == maxiter)\n\n def callback(res):\n callback.iter += 1\n callback.res = res\n assert hasattr(res, 'x')\n assert res.df not in callback.dfs\n callback.dfs.add(res.df)\n assert res.status == eim._EINPROGRESS\n if callback.iter == maxiter:\n raise StopIteration\n callback.iter = -1 # callback called once before first iteration\n callback.res = None\n callback.dfs = set()\n\n res2 = differentiate(f, x, callback=callback, rtol=1e-15)\n # terminating with callback is identical to terminating due to maxiter\n # (except for `status`)\n for key in res.keys():\n if key == 'status':\n assert res[key] == eim._ECONVERR\n assert callback.res[key] == eim._EINPROGRESS\n assert res2[key] == eim._ECALLBACK\n else:\n assert res2[key] == callback.res[key] == res[key]", + "completions": [ + "```python\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n def callback(work):\n if work.nit == 0:\n work.maxiter = 1\n else:\n work.maxiter = 0\n res = _differentiate(lambda x: x**2, 0, maxiter=2, callback=callback)\n self.assertEqual(res.nit, 1)\n self.assertEqual(res.status, 0)\n```\n" + ], + "line": 187, + "token": 1981, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 22, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef get_standard_metadata(\n pyproject_dict: Mapping[str, Any],\n settings: ScikitBuildSettings,\n) -> StandardMetadata:\n new_pyproject_dict = dict(pyproject_dict)\n # Handle any dynamic metadata\n for field, provider, config in load_dynamic_metadata(settings.metadata):\n if provider is None:\n msg = f\"{field} is missing provider\"\n raise KeyError(msg)\n if field not in pyproject_dict.get(\"project\", {}).get(\"dynamic\", []):\n msg = f\"{field} is not in project.dynamic\"\n raise KeyError(msg)\n new_pyproject_dict[\"project\"][field] = provider.dynamic_metadata(field, config)\n new_pyproject_dict[\"project\"][\"dynamic\"].remove(field)\n\n metadata = StandardMetadata.from_pyproject(new_pyproject_dict)\n\n # For scikit-build-core < 0.5, we keep the normalized name for back-compat\n if settings.minimum_version is not None and settings.minimum_version < Version(\n \"0.5\"\n ):\n metadata.name = metadata.canonical_name\n\n # The description field is required to be one line. Instead of merging it\n # or cutting off subsequent lines (setuptools), we throw a nice error.\n # But we didn't validate before 0.9.\n if (\n settings.minimum_version is None or settings.minimum_version >= Version(\"0.9\")\n ) and \"\\n\" in (metadata.description or \"\"):\n msg = \"Multiple lines in project.description are not supported; this is supposed to be a one line summary\"\n raise ValueError(msg)\n\n return metadata\n```\n###test function signature:\n```python\n\ndef test_plugin_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_plugin_metadata", + "reference": "\ndef test_plugin_metadata():\n reason_msg = (\n \"install hatch-fancy-pypi-readme and setuptools-scm to test the \"\n \"dynamic metadata plugins\"\n )\n pytest.importorskip(\"hatch_fancy_pypi_readme\", reason=reason_msg)\n pytest.importorskip(\"setuptools_scm\", reason=reason_msg)\n if shutil.which(\"git\") is None:\n pytest.skip(\"git is not installed\")\n\n shutil.copy(\"plugin_project.toml\", \"pyproject.toml\")\n\n subprocess.run([\"git\", \"init\", \"--initial-branch=main\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.name\", \"bot\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.email\", \"bot@scikit-build.org\"], check=True)\n subprocess.run([\"git\", \"add\", \"pyproject.toml\"], check=True)\n subprocess.run([\"git\", \"commit\", \"-m\", \"initial commit\"], check=True)\n subprocess.run([\"git\", \"tag\", \"v0.1.0\", \"-m\", \"initial commint\"], check=True)\n\n with Path(\"pyproject.toml\").open(\"rb\") as ft:\n pyproject = tomllib.load(ft)\n settings_reader = SettingsReader(pyproject, {}, state=\"metadata_wheel\")\n settings = settings_reader.settings\n\n settings_reader.validate_may_exit()\n\n metadata = get_standard_metadata(pyproject, settings)\n\n assert str(metadata.version) == \"0.1.0\"\n assert metadata.readme == pyproject_metadata.Readme(\n \"Fragment #1Fragment #2\", None, \"text/x-rst\"\n )\n\n assert set(GetRequires().dynamic_metadata()) == {\n \"hatch-fancy-pypi-readme>=22.3\",\n \"setuptools-scm\",\n }", + "completions": [ + "```python\n\ndef test_plugin_metadata():\n from skbuild import __version__ as skbuild_version\n from skbuild.settings import ScikitBuildSettings\n from skbuild.metadata import StandardMetadata\n from skbuild.utils import load_dynamic_metadata\n from skbuild.version import Version\n\n pyproject_dict = {\n \"build-system\": {\n \"requires\": [\"setuptools>=40.8.0\", \"wheel\"],\n \"build-backend\": \"setuptools.build_meta\",\n },\n \"project\": {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n \"description\": \"A build system for Python packages\",\n \"dynamic\": [\"name\", \"version\"],\n },\n }\n\n settings = ScikitBuildSettings(\n minimum_version=Version(\"0.1.0\"),\n metadata=[\n (\"name\", None, None),\n (\"version\", None, None),\n ],\n )\n\n metadata = get_standard_metadata(pyproject_dict, settings)\n\n assert metadata.name == \"scikit-build\"\n assert metadata.version == \"0.1.0\"\n assert metadata.description == \"A build system for Python packages\"\n\n assert metadata.canonical_name == \"scikit-build\"\n assert metadata.canonical_version == \"0.1.0\"\n\n assert metadata.dynamic_metadata == {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n }\n\n assert metadata.dynamic_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_providers == [None, None]\n\n assert metadata.dynamic_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_metadata_providers == [None, None]\n\n assert metadata.dynamic_metadata_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields_by_provider == {None: [\"name\", \"version\"]}\n\n assert metadata.dynamic_metadata_providers_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_configs_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field == {\n (None, \"name\"): \"scikit-build\",\n (None, \"version\"): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider == {\n (\"name\", None): \"scikit-build\",\n (\"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider == {\n (\"name\", None): None,\n (\"version\", None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config == {\n (None, \"name\", None): \"scikit-build\",\n (None, \"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config == {\n (\"name\", None, None): \"scikit-build\",\n (\"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider_and_config == {\n (\"name\", None, None): None,\n (\"version\", None, None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config_and_provider == {\n (None, \"name\", None, None): \"scikit-build\",\n (None, \"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config_and_provider == {\n (\"name\", None, None" + ], + "line": 40, + "token": 347, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 23, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _get_datetime_format(value):\n if not isinstance(value, pd.Series):\n value = pd.Series(value)\n\n value = value[~value.isna()]\n value = value.astype(str).to_numpy()\n\n return _guess_datetime_format_for_array(value)\n```\n###test function signature:\n```python\ndef test__get_datetime_format():\n # Setup\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test__get_datetime_format", + "reference": "def test__get_datetime_format():\n # Setup\n string_value = '2021-02-02'\n list_value = [np.nan, '2021-02-02']\n series_value = pd.Series(['2021-02-02T12:10:59'])\n\n # Run\n string_out = _get_datetime_format(string_value)\n list_out = _get_datetime_format(list_value)\n series_out = _get_datetime_format(series_value)\n\n # Assert\n expected_output = '%Y-%m-%d'\n assert string_out == expected_output\n assert list_out == expected_output\n assert series_out == '%Y-%m-%dT%H:%M:%S'", + "completions": [ + "```python\ndef test__get_datetime_format():\n # Setup\n value = pd.Series([\"2020-01-01\", \"2020-01-02\", \"2020-01-03\"])\n expected = \"%Y-%m-%d\"\n\n # Exercise\n actual = _get_datetime_format(value)\n\n # Verify\n assert actual == expected\n```\n" + ], + "line": 17, + "token": 124, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 24, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def load(self) -> None:\n super().load()\n\n _path = str(\n importlib_resources.files(\"distilabel\")\n / \"steps\"\n / \"tasks\"\n / \"templates\"\n / \"ultrafeedback\"\n / f\"{self.aspect}.jinja2\"\n )\n\n self._template = Template(open(_path).read())\n```\n###test function signature:\n```python\n\n def test_process_with_complex_aspect(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_process_with_complex_aspect", + "reference": "\n def test_process_with_complex_aspect(self) -> None:\n pipeline = Pipeline(name=\"unit-test-pipeline\")\n llm = UltraFeedbackLLM()\n\n task = UltraFeedback(\n name=\"ultrafeedback\",\n aspect=\"truthfulness\",\n llm=llm,\n pipeline=pipeline,\n )\n task.load()\n\n assert next(\n task.process([{\"instruction\": \"test\", \"generations\": [\"A\", \"B\"]}])\n ) == [\n {\n \"instruction\": \"test\",\n \"generations\": [\"A\", \"B\"],\n \"types\": [1, 2],\n \"rationales\": [\"text\", \"text\"],\n \"ratings\": [1, 2],\n \"rationales-for-ratings\": [\"text\", \"text\"],\n \"model_name\": \"ultrafeedback-model\",\n }\n ]", + "completions": [ + "```python\n def test_process_with_complex_aspect(self) -> None:\n self.aspect = \"complex_aspect\"\n self.load()\n self.assertEqual(\n self._template.render(\n {\n \"aspect\": self.aspect,\n \"aspect_name\": \"Complex Aspect\",\n \"aspect_description\": \"This is a complex aspect\",\n \"aspect_type\": \"complex\",\n \"aspect_value\": \"1\",\n \"aspect_weight\": \"1\",\n \"aspect_color\": \"red\",\n \"aspect_icon\": \"icon-circle\",\n \"aspect_icon_color\": \"red\",\n \"aspect_icon_bg_color\": \"white\",\n \"aspect_icon_size\": \"20\",\n \"aspect_icon_border_color\": \"red\",\n \"aspect_icon_border_width\": \"2\",\n \"aspect_icon_border_radius\": \"50\",\n \"aspect_icon_padding\": \"5\",\n \"aspect_icon_margin\": \"5\",\n \"aspect_icon_bg_size\": \"20\",\n \"aspect_icon_bg_border_color\": \"red\",\n \"aspect_icon_bg_border_width\": \"2\",\n \"aspect_icon_bg_border_radius\": \"50\",\n \"aspect_icon_bg_padding\": \"5\",\n \"aspect_icon_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_" + ], + "line": 21, + "token": 145, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 25, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _setup_cuda_device(device: torch.device) -> torch.device:\n local_rank = _get_local_rank() or 0\n if device.index is None:\n device = torch.device(type=\"cuda\", index=local_rank)\n\n # Ensure index is available before setting device\n if device.index >= torch.cuda.device_count():\n raise RuntimeError(\n \"The local rank is larger than the number of available GPUs.\"\n )\n\n torch.cuda.set_device(device)\n return device\n```\n###test function signature:\n```python\n\n def test_get_gpu_device(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_gpu_device", + "reference": "\n def test_get_gpu_device(self) -> None:\n device_idx = torch.cuda.device_count() - 1\n assert device_idx >= 0\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(device_idx)}, clear=True):\n device = get_device()\n assert device.type == \"cuda\"\n assert device.index == device_idx\n assert device.index == torch.cuda.current_device()\n\n # Test that we raise an error if the device index is specified on distributed runs\n if device_idx > 0:\n with pytest.raises(\n RuntimeError,\n match=f\"Device specified is cuda:0 but was assigned cuda:{device_idx}\",\n ):\n device = get_device(\"cuda:0\")\n\n invalid_device_idx = device_idx + 10\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(invalid_device_idx)}):\n with pytest.raises(\n RuntimeError,\n match=\"The local rank is larger than the number of available GPUs\",\n ):\n device = get_device(\"cuda\")\n\n # Test that we fall back to 0 if LOCAL_RANK is not specified\n device = torch.device(_get_device_type_from_env())\n device = _setup_cuda_device(device)\n assert device.type == \"cuda\"\n assert device.index == 0\n assert device.index == torch.cuda.current_device()", + "completions": [ + "```python\n def test_get_gpu_device(self) -> None:\n device = torch.device(type=\"cuda\", index=0)\n self.assertEqual(device, _setup_cuda_device(device))\n```\n" + ], + "line": 21, + "token": 157, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):\n local_dir = settings.home_dir / settings.cache_dir / \"models\" / model_id\n client.load_artifact(local_dir, posixpath.join(\"models\", model_id))\n meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)\n # internal worker case, no token based http client\n if not local_dir.exists():\n raise RuntimeError(f\"Cannot find existing model {project_key}.{model_id} in {local_dir}\")\n with (Path(local_dir) / META_FILENAME).open(encoding=\"utf-8\") as f:\n file_meta = yaml.load(f, Loader=yaml.Loader)\n classification_labels = cls.cast_labels(meta_response)\n meta = ModelMeta(\n name=meta_response.name,\n description=meta_response.description,\n model_type=SupportedModelTypes[meta_response.modelType],\n feature_names=meta_response.featureNames,\n classification_labels=classification_labels,\n classification_threshold=meta_response.threshold,\n loader_module=file_meta[\"loader_module\"],\n loader_class=file_meta[\"loader_class\"],\n )\n\n model_py_ver = (\n tuple(meta_response.languageVersion.split(\".\")) if \"PYTHON\" == meta_response.language.upper() else None\n )\n\n clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver)\n\n constructor_params = meta.__dict__\n constructor_params[\"id\"] = str(model_id)\n\n del constructor_params[\"loader_module\"]\n del constructor_params[\"loader_class\"]\n\n model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params)\n return model\n```\n###test function signature:\n```python\n\ndef test_model_download(request):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_model_download", + "reference": "\ndef test_model_download(request):\n model: BaseModel = request.getfixturevalue(\"german_credit_model\")\n project_key = str(uuid.uuid4())\n\n with utils.MockedProjectCacheDir():\n with utils.MockedClient(mock_all=False) as (client, mr):\n # The model needs to request files\n requested_urls = []\n requested_urls.extend(utils.register_uri_for_model_meta_info(mr, model, project_key))\n requested_urls.extend(\n utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True)\n )\n\n downloaded_model = BaseModel.download(client=client, project_key=project_key, model_id=str(model.id))\n\n for requested_url in requested_urls:\n assert utils.is_url_requested(mr.request_history, requested_url)\n\n assert downloaded_model.id == model.id\n assert downloaded_model.meta == model.meta", + "completions": [ + "```python\ndef test_model_download(request):\n from pathlib import Path\n from unittest.mock import MagicMock\n from giskard.client import GiskardClient\n from giskard.models import ModelMetaInfo, ModelMeta, SupportedModelTypes\n from giskard.utils import cast_labels\n from giskard.models.base import Model\n from giskard.models.classification import ClassificationModel\n from giskard.models.regression import RegressionModel\n from giskard.models.clustering import ClusteringModel\n from giskard.models.dim_reduction import DimReductionModel\n from giskard.models.anomaly_detection import AnomalyDetectionModel\n from giskard.models.time_series import TimeSeriesModel\n from giskard.models.text_classification import TextClassificationModel\n from giskard.models.text_recommendation import TextRecommendationModel\n from giskard.models.image_classification import ImageClassificationModel\n from giskard.models.image_recommendation import ImageRecommendationModel\n from giskard.models.image_segmentation import ImageSegmentationModel\n from giskard.models.image_generation import ImageGenerationModel\n from giskard.models.image_style_transfer import ImageStyleTransferModel\n from giskard.models.image_super_resolution import ImageSuperResolutionModel\n from giskard.models.image_denoising import ImageDenoisingModel\n from giskard.models.image_colorization import ImageColorizationModel\n from giskard.models.image_deraining import ImageDerainingModel\n from giskard.models.image_deblurring import ImageDeblurringModel\n from giskard.models.image_edge_detection import ImageEdgeDetectionModel\n from giskard.models.image_face_detection import ImageFaceDetectionModel\n from giskard.models.image_face_recognition import ImageFaceRecognitionModel\n from giskard.models.image_object_detection import ImageObjectDetectionModel\n from giskard.models.image_object_recognition import ImageObjectRecognitionModel\n from giskard.models.image_segmentation_mask_rcnn import ImageSegmentationMaskRCNNModel\n from giskard.models.image_segmentation_yolo import ImageSegmentationYOLOModel\n from giskard.models.image_segmentation_faster_rcnn import ImageSegmentationFasterRCNNModel\n from giskard.models.image_segmentation_unet import ImageSegmentationUNetModel\n from giskard.models.image_segmentation_deeplab import ImageSegmentationDeepLabModel\n from giskard.models.image_segmentation_fcn import ImageSegmentationFCNModel\n from giskard.models.image_segmentation_pspnet import ImageSegmentationPSPNetModel\n from giskard.models.image_segmentation_linknet import ImageSegmentationLinkNetModel\n from giskard.models.image_segmentation_unet_plus import ImageSegmentationUNetPlusModel\n from giskard.models.image_segmentation_deeplab_v3 import ImageSegmentationDeepLabV3Model\n from giskard.models.image_segmentation_deeplab_v3_plus import ImageSegmentationDeepLabV3PlusModel\n from giskard.models.image_segmentation_fcn_resnet101 import ImageSegmentationFCNResNet101Model\n from giskard.models.image_segmentation_fcn_resnet50 import ImageSegmentationFCNResNet50Model\n from giskard.models.image_segmentation_fcn_mobilenet_v2 import ImageSegmentationFCNMobileNetV2Model\n from giskard.models.image_segmentation_fcn_mobilenet_v3 import ImageSegmentationFCNMobileNetV3Model\n from giskard.models.image_segmentation_fcn_efficientnet_b" + ], + "line": 40, + "token": 316, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 27, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\ndef test_rounded_date_columns():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rounded_date_columns", + "reference": "def test_rounded_date_columns():\n df = pl.DataFrame(\n [\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.000000\", \"exp\": True},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.123456\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:01.000000\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00\", \"exp\": True},\n ]\n )\n\n df = df.with_columns(df[\"a\"].str.to_date().alias(\"a_dt\"))\n actual = columns_equal(df[\"a_dt\"], df[\"b\"])\n expected = df[\"exp\"]\n assert_series_equal(actual, expected, check_names=False)", + "completions": [ + "```python\ndef test_rounded_date_columns():\n # given\n col_1 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n col_2 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n\n # when\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n\n # then\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 28, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef annotations_to_instances(annos, image_size, mask_format=\"polygon\"):\n boxes = [BoxMode.convert(obj[\"bbox\"], obj[\"bbox_mode\"], BoxMode.XYXY_ABS) for obj in annos]\n target = Instances(image_size)\n boxes = target.gt_boxes = Boxes(boxes)\n boxes.clip(image_size)\n\n classes = [obj[\"category_id\"] for obj in annos]\n classes = torch.tensor(classes, dtype=torch.int64)\n target.gt_classes = classes\n\n if len(annos) and \"segmentation\" in annos[0]:\n segms = [obj[\"segmentation\"] for obj in annos]\n if mask_format == \"polygon\":\n masks = PolygonMasks(segms)\n else:\n assert mask_format == \"bitmask\", mask_format\n masks = []\n for segm in segms:\n if isinstance(segm, list):\n # polygon\n masks.append(polygons_to_bitmask(segm, *image_size))\n elif isinstance(segm, dict):\n # COCO RLE\n masks.append(mask_util.decode(segm))\n elif isinstance(segm, np.ndarray):\n assert segm.ndim == 2, \"Expect segmentation of 2 dimensions, got {}.\".format(\n segm.ndim\n )\n # mask array\n masks.append(segm)\n else:\n raise ValueError(\n \"Cannot convert segmentation of type '{}' to BitMasks!\"\n \"Supported types are: polygons as list[list[float] or ndarray],\"\n \" COCO-style RLE as a dict, or a full-image segmentation mask \"\n \"as a 2D ndarray.\".format(type(segm))\n )\n # torch.from_numpy does not support array with negative stride.\n masks = BitMasks(\n torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])\n )\n target.gt_masks = masks\n\n if len(annos) and \"keypoints\" in annos[0]:\n kpts = [obj.get(\"keypoints\", []) for obj in annos]\n target.gt_keypoints = Keypoints(kpts)\n\n return target\n```\n###test function signature:\n```python\n\n def test_transform_simple_annotation(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_simple_annotation", + "reference": "\n def test_transform_simple_annotation(self):\n transforms = T.TransformList([T.HFlipTransform(400)])\n anno = {\n \"bbox\": np.asarray([10, 10, 200, 300]),\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 3,\n \"segmentation\": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],\n }\n\n output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))\n self.assertTrue(np.allclose(output[\"bbox\"], [200, 10, 390, 300]))\n self.assertEqual(len(output[\"segmentation\"]), len(anno[\"segmentation\"]))\n self.assertTrue(np.allclose(output[\"segmentation\"][0], [390, 10, 300, 100, 300, 10]))\n\n detection_utils.annotations_to_instances([output, output], (400, 400))", + "completions": [ + "```python\n def test_transform_simple_annotation(self):\n annos = [\n {\n \"bbox\": [10, 10, 100, 100],\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 0,\n \"segmentation\": [\n [10, 10, 50, 50, 90, 90, 100, 100],\n [20, 20, 50, 50, 80, 80, 100, 100],\n ],\n \"keypoints\": [10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n }\n ]\n image_size = (100, 100)\n target = annotations_to_instances(annos, image_size)\n assert target.gt_boxes.tensor.tolist() == [[10, 10, 100, 100]]\n assert target.gt_classes.tolist() == [0]\n assert target.gt_masks.polygons().tolist() == [\n [[10, 10], [50, 50], [90, 90], [100, 100]],\n [[20, 20], [50, 50], [80, 80], [100, 100]],\n ]\n assert target.gt_keypoints.tensor.tolist() == [[10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n```\n" + ], + "line": 54, + "token": 437, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 29, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_simple(seed):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_simple", + "reference": "\ndef test_sabre_simple(seed):\n placer = Trivial()\n circ = Circuit(5)\n circ.add(gates.CZ(0, 1))\n initial_layout = placer(circ)\n router = Sabre(connectivity=star_connectivity(), seed=seed)\n routed_circuit, final_map = router(circuit=circ, initial_layout=initial_layout)\n assert router.added_swaps == 1\n assert final_map == {\"q0\": 2, \"q1\": 1, \"q2\": 0, \"q3\": 3, \"q4\": 4}\n assert routed_circuit.queue[0].qubits == (0, 2)\n assert isinstance(routed_circuit.queue[0], gates.SWAP)\n assert isinstance(routed_circuit.queue[1], gates.CZ)\n assert_connectivity(circuit=routed_circuit, connectivity=star_connectivity())\n assert_circuit_equivalence(\n original_circuit=circ,\n transpiled_circuit=routed_circuit,\n final_map=final_map,\n initial_map=initial_layout,\n )", + "completions": [ + "```python\ndef test_sabre_simple(seed):\n import random\n random.seed(seed)\n import networkx as nx\n from pyquil.quil import Program\n from pyquil.api import QVMConnection\n from pyquil.gates import CNOT, H, MEASURE\n from pyquil.quilbase import Declare\n from pyquil.quilbase import DefGate\n from pyquil.quilbase import DefCircuit\n from pyquil.quilbase import Gate\n from pyquil.quilbase import Circuit\n from pyquil.quilbase import QubitPlaceholder\n from pyquil.quilbase import MemoryReference\n from pyquil.quilbase import Parameter\n from pyquil.quilbase import Pragma\n from pyquil.quilbase import PragmaDeclare\n from pyquil.quilbase import PragmaDefGate\n from pyquil.quilbase import PragmaDefCircuit\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefQubit\n from pyquil.quilbase import PragmaDefMemory\n from pyquil.quilbase import PragmaDefRegister\n from pyquil.quilbase import PragmaDefLabel\n from pyquil.quilbase import PragmaDefFunction\n from pyquil.quilbase import PragmaDefProcedure\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyqu" + ], + "line": 27, + "token": 242, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 30, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef unique_patterns_2d(\n agrid: NDArray[np.int64], ksize: int, periodic_input: bool\n) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:\n assert ksize >= 1\n if periodic_input:\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n else:\n # TODO: implement non-wrapped image handling\n # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n\n patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(\n agrid,\n (\n agrid.shape[0] - ksize + 1,\n agrid.shape[1] - ksize + 1,\n ksize,\n ksize,\n *agrid.shape[2:],\n ),\n agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],\n writeable=False,\n )\n patch_codes = hash_downto(patches, 2)\n uc, ui = np.unique(patch_codes, return_index=True)\n locs = np.unravel_index(ui, patch_codes.shape)\n up: NDArray[np.int64] = patches[locs[0], locs[1]]\n ids: NDArray[np.int64] = np.vectorize(\n {code: ind for ind, code in enumerate(uc)}.get\n )(patch_codes)\n return ids, up, patch_codes\n```\n###test function signature:\n```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_patterns_2d", + "reference": "\ndef test_unique_patterns_2d(img_redmaze) -> None:\n img = img_redmaze\n tile_size = 1\n pattern_width = 2\n _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(\n img, tile_size\n )\n\n (\n _patterns_in_grid,\n pattern_contents_list,\n patch_codes,\n ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)\n assert patch_codes[1][2] == 4867810695119132864\n assert pattern_contents_list[7][1][1] == 8253868773529191888", + "completions": [ + "```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n ids, up, patch_codes = unique_patterns_2d(img_redmaze, 3, True)\n assert ids.shape == (10,)\n assert up.shape == (10, 3, 3)\n assert patch_codes.shape == (10,)\n```\n" + ], + "line": 48, + "token": 425, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 31, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef serialize(data: Any, encoding: str | None = None) -> str | bytes:\n serialized: bytes | str\n\n if encoding is not None:\n serialized = yaml.dump(data, encoding=encoding, Dumper=AiiDADumper)\n else:\n serialized = yaml.dump(data, Dumper=AiiDADumper)\n\n return serialized\n```\n###test function signature:\n```python\ndef test_mixed_attribute_normal_dict():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mixed_attribute_normal_dict", + "reference": "def test_mixed_attribute_normal_dict():\n from aiida.common.extendeddicts import AttributeDict\n\n # Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively\n dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}\n attribute_dict = AttributeDict(dictionary)\n\n # Now add a normal dictionary in the attribute dictionary\n attribute_dict['nested']['normal'] = {'a': 2}\n\n serialized = serialize.serialize(attribute_dict)\n deserialized = serialize.deserialize_unsafe(serialized)\n\n assert attribute_dict, deserialized", + "completions": [ + "```python\ndef test_mixed_attribute_normal_dict():\n data = {\n \"a\": 1,\n \"b\": 2,\n \"c\": 3,\n \"d\": 4,\n \"e\": 5,\n \"f\": 6,\n \"g\": 7,\n \"h\": 8,\n \"i\": 9,\n \"j\": 10,\n \"k\": 11,\n \"l\": 12,\n \"m\": 13,\n \"n\": 14,\n \"o\": 15,\n \"p\": 16,\n \"q\": 17,\n \"r\": 18,\n \"s\": 19,\n \"t\": 20,\n \"u\": 21,\n \"v\": 22,\n \"w\": 23,\n \"x\": 24,\n \"y\": 25,\n \"z\": 26,\n }\n assert serialize(data) == \"\"\"a: 1\nb: 2\nc: 3\nd: 4\ne: 5\nf: 6\ng: 7\nh: 8\ni: 9\nj: 10\nk: 11\nl: 12\nm: 13\nn: 14\no: 15\np: 16\nq: 17\nr: 18\ns: 19\nt: 20\nu: 21\nv: 22\nw: 23\nx: 24\ny: 25\nz: 26\n\"\"\"\n```\n" + ], + "line": 17, + "token": 133, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected_against_alias", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n http_client.select_model(model_id=\"yolov8n-640\")\n\n # when\n result = http_client.unload_model(model_id=\"yolov8n-640\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"coco/3\",\n }\n assert (\n http_client.selected_model is None\n ), \"Even when alias is in use - selected model should be emptied\"", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n requests_mock.post(\n f\"{API_URL}/model/remove\",\n json={\n \"model_id\": \"test_model_id\",\n },\n headers=DEFAULT_HEADERS,\n )\n client = RoboflowClient(api_key=\"test_api_key\", api_secret=\"test_api_secret\")\n # when\n response = client.unload_model(model_id=\"test_model_alias\")\n # then\n assert response.model_id == \"test_model_id\"\n assert response.model_name == \"test_model_name\"\n assert response.model_version == \"test_model_version\"\n assert response.model_type == \"test_model_type\"\n assert response.model_status == \"test_model_status\"\n assert response.model_created_at == \"test_model_created_at\"\n assert response.model_updated_at == \"test_model_updated_at\"\n assert response.model_tags == \"test_model_tags\"\n assert response.model_description == \"test_model_description\"\n assert response.model_url == \"test_model_url\"\n assert response.model_thumbnail_url == \"test_model_thumbnail_url\"\n assert response.model_training_data_url == \"test_model_training_data_url\"\n assert response.model_training_data_type == \"test_model_training_data_type\"\n assert response.model_training_data_size == \"test_model_training_data_size\"\n assert response.model_training_data_description == \"test_model_training_data_description\"\n assert response.model_training_data_tags == \"test_model_training_data_tags\"\n assert response.model_training_data_created_at == \"test_model_training_data_created_at\"\n assert response.model_training_data_updated_at == \"test_model_training_data_updated_at\"\n assert response.model_training_data_status == \"test_model_training_data_status\"\n assert response.model_training_data_status_message == \"test_model_training_data_status_message\"\n assert response.model_training_data_status_details == \"test_model_training_data_status_details\"\n assert response.model_training_data_status_created_at == \"test_model_training_data_status_created_at\"\n assert response.model_training_data_status_updated_at == \"test_model_training_data_status_updated_at\"\n assert response.model_training_data_status_completed_at == \"test_model_training_data_status_completed_at\"\n assert response.model_training_data_status_duration == \"test_model_training_data_status_duration\"\n assert response.model_training_data_status_budget_type == \"test_model_training_data_status_budget_type\"\n assert response.model_training_data_status_budget_amount == \"test_model_training_data_status_budget_amount\"\n assert response.model_training_data_status_budget_unit == \"test_model_training_data_status_budget_unit\"\n assert response.model_training_data_status_budget_remaining == \"test_model_training_data_status_budget_remaining\"\n assert response.model_training_data_status_budget_message == \"test_model_training_data_status_budget_message\"\n assert response.model_training_data_status_budget_details == \"test_model_training_data_status_budget_details\"\n assert response.model_training_data_status_budget_created_at == \"test_model_training_data_status_budget_created_at\"\n assert response.model_training_data_status_budget_updated_at == \"test_" + ], + "line": 31, + "token": 162, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 33, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef edge_index_to_vector(\n edge_index: Tensor,\n size: Tuple[int, int],\n bipartite: bool,\n force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n row, col = edge_index\n\n if bipartite: # No need to account for self-loops.\n idx = (row * size[1]).add_(col)\n population = size[0] * size[1]\n return idx, population\n\n elif force_undirected:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We only operate on the upper triangular matrix:\n mask = row < col\n row, col = row[mask], col[mask]\n offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n idx = row.mul_(num_nodes).add_(col).sub_(offset)\n population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n return idx, population\n\n else:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We remove self-loops as we do not want to take them into account\n # when sampling negative values.\n mask = row != col\n row, col = row[mask], col[mask]\n col[row < col] -= 1\n idx = row.mul_(num_nodes - 1).add_(col)\n population = num_nodes * num_nodes - num_nodes\n return idx, population\n```\n###test function signature:\n```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_edge_index_to_vector_and_vice_versa", + "reference": "\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n N = 10\n row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n edge_index = torch.stack([row, col], dim=0)\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n assert population == N * N\n assert idx.tolist() == list(range(population))\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n assert is_undirected(edge_index2)\n assert edge_index.tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n assert population == N * N - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n force_undirected=True)\n assert population == (N * (N + 1)) / 2 - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n force_undirected=True)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()", + "completions": [ + "```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n\n # Test force_undirected:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test force_undirected with bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n```\n" + ], + "line": 43, + "token": 351, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 34, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_integration_test(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_integration_test", + "reference": "\n def test_spectrogram_integration_test(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,\n 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,\n 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,\n 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,\n 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,\n 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,\n 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,\n 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,\n 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,\n 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,\n 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,\n 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,\n 0.0293578 , 0.03452379, 0.02194803, 0.01676056,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=256,\n num_mel_filters=400,\n min_frequency=20,\n max_frequency=8000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"kaldi\",\n triangularize_in_mel_space=True,\n )\n\n mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"povey\", periodic=False),\n frame_length=400,\n hop_length=160,\n fft_length=512,\n power=2.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n preemphasis=0.97,\n mel_filters=mel_filters,\n log_mel=\"log\",\n mel_floor=1.1920928955078125e-07,\n remove_dc_offset=True,\n )\n self.assertEqual(spec.shape, (400, 584))\n\n # fmt: off\n expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -6.52463769, -7.73677889, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.18650018, -3.37195286,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -4.70190154, -2.4217066 , -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -5.62755239, -3.53385194,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -9.43303023, -8.77480925, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.2951092 , -5.51585994,\n -15.94238515, -15.94238515, -15.94238515, -4.40151721,\n -3.95228878, -15.94238515, -15.94238515, -15.94238515,\n -6.10365415, -4.59494697, -15.94238515, -15.94238515,\n -15.94238515, -8.10727767, -6.2585298 , -15.94238515,\n -15.94238515, -15.94238515, -5.60161702, -4.47217004,\n -15.94238515, -15.94238515, -15.94238515, -5.91641988]\n )\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))", + "completions": [ + "```python\n def test_spectrogram_integration_test(self):\n waveform = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])\n window = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n expected_spectrogram = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]])\n\n self.assertTrue(np.allclose(spectrogram, expected_spectrogram))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 35, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef run_query_expansion_node(modules: List[Callable],\n module_params: List[Dict],\n previous_result: pd.DataFrame,\n node_line_dir: str,\n strategies: Dict,\n ) -> pd.DataFrame:\n if not os.path.exists(node_line_dir):\n os.makedirs(node_line_dir)\n node_dir = os.path.join(node_line_dir, \"query_expansion\")\n if not os.path.exists(node_dir):\n os.makedirs(node_dir)\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n\n # run query expansion\n results, execution_times = zip(*map(lambda task: measure_speed(\n task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))\n average_times = list(map(lambda x: x / len(results[0]), execution_times))\n\n # save results to folder\n pseudo_module_params = deepcopy(module_params)\n for i, module_param in enumerate(pseudo_module_params):\n if 'prompt' in module_params:\n module_param['prompt'] = str(i)\n filepaths = list(map(lambda x: os.path.join(node_dir, f'{x}.parquet'), range(len(modules))))\n list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet\n filenames = list(map(lambda x: os.path.basename(x), filepaths))\n\n # make summary file\n summary_df = pd.DataFrame({\n 'filename': filenames,\n 'module_name': list(map(lambda module: module.__name__, modules)),\n 'module_params': module_params,\n 'execution_time': average_times,\n })\n\n # Run evaluation when there are more than one module.\n if len(modules) > 1:\n # pop general keys from strategies (e.g. metrics, speed_threshold)\n general_key = ['metrics', 'speed_threshold']\n general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))\n extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))\n\n # first, filter by threshold if it is enabled.\n if general_strategy.get('speed_threshold') is not None:\n results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],\n filenames)\n\n # check metrics in strategy\n if general_strategy.get('metrics') is None:\n raise ValueError(\"You must at least one metrics for query expansion evaluation.\")\n\n if extra_strategy.get('top_k') is None:\n extra_strategy['top_k'] = 10 # default value\n\n # get retrieval modules from strategy\n retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)\n\n # get retrieval_gt\n retrieval_gt = pd.read_parquet(os.path.join(project_dir, \"data\", \"qa.parquet\"))['retrieval_gt'].tolist()\n\n # run evaluation\n evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(\n retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,\n general_strategy['metrics'], project_dir, previous_result), results))\n\n evaluation_df = pd.DataFrame({\n 'filename': filenames,\n **{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))\n for metric_name in general_strategy['metrics']}\n })\n summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')\n\n best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)\n # change metric name columns to query_expansion_metric_name\n best_result = best_result.rename(columns={\n metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})\n best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])\n else:\n best_result, best_filename = results[0], filenames[0]\n best_result = pd.concat([previous_result, best_result], axis=1)\n\n # add 'is_best' column at summary file\n summary_df['is_best'] = summary_df['filename'] == best_filename\n\n # save files\n summary_df.to_csv(os.path.join(node_dir, \"summary.csv\"), index=False)\n best_result.to_parquet(os.path.join(node_dir, f\"best_{os.path.splitext(best_filename)[0]}.parquet\"), index=False)\n\n return best_result\n```\n###test function signature:\n```python\n\ndef test_run_query_expansion_node_default(node_line_dir):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_query_expansion_node_default", + "reference": "\ndef test_run_query_expansion_node_default(node_line_dir):\n generator_models['mock'] = MockLLM\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n qa_path = os.path.join(project_dir, \"data\", \"qa.parquet\")\n previous_result = pd.read_parquet(qa_path)\n\n modules = [query_decompose, hyde]\n module_params = [{'llm': \"mock\", 'temperature': 0.2, 'batch': 7}, {'llm': \"mock\"}]\n strategies = {\n 'metrics': metrics\n }\n best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n base_query_expansion_test(best_result, node_line_dir)", + "completions": [ + "```python\ndef test_run_query_expansion_node_default(node_line_dir):\n modules = [query_expansion_module]\n module_params = [{'prompt': 'default'}]\n previous_result = pd.DataFrame({'query': ['a', 'b', 'c']})\n strategies = {'metrics': ['recall', 'precision', 'f1'], 'speed_threshold': 1000}\n run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n```\n" + ], + "line": 84, + "token": 850, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 36, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(self) -> Tuple[pd.DataFrame, FreqDetectionResult]:\n try:\n if self._prophet_detector_params[\"suppress_stan\"]:\n pd.set_option(\"mode.chained_assignment\", None)\n\n # Skip measurements based on feedbacks given from SODA Cloud\n preprocessed_df = self.preprocess(time_series_df=self.raw_time_series_df)\n\n # Automatically detect frequency of the time series\n freq_detector = FrequencyDetector(\n logs=self.logs,\n params=self.params,\n time_series_df=preprocessed_df,\n manual_freq=self.training_dataset_params.frequency,\n )\n freq_detection_result = freq_detector.detect_frequency()\n\n # Return if frequency detection failed\n if freq_detection_result.error_code_int >= ERROR_CODE_LEVEL_CUTTOFF:\n return self.exit_with_warning(freq_detection_result)\n\n # Apply training dataset configurations\n training_df = self.apply_training_dataset_configs(\n time_series_df=preprocessed_df, freq_detection_result=freq_detection_result\n )\n\n # Remove big gaps from the time series to not confuse Prophet\n training_df = self.remove_big_gaps_from_time_series(\n time_series_df=training_df, freq_detection_result=freq_detection_result\n )\n\n # Only use the last n points for training based on the window length\n window_length = self.get_window_length(training_df=training_df)\n training_df = training_df.iloc[-window_length:]\n\n training_df_shape = training_df[\"y\"].dropna().shape[0]\n if training_df_shape <= self._min_n_points:\n freq_detection_result = get_not_enough_measurements_freq_result(n_data_points=training_df_shape)\n return self.exit_with_warning(freq_detection_result)\n\n model_hyperparameters = self.get_prophet_hyperparameters(time_series_df=training_df)\n\n predictions_df = self.setup_fit_predict(\n time_series_df=training_df, model_hyperparameters=model_hyperparameters\n )\n anomalies_df = self.detect_anomalies(time_series_df=training_df, predictions_df=predictions_df)\n anomalies_df = self.generate_severity_zones(anomalies_df=anomalies_df)\n anomalies_df = self.compute_alert_level(anomalies_df=anomalies_df)\n return anomalies_df, freq_detection_result\n except Exception as e:\n raise e\n```\n###test function signature:\n```python\n\ndef test_with_exit() -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_exit", + "reference": "\ndef test_with_exit() -> None:\n time_series_df = generate_random_dataframe(size=2, n_rows_to_convert_none=0, frequency=\"D\")\n detector = ProphetDetector(\n logs=LOGS,\n params=PARAMS,\n time_series_df=time_series_df,\n model_cfg=ModelConfigs(),\n training_dataset_params=TrainingDatasetParameters(),\n severity_level_params=SeverityLevelParameters(),\n )\n df_anomalies, frequency_result = detector.run()\n assert df_anomalies.empty\n assert frequency_result.error_code_int == 100\n assert frequency_result.freq_detection_strategy == \"not_enough_measurements_custom\"", + "completions": [ + "```python\ndef test_with_exit() -> None:\n from soda.sodacl.checks.check_frequency import FrequencyDetector\n from soda.sodacl.checks.check_frequency import get_not_enough_measurements_freq_result\n from soda.sodacl.checks.check_frequency import ERROR_CODE_LEVEL_CUTTOFF\n from soda.sodacl.checks.check_frequency import FreqDetectionResult\n from soda.sodacl.checks.check_frequency import CheckFrequency\n from soda.sodacl.checks.check_frequency import CheckFrequencyParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyLogs\n from soda.sodacl.checks.check_frequency import CheckFrequencyTrainingDatasetParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyPro" + ], + "line": 52, + "token": 310, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 37, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef average_transforms(\n R: torch.Tensor,\n t: torch.Tensor,\n w: torch.Tensor,\n mask: torch.Tensor,\n dim: int,\n t_edge: Optional[torch.Tensor] = None,\n dither: Optional[bool] = True,\n dither_eps: float = 1e-4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert dim >= 0, \"dimension must index from the left\"\n w = torch.where(\n mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)\n )\n\n # We use different averaging models based on the number of weights\n num_transform_weights = w.size(-1)\n if num_transform_weights == 1:\n # Share a single scalar weight between t and R.\n probs = w.softmax(dim)\n t_probs = probs\n R_probs = probs[..., None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 2:\n # Use separate scalar weights for each of t and R.\n probs = w.softmax(dim)\n t_probs, R_probs = probs.unbind(-1)\n t_probs = t_probs[..., None]\n R_probs = R_probs[..., None, None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 3:\n # For R use a signed scalar weight.\n R_probs = w[..., 2].softmax(dim)[..., None, None]\n\n # For t use a two-parameter precision matrix P = P_isometric + P_radial.\n # We need to hand compute softmax over the shared dim x 2 elements.\n w_t = w[..., :2]\n w_t_total = w_t.logsumexp([dim, -1], True)\n p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)\n\n # Use Gaussian fusion for translation.\n t_edge = t_edge * mask.to(t_edge.dtype)[..., None]\n t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)\n else:\n raise NotImplementedError\n\n # Average rotation via SVD\n R_avg_unc = (R * R_probs).sum(dim)\n R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)\n U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)\n R_avg = U @ Vh\n\n # Enforce that matrix is rotation matrix\n d = torch.linalg.det(R_avg)\n d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)\n Vh = Vh * d_expand\n R_avg = U @ Vh\n return R_avg, t_avg\n```\n###test function signature:\n```python\n\ndef test_average_transforms(rotations, translations):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_average_transforms", + "reference": "\ndef test_average_transforms(rotations, translations):\n R = torch.stack([rotations[0], torch.eye(3)])\n t = torch.stack([translations[0], torch.zeros(3)])\n w = torch.ones(2, 2)\n mask = torch.ones(2)\n # average of a transform with the identity is \"half\" the transform\n R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)\n R_total_fromavg, _ = compose_transforms(\n R_avg, torch.zeros(3), R_avg, torch.zeros(3)\n )\n _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)\n assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)\n assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)", + "completions": [ + "```python\ndef test_average_transforms(rotations, translations):\n # test case 1\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 2\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 3\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 4\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 5\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 6\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 7\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 8\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 9\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w" + ], + "line": 65, + "token": 537, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 38, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n return GPT2Config(\n vocab_size=btlm_config.vocab_size,\n n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n n_embd=btlm_config.hidden_size,\n n_layer=btlm_config.num_hidden_layers,\n n_head=btlm_config.num_attention_heads,\n n_inner=btlm_config.n_inner,\n activation_function=btlm_config.activation_function,\n resid_pdrop=btlm_config.resid_pdrop,\n embd_pdrop=btlm_config.embd_pdrop,\n attn_pdrop=btlm_config.attn_pdrop,\n layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n initializer_range=btlm_config.initializer_range,\n bos_token_id=btlm_config.bos_token_id,\n eos_token_id=btlm_config.eos_token_id,\n # These are new arguments not in the original GPT2Config\n use_alibi=btlm_config.position_embedding_type == \"alibi\",\n use_flash_attn=btlm_config.position_embedding_type == \"alibi\", # Alibi code path requires flash_attn\n mup_width_scale=btlm_config.mup_width_scale,\n mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n mup_output_multiplier=btlm_config.mup_output_alpha,\n mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n mlp_multiple_of=1,\n )\n```\n###test function signature:\n```python\ndef test_btlm_optimized(model_name):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_btlm_optimized", + "reference": "def test_btlm_optimized(model_name):\n dtype = torch.float16\n device = \"cuda\"\n config = btlm_config_to_gpt2_config(\n AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n )\n config.fused_bias_fc = True\n config.fused_dropout_add_ln = True\n config.residual_in_fp32 = True\n\n pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n model = GPTLMHeadModel(config, device=device, dtype=dtype)\n model.load_state_dict(pretrained_state_dict)\n model.eval()\n\n torch.manual_seed(0)\n batch_size = 2\n max_seqlen = 256\n seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n input_ids = torch.randint(\n 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n )\n with torch.no_grad():\n out = model.transformer(input_ids)\n logits = model(input_ids).logits\n del model\n\n # Without device_map, the model is loaded on the CPU, which is very slow\n # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n model_ref = AutoModelForCausalLM.from_pretrained(\n model_name, device_map=\"auto\", trust_remote_code=True\n )\n model_ref.eval()\n with torch.no_grad():\n out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n logits_ref = model_ref(input_ids).logits.to(device=device)\n del model_ref\n\n model_hf = AutoModelForCausalLM.from_pretrained(\n model_name,\n torch_dtype=dtype,\n device_map={\"\": device},\n trust_remote_code=True,\n )\n model_hf.eval()\n with torch.no_grad():\n out_hf = model_hf.transformer(input_ids).last_hidden_state\n logits_hf = model_hf(input_ids).logits\n del model_hf\n\n print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n assert (logits - logits_ref).abs().max().item() < 3 * (\n logits_hf - logits_ref\n ).abs().max().item()", + "completions": [ + "```python\ndef test_btlm_optimized(model_name):\n from transformers import GPT2Config, PretrainedConfig\n from bert_extended.optimized import btlm_config_to_gpt2_config\n\n config = PretrainedConfig.from_pretrained(model_name)\n gpt2_config = btlm_config_to_gpt2_config(config)\n gpt2_config.to_json_file(f\"{model_name}_optimized.json\")\n```\n" + ], + "line": 35, + "token": 169, + "line_diff": 2, + "token_diff": 0 + }, + { + "id": 39, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef grad_solid_harmonic(r: jnp.ndarray, l_max: int) -> jnp.ndarray:\n def cartesian_sum(r, ell):\n # computes all cartesian sums of order ell\n # i.e. all terms (a*x + b*y + c*z) for a+b+c = ell\n if ell == 0:\n return np.array([0.])\n return np.sum(np.stack(\n np.meshgrid(*[r]*ell, indexing='ij'), axis=0).reshape(ell, -1), axis=0)\n\n dharmonics = []\n for ell in range(l_max + 1):\n dx = cartesian_sum(np.array([1., 0., 0.]), ell)\n dy = cartesian_sum(np.array([0., 1., 0.]), ell)\n dz = cartesian_sum(np.array([0., 0., 1.]), ell)\n d = np.stack([dx, dy, dz], axis=1)\n cp = jax.vmap(functools.partial(cartesian_product, ell=ell), out_axes=-1)\n dcp = jnp.where(r.T[None, :, :] == 0, 0,\n d[:, :, None]*cp(r)[:, None, :] / r.T[None, :, :])\n sh = jnp.einsum('cs,cxn->snx', full_cart2sph(ell, True), dcp)\n pad = l_max - ell\n sh = jnp.pad(sh, [[pad, pad], [0, 0], [0, 0]])\n dharmonics.append(sh)\n\n return jnp.stack(dharmonics, axis=1) # [L, M, grid, xyz]\n```\n###test function signature:\n```python\n\n def test_grad_solid_harmonic(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_grad_solid_harmonic", + "reference": "\n def test_grad_solid_harmonic(self):\n np.random.seed(0)\n r = np.random.randn(100, 3)\n l_max = 5\n\n jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max))\n expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])),\n [1, 2, 0, 3])\n\n with self.subTest('by hand'):\n observed = gto.grad_solid_harmonic(r, l_max)\n np.testing.assert_allclose(observed, expected, atol=1.e-4)\n with self.subTest('by jax'):\n observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max)\n np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4)", + "completions": [ + "```python\n def test_grad_solid_harmonic(self):\n r = jnp.array([[0.5, 0.5, 0.5]])\n l_max = 2\n dharmonics = grad_solid_harmonic(r, l_max)\n self.assertEqual(dharmonics.shape, (l_max + 1, l_max + 1, 1, 3))\n self.assertTrue(jnp.allclose(dharmonics[0, 0, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 0, 0, :], jnp.array([1., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 1, 0, :], jnp.array([0., 1., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 0, 0, :], jnp.array([0., 0., 1.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 1, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 2, 0, :], jnp.array([0., 0., 0.])))\n```\n" + ], + "line": 32, + "token": 389, + "line_diff": 1, + "token_diff": 3 + }, + { + "id": 40, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_basis(\n base_shape, angular_tesselation, remove_symmetries=True, eps=1e-4\n):\n\n if base_shape == 'tetrahedron':\n verts = np.array([\n (np.sqrt(8 / 9), 0, -1 / 3),\n (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),\n (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3),\n (0, 0, 1),\n ])\n faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)])\n elif base_shape == 'icosahedron':\n a = (np.sqrt(5) + 1) / 2\n verts = np.array([\n (-1, 0, a),\n (1, 0, a),\n (-1, 0, -a),\n (1, 0, -a),\n (0, a, 1),\n (0, a, -1),\n (0, -a, 1),\n (0, -a, -1),\n (a, 1, 0),\n (-a, 1, 0),\n (a, -1, 0),\n (-a, -1, 0),\n ]) / np.sqrt(a + 2)\n faces = np.array([\n (0, 4, 1),\n (0, 9, 4),\n (9, 5, 4),\n (4, 5, 8),\n (4, 8, 1),\n (8, 10, 1),\n (8, 3, 10),\n (5, 3, 8),\n (5, 2, 3),\n (2, 7, 3),\n (7, 10, 3),\n (7, 6, 10),\n (7, 11, 6),\n (11, 0, 6),\n (0, 1, 6),\n (6, 1, 10),\n (9, 0, 11),\n (9, 11, 2),\n (9, 2, 5),\n (7, 2, 11),\n ])\n elif base_shape == 'octahedron':\n verts = np.array(\n [(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]\n )\n corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n else:\n raise ValueError(f'base_shape {base_shape} not supported')\n verts = tesselate_geodesic(verts, faces, angular_tesselation)\n\n if remove_symmetries:\n # Remove elements of `verts` that are reflections of each other.\n match = compute_sq_dist(verts.T, -verts.T) < eps\n verts = verts[~np.any(np.triu(match), axis=0), :]\n\n basis = verts[:, ::-1]\n return basis\n```\n###test function signature:\n```python\n def test_generate_basis_golden(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_basis_golden", + "reference": " def test_generate_basis_golden(self):\n\n basis = geopoly.generate_basis('tetrahedron', 1)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('tetrahedron', 2)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.57735027, 0.00000000, -0.81649658],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.57735027, -0.70710678, 0.40824829],\n [-0.57735027, 0.70710678, 0.40824829],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('icosahedron', 2)\n basis_golden = np.array([\n [0.85065081, 0.00000000, 0.52573111],\n [0.80901699, 0.50000000, 0.30901699],\n [0.52573111, 0.85065081, 0.00000000],\n [1.00000000, 0.00000000, 0.00000000],\n [0.80901699, 0.50000000, -0.30901699],\n [0.85065081, 0.00000000, -0.52573111],\n [0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, -0.85065081],\n [0.50000000, 0.30901699, -0.80901699],\n [0.00000000, 1.00000000, 0.00000000],\n [-0.52573111, 0.85065081, 0.00000000],\n [-0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, 0.85065081],\n [-0.30901699, 0.80901699, 0.50000000],\n [0.30901699, 0.80901699, 0.50000000],\n [0.50000000, 0.30901699, 0.80901699],\n [0.50000000, -0.30901699, 0.80901699],\n [0.00000000, 0.00000000, 1.00000000],\n [-0.50000000, 0.30901699, 0.80901699],\n [-0.80901699, 0.50000000, 0.30901699],\n [-0.80901699, 0.50000000, -0.30901699],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('octahedron', 4)\n basis_golden = np.array([\n [0.00000000, 0.00000000, -1.00000000],\n [0.00000000, -0.31622777, -0.94868330],\n [0.00000000, -0.70710678, -0.70710678],\n [0.00000000, -0.94868330, -0.31622777],\n [0.00000000, -1.00000000, 0.00000000],\n [-0.31622777, 0.00000000, -0.94868330],\n [-0.40824829, -0.40824829, -0.81649658],\n [-0.40824829, -0.81649658, -0.40824829],\n [-0.31622777, -0.94868330, 0.00000000],\n [-0.70710678, 0.00000000, -0.70710678],\n [-0.81649658, -0.40824829, -0.40824829],\n [-0.70710678, -0.70710678, 0.00000000],\n [-0.94868330, 0.00000000, -0.31622777],\n [-0.94868330, -0.31622777, 0.00000000],\n [-1.00000000, 0.00000000, 0.00000000],\n [0.00000000, -0.31622777, 0.94868330],\n [0.00000000, -0.70710678, 0.70710678],\n [0.00000000, -0.94868330, 0.31622777],\n [0.40824829, -0.40824829, 0.81649658],\n [0.40824829, -0.81649658, 0.40824829],\n [0.31622777, -0.94868330, 0.00000000],\n [0.81649658, -0.40824829, 0.40824829],\n [0.70710678, -0.70710678, 0.00000000],\n [0.94868330, -0.31622777, 0.00000000],\n [0.31622777, 0.00000000, -0.94868330],\n [0.40824829, 0.40824829, -0.81649658],\n [0.40824829, 0.81649658, -0.40824829],\n [0.70710678, 0.00000000, -0.70710678],\n [0.81649658, 0.40824829, -0.40824829],\n [0.94868330, 0.00000000, -0.31622777],\n [0.40824829, -0.40824829, -0.81649658],\n [0.40824829, -0.81649658, -0.40824829],\n [0.81649658, -0.40824829, -0.40824829],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))", + "completions": [ + "```python\n def test_generate_basis_golden(self):\n basis = generate_basis(\n base_shape='tetrahedron', angular_tesselation=10, remove_symmetries=True, eps=1e-4\n )\n self.assertTrue(np.allclose(basis, np.array([[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0]])))\n```\n" + ], + "line": 75, + "token": 692, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 41, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:\n if not isinstance(t, torch.Tensor):\n t = torch.Tensor([t]).float()\n\n t_max = self.t_max.to(t.device)\n t_min = self.t_min.to(t.device)\n t_tilde = t_min + (t_max - t_min) * t\n\n return t_tilde\n```\n###test function signature:\n```python\n\ndef test_noise_schedule_diffusion_coeff(kind):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_noise_schedule_diffusion_coeff", + "reference": "\ndef test_noise_schedule_diffusion_coeff(kind):\n noise_schedule = GaussianNoiseSchedule(kind=kind, log_snr_range=(-6, 6))\n ts = torch.linspace(1e-2, 1 - 1e-2, 10)\n t_map = noise_schedule.t_map(ts) # map time to the prescribed log_SNR range\n\n if kind == \"log_snr\":\n g = noise_schedule.g(ts)\n\n # compute true beta_t\n l_range = noise_schedule.log_snr_range\n snr = torch.exp((1 - t_map) * l_range[1] + t_map * l_range[0])\n g_true = (-(l_range[0] - l_range[1]) / (snr + 1)).sqrt()\n\n assert torch.allclose(g, g_true, atol=1e-4)\n\n if kind == \"ot_linear\":\n g = noise_schedule.g(ts)\n g_true = (2.0 * t_map / (1.0 - t_map)).sqrt()\n assert torch.allclose(g, g_true, atol=1e-4)", + "completions": [ + "```python\ndef test_noise_schedule_diffusion_coeff(kind):\n from torch.testing import assert_allclose\n from torch.nn import functional as F\n from torch import tensor\n from torch.distributions import Uniform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import Transform\n from torch.distributions.transforms import SigmoidTransform\n from torch.distributions.transforms import TanhTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import ComposeTransform\n from torch.distributions.transforms import PowerTransform\n from torch.distributions.transforms import ExpTransform\n from torch.distributions.transforms import LogTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import" + ], + "line": 17, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 42, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_data_quality", + "reference": "\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index, temperature_data, temperature_mean=False, data_quality=True\n )\n assert df.shape == (27, 4)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_not_null\",\n \"temperature_null\",\n ]\n assert round(df.temperature_not_null.mean(), 2) == 729.23\n assert round(df.temperature_null.mean(), 2) == 0.0", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"1H\", tz=\"America/Chicago\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0" + ], + "line": 168, + "token": 984, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 43, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n\n # when\n result = http_client.unload_model(model_id=\"some/1\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"some/1\",\n }", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n model_id = \"my_model_id\"\n de_aliased_model_id = \"my_de_aliased_model_id\"\n response_payload = {\n \"model_id\": de_aliased_model_id,\n \"model_name\": \"my_model_name\",\n \"model_version\": \"my_model_version\",\n \"model_type\": \"my_model_type\",\n \"model_status\": \"my_model_status\",\n \"model_description\": \"my_model_description\",\n \"model_tags\": [\"my_model_tag1\", \"my_model_tag2\"],\n \"model_created_at\": \"my_model_created_at\",\n \"model_updated_at\": \"my_model_updated_at\",\n }\n requests_mock.post(\n f\"{self.__api_url}/model/remove\",\n json=response_payload,\n headers=DEFAULT_HEADERS,\n )\n # when\n response = self.client.unload_model(model_id=model_id)\n # then\n assert response.model_id == de_aliased_model_id\n assert response.model_name == \"my_model_name\"\n assert response.model_version == \"my_model_version\"\n assert response.model_type == \"my_model_type\"\n assert response.model_status == \"my_model_status\"\n assert response.model_description == \"my_model_description\"\n assert response.model_tags == [\"my_model_tag1\", \"my_model_tag2\"]\n assert response.model_created_at == \"my_model_created_at\"\n assert response.model_updated_at == \"my_model_updated_at\"\n assert self.client.selected_model is None\n```\n" + ], + "line": 31, + "token": 162, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=20.0,\n max_iteration=100,\n distance_strict=False,\n seed=None,\n):\n rng = np.random.default_rng(seed=seed)\n units_locations = np.zeros((num_units, 3), dtype=\"float32\")\n\n minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um\n minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um\n\n units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)\n units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)\n units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)\n\n if minimum_distance is not None:\n solution_found = False\n renew_inds = None\n for i in range(max_iteration):\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n inds0, inds1 = np.nonzero(distances < minimum_distance)\n mask = inds0 != inds1\n inds0 = inds0[mask]\n inds1 = inds1[mask]\n\n if inds0.size > 0:\n if renew_inds is None:\n renew_inds = np.unique(inds0)\n else:\n # random only bad ones in the previous set\n renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]\n\n units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)\n units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)\n units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)\n else:\n solution_found = True\n break\n\n if not solution_found:\n if distance_strict:\n raise ValueError(\n f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} \"\n \"You can use distance_strict=False or reduce minimum distance\"\n )\n else:\n warnings.warn(f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}\")\n\n return units_locations\n```\n###test function signature:\n```python\n\ndef test_generate_unit_locations():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_unit_locations", + "reference": "\ndef test_generate_unit_locations():\n seed = 0\n\n probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)\n channel_locations = probe.contact_positions\n\n num_units = 100\n minimum_distance = 20.0\n unit_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=minimum_distance,\n max_iteration=500,\n distance_strict=False,\n seed=seed,\n )\n distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)\n dist_flat = np.triu(distances, k=1).flatten()\n dist_flat = dist_flat[dist_flat > 0]\n assert np.all(dist_flat > minimum_distance)\n\n # import matplotlib.pyplot as plt\n # fig, ax = plt.subplots()\n # ax.hist(dist_flat, bins = np.arange(0, 400, 10))\n # fig, ax = plt.subplots()\n # from probeinterface.plotting import plot_probe\n # plot_probe(probe, ax=ax)\n # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)\n # plt.show()", + "completions": [ + "```python\n\ndef test_generate_unit_locations():\n num_units = 10\n channel_locations = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])\n margin_um = 20.0\n minimum_z = 5.0\n maximum_z = 40.0\n minimum_distance = 20.0\n max_iteration = 100\n distance_strict = False\n seed = None\n\n units_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=margin_um,\n minimum_z=minimum_z,\n maximum_z=maximum_z,\n minimum_distance=minimum_distance,\n max_iteration=max_iteration,\n distance_strict=distance_strict,\n seed=seed,\n )\n\n assert units_locations.shape == (num_units, 3)\n assert np.all(units_locations[:, 0] >= margin_um)\n assert np.all(units_locations[:, 0] <= 100 - margin_um)\n assert np.all(units_locations[:, 1] >= margin_um)\n assert np.all(units_locations[:, 1] <= 100 - margin_um)\n assert np.all(units_locations[:, 2] >= minimum_z)\n assert np.all(units_locations[:, 2] <= maximum_z)\n\n # check distance\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n assert np.all(distances >= minimum_distance)\n\n # check no duplicate\n assert np.all(np.unique(units_locations, axis=0) == units_locations)\n```\n" + ], + "line": 58, + "token": 453, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 45, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def to_dataframe(\n self, index: Optional[Union[str, None]] = \"date\", sort_by: Optional[str] = None\n ) -> pd.DataFrame:\n\n def is_list_of_basemodel(items: Union[List[T], T]) -> bool:\n return isinstance(items, list) and all(\n isinstance(item, BaseModel) for item in items\n )\n\n if self.results is None or not self.results:\n raise OpenBBError(\"Results not found.\")\n\n if isinstance(self.results, pd.DataFrame):\n return self.results\n\n try:\n res = self.results\n df = None\n sort_columns = True\n\n # List[Dict]\n if isinstance(res, list) and len(res) == 1 and isinstance(res[0], dict):\n r = res[0]\n dict_of_df = {}\n\n for k, v in r.items():\n # Dict[str, List[BaseModel]]\n if is_list_of_basemodel(v):\n dict_of_df[k] = basemodel_to_df(v, index)\n sort_columns = False\n # Dict[str, Any]\n else:\n dict_of_df[k] = pd.DataFrame(v)\n\n df = pd.concat(dict_of_df, axis=1)\n\n # List[BaseModel]\n elif is_list_of_basemodel(res):\n dt: Union[List[Data], Data] = res # type: ignore\n df = basemodel_to_df(dt, index)\n sort_columns = False\n # List[List | str | int | float] | Dict[str, Dict | List | BaseModel]\n else:\n try:\n df = pd.DataFrame(res)\n # Set index, if any\n if index is not None and index in df.columns:\n df.set_index(index, inplace=True)\n\n except ValueError:\n if isinstance(res, dict):\n df = pd.DataFrame([res])\n\n if df is None:\n raise OpenBBError(\"Unsupported data format.\")\n\n # Drop columns that are all NaN, but don't rearrange columns\n if sort_columns:\n df.sort_index(axis=1, inplace=True)\n df = df.dropna(axis=1, how=\"all\")\n\n # Sort by specified column\n if sort_by:\n df.sort_values(by=sort_by, inplace=True)\n\n except OpenBBError as e:\n raise e\n except ValueError as ve:\n raise OpenBBError(\n f\"ValueError: {ve}. Ensure the data format matches the expected format.\"\n ) from ve\n except TypeError as te:\n raise OpenBBError(\n f\"TypeError: {te}. Check the data types in your results.\"\n ) from te\n except Exception as ex:\n raise OpenBBError(f\"An unexpected error occurred: {ex}\") from ex\n\n return df\n```\n###test function signature:\n```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_to_dataframe_w_args", + "reference": "def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n co = OBBject(results=results)\n\n # Act and Assert\n result = co.to_dataframe(index=index, sort_by=sort_by)\n assert isinstance(result, pd.DataFrame)\n assert result.index.name == index\n\n # check if dataframe is properly sorted\n assert result[sort_by].is_monotonic_increasing", + "completions": [ + "```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n from openbb_terminal.decorators import log_start_end\n from openbb_terminal.stocks.fundamental_analysis import financedatabase_model\n\n @log_start_end(log=logger)\n def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n df = financedatabase_model.to_dataframe(results, index, sort_by)\n # Act\n # Assert\n assert df.shape[0] > 0\n assert df.shape[1] > 0\n assert df.index.name == index\n assert df.columns.tolist() == sort_by\n\n # Act\n test_to_dataframe_w_args(results, index, sort_by)\n # Assert\n```\n" + ], + "line": 76, + "token": 566, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 46, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef unique(ar, return_index=False, return_inverse=False,\n return_counts=False, axis=None, *, equal_nan=True):\n ar = np.asanyarray(ar)\n if axis is None:\n ret = _unique1d(ar, return_index, return_inverse, return_counts, \n equal_nan=equal_nan)\n return _unpack_tuple(ret)\n\n # axis was specified and not None\n try:\n ar = np.moveaxis(ar, axis, 0)\n except np.AxisError:\n # this removes the \"axis1\" or \"axis2\" prefix from the error message\n raise np.AxisError(axis, ar.ndim) from None\n\n # Must reshape to a contiguous 2D array for this to work...\n orig_shape, orig_dtype = ar.shape, ar.dtype\n ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))\n ar = np.ascontiguousarray(ar)\n dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]\n\n # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured\n # data type with `m` fields where each field has the data type of `ar`.\n # In the following, we create the array `consolidated`, which has\n # shape `(n,)` with data type `dtype`.\n try:\n if ar.shape[1] > 0:\n consolidated = ar.view(dtype)\n else:\n # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is\n # a data type with itemsize 0, and the call `ar.view(dtype)` will\n # fail. Instead, we'll use `np.empty` to explicitly create the\n # array with shape `(len(ar),)`. Because `dtype` in this case has\n # itemsize 0, the total size of the result is still 0 bytes.\n consolidated = np.empty(len(ar), dtype=dtype)\n except TypeError as e:\n # There's no good way to do this for object arrays, etc...\n msg = 'The axis argument to unique is not supported for dtype {dt}'\n raise TypeError(msg.format(dt=ar.dtype)) from e\n\n def reshape_uniq(uniq):\n n = len(uniq)\n uniq = uniq.view(orig_dtype)\n uniq = uniq.reshape(n, *orig_shape[1:])\n uniq = np.moveaxis(uniq, 0, axis)\n return uniq\n\n output = _unique1d(consolidated, return_index,\n return_inverse, return_counts, equal_nan=equal_nan)\n output = (reshape_uniq(output[0]),) + output[1:]\n return _unpack_tuple(output)\n```\n###test function signature:\n```python\n\n def test_unique_1d(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_1d", + "reference": "\n def test_unique_1d(self):\n\n def check_all(a, b, i1, i2, c, dt):\n base_msg = 'check {0} failed for type {1}'\n\n msg = base_msg.format('values', dt)\n v = unique(a)\n assert_array_equal(v, b, msg)\n\n msg = base_msg.format('return_index', dt)\n v, j = unique(a, True, False, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i1, msg)\n\n msg = base_msg.format('return_inverse', dt)\n v, j = unique(a, False, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i2, msg)\n\n msg = base_msg.format('return_counts', dt)\n v, j = unique(a, False, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, c, msg)\n\n msg = base_msg.format('return_index and return_inverse', dt)\n v, j1, j2 = unique(a, True, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n\n msg = base_msg.format('return_index and return_counts', dt)\n v, j1, j2 = unique(a, True, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format('return_inverse and return_counts', dt)\n v, j1, j2 = unique(a, False, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i2, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format(('return_index, return_inverse '\n 'and return_counts'), dt)\n v, j1, j2, j3 = unique(a, True, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n assert_array_equal(j3, c, msg)\n\n a = [5, 7, 1, 2, 1, 5, 7]*10\n b = [1, 2, 5, 7]\n i1 = [2, 3, 0, 1]\n i2 = [2, 3, 0, 1, 0, 2, 3]*10\n c = np.multiply([2, 1, 2, 2], 10)\n\n # test for numeric arrays\n types = []\n types.extend(np.typecodes['AllInteger'])\n types.extend(np.typecodes['AllFloat'])\n types.append('datetime64[D]')\n types.append('timedelta64[D]')\n for dt in types:\n aa = np.array(a, dt)\n bb = np.array(b, dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for object arrays\n dt = 'O'\n aa = np.empty(len(a), dt)\n aa[:] = a\n bb = np.empty(len(b), dt)\n bb[:] = b\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for structured arrays\n dt = [('', 'i'), ('', 'i')]\n aa = np.array(list(zip(a, a)), dt)\n bb = np.array(list(zip(b, b)), dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for ticket #2799\n aa = [1. + 0.j, 1 - 1.j, 1]\n assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])\n\n # test for ticket #4785\n a = [(1, 2), (1, 2), (2, 3)]\n unq = [1, 2, 3]\n inv = [0, 1, 0, 1, 1, 2]\n a1 = unique(a)\n assert_array_equal(a1, unq)\n a2, a2_inv = unique(a, return_inverse=True)\n assert_array_equal(a2, unq)\n assert_array_equal(a2_inv, inv)\n\n # test for chararrays with return_inverse (gh-5099)\n a = np.chararray(5)\n a[...] = ''\n a2, a2_inv = np.unique(a, return_inverse=True)\n assert_array_equal(a2_inv, np.zeros(5))\n\n # test for ticket #9137\n a = []\n a1_idx = np.unique(a, return_index=True)[1]\n a2_inv = np.unique(a, return_inverse=True)[1]\n a3_idx, a3_inv = np.unique(a, return_index=True,\n return_inverse=True)[1:]\n assert_equal(a1_idx.dtype, np.intp)\n assert_equal(a2_inv.dtype, np.intp)\n assert_equal(a3_idx.dtype, np.intp)\n assert_equal(a3_inv.dtype, np.intp)\n\n # test for ticket 2111 - float\n a = [2.0, np.nan, 1.0, np.nan]\n ua = [1.0, 2.0, np.nan]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - complex\n a = [2.0-1j, np.nan, 1.0+1j, complex(0.0, np.nan), complex(1.0, np.nan)]\n ua = [1.0+1j, 2.0-1j, complex(0.0, np.nan)]\n ua_idx = [2, 0, 3]\n ua_inv = [1, 2, 0, 2, 2]\n ua_cnt = [1, 1, 3]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - datetime64\n nat = np.datetime64('nat')\n a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]\n ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - timedelta\n nat = np.timedelta64('nat')\n a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]\n ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for gh-19300\n all_nans = [np.nan] * 4\n ua = [np.nan]\n ua_idx = [0]\n ua_inv = [0, 0, 0, 0]\n ua_cnt = [4]\n assert_equal(np.unique(all_nans), ua)\n assert_equal(np.unique(all_nans, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(all_nans, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(all_nans, return_counts=True), (ua, ua_cnt))", + "completions": [ + "```python\n def test_unique_1d(self):\n ar = np.array([1, 2, 3, 4, 5, 5, 5, 5, 5, 5])\n self.assertEqual(unique(ar), np.array([1, 2, 3, 4, 5]))\n self.assertEqual(unique(ar, return_index=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n```\n" + ], + "line": 56, + "token": 571, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 47, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef scaled_dot_product_attention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n att_mask: Optional[Union[AttentionMask, \"SparseCS\", torch.Tensor]],\n dropout: Optional[torch.nn.Module] = None,\n block_size: int = 128,\n) -> torch.Tensor:\n autocast_disabled = (\n _has_cpp_library\n and isinstance(att_mask, SparseCS)\n or (att_mask is not None and att_mask.is_sparse)\n )\n seq_len = q.shape[-2]\n\n # switch if:\n # causal is required but mask is not sparse\n # fp16 or under amp context\n # sequence length is divisible by block size\n # same seq len for K and Q\n switch_to_blocksparse = (\n _is_blocksparse_available\n and (att_mask is not None and not att_mask.is_sparse)\n and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)\n and (q.dtype == torch.float16 or torch.is_autocast_enabled())\n and not seq_len % block_size\n and q.shape[-2] == k.shape[-2]\n )\n\n if switch_to_blocksparse:\n logger.info(\"Switching causal attention to Triton blocksparse...\")\n return blocksparse_attention(q, k, v, dropout, block_size)\n\n with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore\n if autocast_disabled:\n q, k, v = q.float(), k.float(), v.float()\n\n att = scaled_query_key_softmax(q, k, att_mask=att_mask)\n\n # Optional dropout, could be part of the masking in the future\n att = _apply_dropout(att, dropout)\n\n # Get to the predicted values, for all heads\n # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)\n y = bmm(att, v)\n return y\n```\n###test function signature:\n```python\n\ndef test_amp_attention_dense(device):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_amp_attention_dense", + "reference": "\ndef test_amp_attention_dense(device):\n b, s, d = 8, 64, 32\n prob = 0.9\n\n a = torch.rand(b, s, d, device=device)\n m = torch.rand(s, s, device=device) > prob\n\n with torch.cuda.amp.autocast():\n r = scaled_dot_product_attention(a, a, a, m)\n\n expected_device = torch.float16 if device == \"cuda\" else torch.float32\n assert r.dtype == expected_device", + "completions": [ + "```python\n\ndef test_amp_attention_dense(device):\n q = torch.randn(1, 128, 128, device=device)\n k = torch.randn(1, 128, 128, device=device)\n v = torch.randn(1, 128, 128, device=device)\n att_mask = torch.randn(1, 128, 128, device=device)\n dropout = torch.nn.Dropout(0.1)\n block_size = 128\n y = scaled_dot_product_attention(q, k, v, att_mask, dropout, block_size)\n assert y.shape == (1, 128, 128)\n```\n" + ], + "line": 50, + "token": 390, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 48, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def quantize(\n cls,\n tensor: Tensor,\n nbits: int = 4,\n channel_wise: bool = True,\n group_size: int = 64,\n optimize: bool = False,\n round_zero: bool = False,\n axis: int = 0,\n bitpack: bool = True,\n compute_dtype: torch.dtype | None = None,\n view_as_float: bool = False,\n device: str = \"cuda\",\n ) -> tuple:\n assert nbits in Quantizer.SUPPORTED_BITS, (\n \"nbits=\" + str(nbits) + \" not supported.\"\n )\n assert axis in [0, 1], \"axis should be either 0 or 1\"\n if group_size is not None:\n assert is_divisible(tensor.numel(), group_size), (\n \"group_size should be divisble by the total tensor dimensions. shape: \"\n + str(tensor.shape)\n + \", group_size: \"\n + str(group_size)\n )\n\n W = tensor.float()\n shape = W.shape\n\n # Reshape for grouping\n if (group_size is not None) and channel_wise:\n W = (\n W.reshape([-1, group_size])\n if (axis == 1)\n else W.reshape([group_size, -1])\n )\n\n # Get min/max values\n if not channel_wise:\n _min, _max = W.min(), W.max()\n optimize = False\n else:\n _min = W.min(axis=axis, keepdim=True)[0]\n _max = W.max(axis=axis, keepdim=True)[0]\n\n max_v = 2**nbits - 1\n min_v = 0\n min_max = [min_v, max_v]\n\n # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.\n scale = (max_v / (_max - _min)).clamp(\n max=2e4\n ) # clamp to avoid half-precision problems\n zero = -_min * scale\n\n # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14\n if round_zero:\n zero = torch.round(zero)\n\n # Fine-tune weights\n if optimize:\n W_q, scale, zero = Quantizer.optimize_weights(\n tensor=W,\n scale=scale,\n zero=zero,\n min_max=min_max,\n axis=axis,\n device=device,\n )\n else:\n W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])\n\n # Store meta-data (we invert the scale for dequantization)\n meta = {\n \"nbits\": nbits,\n \"group_size\": group_size,\n \"shape\": shape,\n \"scale\": 1.0 / scale,\n \"zero\": zero,\n \"axis\": axis,\n \"packing\": Quantizer.bit_to_packing[nbits],\n }\n meta[\"unpack_view_dtype\"] = Quantizer.unpack_view_dtype[meta[\"packing\"]]\n\n # Pack bits\n meta[\"view_as_float\"] = view_as_float\n if bitpack:\n W_q = Quantizer.pack[meta[\"packing\"]](W_q)\n if view_as_float:\n W_q = W_q.view(\n torch.float32 if compute_dtype is None else compute_dtype\n ) # store quantized weights as compute_dtype\n else:\n W_q = W_q.to(tensor.dtype)\n meta[\"packing\"] = None\n\n # cleanup\n del W, _min, _max\n torch.cuda.empty_cache()\n\n return W_q, meta\n```\n###test function signature:\n```python\n\n def test_quantizer(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_quantizer", + "reference": "\n def test_quantizer(self): \n for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:\n for nbits in [8,4,3,2,1]:\n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, view_as_float=False)\n if nbits == 3:\n assert W_q.dtype == torch.int32\n else:\n assert W_q.dtype == torch.uint8\n w_dq = Quantizer.dequantize(W_q, meta)\n norm1 = torch.norm(self.w - w_dq, p=0.7)\n \n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, compute_dtype=compute_dtype, view_as_float=True)\n assert W_q.dtype == compute_dtype\n w_dq = Quantizer.dequantize(W_q, meta)\n norm2 = torch.norm(self.w - w_dq, p=0.7)\n \n self.assertTrue(torch.equal(norm1, norm2))", + "completions": [ + "```python\n\n def test_quantizer(self):\n # test case 1\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 2\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=True, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], 0.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 3\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=True, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 4\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor," + ], + "line": 100, + "token": 645, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 49, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_num_ref_area", + "reference": "\ndef test_nanoplot_out_bars_with_num_ref_area():\n\n out_bars_with_num_ref_area = _generate_nanoplot(**CASES[13])\n\n assert _is_nanoplot_output(out_bars_with_num_ref_area)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_area = [1, 3]\n y_ref_line = 2\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_" + ], + "line": 749, + "token": 7038, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 50, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef mean(x, axis=None, keepdims=False):\n if isinstance(x, (list, tuple)):\n x = stack(x)\n x = convert_to_tensor(x)\n if axis == () or axis == []:\n # Torch handles the empty axis case differently from numpy.\n return x\n axis = to_tuple_or_list(axis) # see [NB] below\n\n ori_dtype = standardize_dtype(x.dtype)\n # torch.mean only supports floating point inputs\n compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n if \"int\" in ori_dtype or ori_dtype == \"bool\":\n result_dtype = compute_dtype\n else:\n result_dtype = ori_dtype\n\n # [NB] the python torch op torch.mean() is generated into\n # `torch._C._VariableFunctions.pyi`, and the method\n # signature is overloaded.\n # Dynamo won't actually find the correct signature of\n # `torch.mean()` if arguments are passed via kwargs\n # So we have to pass the arguments via positional args\n # EXCEPT for those that are forced as kwargs via the `*`\n # delimiter in the overloaded method signatures.\n # Additionally, we have to create a singleton-tuple\n # when `axis` is an int to match the existing fn signature\n result = torch.mean(\n x,\n axis,\n keepdims,\n dtype=to_torch_dtype(compute_dtype),\n )\n return cast(result, result_dtype)\n```\n###test function signature:\n```python\n\n def test_mean(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mean", + "reference": "\n def test_mean(self):\n x = np.array([[1, 2, 3], [3, 2, 1]])\n self.assertAllClose(knp.mean(x), np.mean(x))\n self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n self.assertAllClose(\n knp.mean(x, axis=1, keepdims=True),\n np.mean(x, axis=1, keepdims=True),\n )\n\n self.assertAllClose(knp.Mean()(x), np.mean(x))\n self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n self.assertAllClose(\n knp.Mean(axis=1, keepdims=True)(x),\n np.mean(x, axis=1, keepdims=True),\n )\n\n # test overflow\n x = np.array([65504, 65504, 65504], dtype=\"float16\")\n self.assertAllClose(knp.mean(x), np.mean(x))", + "completions": [ + "```python\n\n def test_mean(self):\n x = torch.randn(2, 3, 4)\n self.assertEqual(mean(x), torch.mean(x))\n self.assertEqual(mean(x, axis=0), torch.mean(x, dim=0))\n self.assertEqual(mean(x, axis=1), torch.mean(x, dim=1))\n self.assertEqual(mean(x, axis=2), torch.mean(x, dim=2))\n self.assertEqual(mean(x, axis=(0, 1)), torch.mean(x, dim=(0, 1)))\n self.assertEqual(mean(x, axis=(0, 2)), torch.mean(x, dim=(0, 2)))\n self.assertEqual(mean(x, axis=(1, 2)), torch.mean(x, dim=(1, 2)))\n self.assertEqual(mean(x, axis=(0, 1, 2)), torch.mean(x, dim=(0, 1, 2)))\n self.assertEqual(mean(x, axis=0, keepdims=True), torch.mean(x, dim=0, keepdim=True))\n self.assertEqual(mean(x, axis=1, keepdims=True), torch.mean(x, dim=1, keepdim=True))\n self.assertEqual(mean(x, axis=2, keepdims=True), torch.mean(x, dim=2, keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1), keepdims=True), torch.mean(x, dim=(0, 1), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 2), keepdims=True), torch.mean(x, dim=(0, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(1, 2), keepdims=True), torch.mean(x, dim=(1, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1, 2), keepdims=True), torch.mean(x, dim=(0, 1, 2), keepdim=True))\n```\n" + ], + "line": 42, + "token": 322, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 51, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_intrinsic_value(\n cash_flow: float,\n growth_rate: float,\n perpetual_growth_rate: float,\n weighted_average_cost_of_capital: float,\n cash_and_cash_equivalents: float,\n total_debt: float,\n shares_outstanding: float,\n periods: int = 5,\n) -> pd.DataFrame:\n components = {}\n\n cash_flow_projection = [cash_flow]\n\n # Cash Flow to Use\n for period in range(1, periods + 1):\n if period == 1:\n cash_flow_projection.append(cash_flow_projection[0] * (1 + growth_rate))\n else:\n cash_flow_projection.append(\n cash_flow_projection[period - 1] * (1 + growth_rate)\n )\n\n # Calculate the Terminal Value\n terminal_value = (\n cash_flow_projection[-1]\n * (1 + perpetual_growth_rate)\n / (weighted_average_cost_of_capital - perpetual_growth_rate)\n )\n\n # Add Terminal Value to the end of the cash flow projection\n cash_flow_projection[-1] = cash_flow_projection[-1] + terminal_value\n\n # Calculate the Present Value based on the Discounted Cash Flow Formula\n cash_flow_present_value = []\n for index, cash_flow_value in enumerate(cash_flow_projection):\n cash_flow_present_value.append(\n cash_flow_value / (1 + weighted_average_cost_of_capital) ** (index + 1)\n )\n\n # Calculate the Enterprise Value\n enterprise_value = sum(cash_flow_present_value)\n\n # Calculate the Equity Value\n equity_value = enterprise_value + cash_and_cash_equivalents - total_debt\n\n # Calculate the Intrinsic Value\n intrinsic_value = equity_value / shares_outstanding\n\n # Combine the components into a dictionary\n components = {\n \"Terminal Value\": terminal_value,\n \"Cash Flow Projection\": cash_flow_projection[-1],\n \"Enterprise Value\": enterprise_value,\n \"Equity Value\": equity_value,\n \"Intrinsic Value\": intrinsic_value,\n }\n\n return pd.DataFrame.from_dict(\n components, orient=\"index\", columns=[f\"Periods = {periods}\"]\n )\n```\n###test function signature:\n```python\n\ndef test_get_intrinsic_value(recorder):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_intrinsic_value", + "reference": "\ndef test_get_intrinsic_value(recorder):\n recorder.capture(\n intrinsic_model.get_intrinsic_value(\n cash_flow=500,\n growth_rate=0.05,\n perpetual_growth_rate=0.02,\n weighted_average_cost_of_capital=0.1,\n cash_and_cash_equivalents=100,\n total_debt=100,\n shares_outstanding=100,\n periods=5,\n )\n )", + "completions": [ + "```python\n\ndef test_get_intrinsic_value(recorder):\n cash_flow = 1000000\n growth_rate = 0.05\n perpetual_growth_rate = 0.05\n weighted_average_cost_of_capital = 0.1\n cash_and_cash_equivalents = 1000000\n total_debt = 1000000\n shares_outstanding = 1000000\n periods = 5\n\n result = get_intrinsic_value(\n cash_flow,\n growth_rate,\n perpetual_growth_rate,\n weighted_average_cost_of_capital,\n cash_and_cash_equivalents,\n total_debt,\n shares_outstanding,\n periods,\n )\n\n recorder.capture(result)\n```\n" + ], + "line": 61, + "token": 379, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 52, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef to_numpy(\n evset: EventSet,\n timestamp_to_datetime: bool = True,\n timestamps: bool = True,\n) -> Dict[str, ndarray]:\n timestamp_key = \"timestamp\"\n index_names = evset.schema.index_names()\n feature_names = evset.schema.feature_names()\n\n column_names = index_names + feature_names\n if timestamps:\n column_names += [timestamp_key]\n\n dst = {column_name: [] for column_name in column_names}\n for index, data in evset.data.items():\n assert isinstance(index, tuple)\n\n if timestamps:\n # Timestamps\n if evset.schema.is_unix_timestamp and timestamp_to_datetime:\n dst[timestamp_key].append(\n data.timestamps.astype(\"datetime64[s]\")\n )\n else:\n dst[timestamp_key].append(data.timestamps)\n\n # Features\n for feature_name, feature in zip(feature_names, data.features):\n dst[feature_name].append(feature)\n\n # Indexes\n num_timestamps = len(data.timestamps)\n for index_name, index_item in zip(index_names, index):\n dst[index_name].append(np.repeat(index_item, num_timestamps))\n\n dst = {k: np.concatenate(v) for k, v in dst.items()}\n return dst\n```\n###test function signature:\n```python\n\n def test_correct(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct", + "reference": "\n def test_correct(self):\n evset = event_set(\n timestamps=[\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n features={\n \"feature_1\": [0.5, 0.6],\n \"my_index\": [\"red\", \"blue\"],\n },\n indexes=[\"my_index\"],\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"timestamp\": np.array(\n [\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n dtype=\"datetime64[s]\",\n ),\n \"feature_1\": np.array([0.5, 0.6]),\n \"my_index\": np.array([b\"red\", b\"blue\"]),\n }\n\n for k in expected:\n np.testing.assert_array_equal(\n np.sort(result[k]), np.sort(expected[k])\n )", + "completions": [ + "```python\n def test_correct(self):\n evset = EventSet(\n data={\n \"timestamp\": [1610000000, 1610000000, 1610000000, 1610000000],\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n },\n schema=EventSetSchema(\n index_names=[\"index\"],\n feature_names=[\"feature1\", \"feature2\"],\n is_unix_timestamp=True,\n ),\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n \"timestamp\": [\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n ],\n }\n\n self.assertDictEqual(result, expected)\n```\n" + ], + "line": 41, + "token": 276, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 53, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def from_file(cls, path: str) -> \"LayerConfig\":\n logger.info(f\"Loading route config from {path}\")\n _, ext = os.path.splitext(path)\n with open(path, \"r\") as f:\n if ext == \".json\":\n layer = json.load(f)\n elif ext in [\".yaml\", \".yml\"]:\n layer = yaml.safe_load(f)\n else:\n raise ValueError(\n \"Unsupported file type. Only .json and .yaml are supported\"\n )\n\n if not is_valid(json.dumps(layer)):\n raise Exception(\"Invalid config JSON or YAML\")\n\n encoder_type = layer[\"encoder_type\"]\n encoder_name = layer[\"encoder_name\"]\n routes = []\n for route_data in layer[\"routes\"]:\n # Handle the 'llm' field specially if it exists\n if \"llm\" in route_data and route_data[\"llm\"] is not None:\n llm_data = route_data.pop(\n \"llm\"\n ) # Remove 'llm' from route_data and handle it separately\n # Use the module path directly from llm_data without modification\n llm_module_path = llm_data[\"module\"]\n # Dynamically import the module and then the class from that module\n llm_module = importlib.import_module(llm_module_path)\n llm_class = getattr(llm_module, llm_data[\"class\"])\n # Instantiate the LLM class with the provided model name\n llm = llm_class(name=llm_data[\"model\"])\n # Reassign the instantiated llm object back to route_data\n route_data[\"llm\"] = llm\n\n # Dynamically create the Route object using the remaining route_data\n route = Route(**route_data)\n routes.append(route)\n\n return cls(\n encoder_type=encoder_type, encoder_name=encoder_name, routes=routes\n )\n```\n###test function signature:\n```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_from_file_yaml", + "reference": "\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n config_path = tmp_path / \"config.yaml\"\n config_path.write_text(\n layer_yaml()\n ) # Assuming layer_yaml() returns a valid YAML string\n\n # Load the LayerConfig from the temporary file\n layer_config = LayerConfig.from_file(str(config_path))\n\n # Assertions to verify the loaded configuration\n assert layer_config.encoder_type == \"cohere\"\n assert layer_config.encoder_name == \"embed-english-v3.0\"\n assert len(layer_config.routes) == 2\n assert layer_config.routes[0].name == \"politics\"", + "completions": [ + "```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n yaml_file = tmp_path / \"test_layer.yaml\"\n yaml_file.write_text(\n \"\"\"\n encoder_type: openai\n encoder_name: gpt2\n routes:\n - name: test_route\n model_name: test_model\n llm:\n module: tests.test_layer\n class: TestLLM\n model: test_model\n encoder_kwargs:\n max_position_embeddings: 512\n num_attention_heads: 8\n intermediate_size: 2048\n hidden_act: gelu\n hidden_dropout_prob: 0.1\n attention_probs_dropout_prob: 0.1\n type_vocab_size: 16\n initializer_range: 0.02\n layer_norm_eps: 1e-12\n use_cache: True\n is_decoder: False\n classifier_dropout: 0.1\n use_multiplicative_input_mask: False\n num_labels: 2\n output_attentions: False\n output_hidden_states: False\n return_dict: True\n task_specific_params:\n summarization:\n max_length: 20\n min_length: 5\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n" + ], + "line": 49, + "token": 405, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 54, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def bind_all(self, circuit: QuantumCircuit) -> np.ndarray:\n arr = np.empty(self.shape, dtype=object)\n for idx in np.ndindex(self.shape):\n arr[idx] = self.bind(circuit, idx)\n return arr\n```\n###test function signature:\n```python\n def test_data(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_data", + "reference": " def test_data(self):\n with self.subTest(\"binding a single value\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 0)\n self.assertEqual(ba.shape, ())\n self.assertEqual(ba.size, 1)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuit = ba.bind(self.circuit, ())\n self.assertEqual(bound_circuit, self.circuit.assign_parameters(vals))\n\n with self.subTest(\"binding an array\"):\n vals = np.linspace(0, 1, 300).reshape((2, 3, 50))\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 2)\n self.assertEqual(ba.shape, (2, 3))\n self.assertEqual(ba.size, 6)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuits = ba.bind_all(self.circuit)\n self.assertEqual(bound_circuits.shape, (2, 3))\n self.assertEqual(bound_circuits[0, 0], self.circuit.assign_parameters(vals[0, 0]))\n self.assertEqual(bound_circuits[0, 1], self.circuit.assign_parameters(vals[0, 1]))\n self.assertEqual(bound_circuits[0, 2], self.circuit.assign_parameters(vals[0, 2]))\n self.assertEqual(bound_circuits[1, 0], self.circuit.assign_parameters(vals[1, 0]))\n self.assertEqual(bound_circuits[1, 1], self.circuit.assign_parameters(vals[1, 1]))\n self.assertEqual(bound_circuits[1, 2], self.circuit.assign_parameters(vals[1, 2]))\n\n with self.subTest(\"binding a single param\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params[0]: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 1)\n self.assertEqual(ba.ndim, 1)\n self.assertEqual(ba.shape, (50,))\n self.assertEqual(ba.size, 50)\n self.assertEqual(list(ba.data.keys()), [(self.params[0].name,)])\n np.testing.assert_allclose(list(ba.data.values()), [vals[..., np.newaxis]])", + "completions": [ + "```python\n def test_data(self):\n from qiskit import QuantumCircuit\n from qiskit.quantum_info import Pauli\n from qiskit.quantum_info import PauliTable\n from qiskit.quantum_info import PauliSumOp\n from qiskit.quantum_info import PauliOp\n from qiskit.quantum_info import PauliExpectation\n from qiskit.quantum_info import PauliMoment\n from qiskit.quantum_info import PauliTensor\n from qiskit.quantum_info import PauliList\n from qiskit.quantum_info import PauliTransform\n from qiskit.quantum_info import PauliTransformOp\n from qiskit.quantum_info import PauliTransformList\n from qiskit.quantum_info import PauliTransformSum\n from qiskit.quantum_info import PauliTransformProduct\n from qiskit.quantum_info import PauliTransformExpectation\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum" + ], + "line": 15, + "token": 122, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef set_seed(\n seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None\n) -> int:\n world_size, rank = get_world_size_and_rank()\n max_val = np.iinfo(np.uint32).max - world_size + 1\n min_val = np.iinfo(np.uint32).min\n if seed is None:\n rand_seed = torch.randint(min_val, max_val, (1,))\n seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks\n if seed < min_val or seed > max_val:\n raise ValueError(\n f\"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]\"\n )\n local_seed = seed + rank\n _log.debug(\n f\"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}\"\n )\n\n torch.manual_seed(local_seed)\n np.random.seed(local_seed)\n random.seed(local_seed)\n\n if debug_mode is not None:\n _log.debug(f\"Setting deterministic debug mode to {debug_mode}\")\n torch.set_deterministic_debug_mode(debug_mode)\n deterministic_debug_mode = torch.get_deterministic_debug_mode()\n if deterministic_debug_mode == 0:\n _log.debug(\"Disabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = False\n torch.backends.cudnn.benchmark = True\n else:\n _log.debug(\"Enabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.benchmark = False\n # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n return seed\n```\n###test function signature:\n```python\n\n def test_deterministic_true(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_deterministic_true", + "reference": "\n def test_deterministic_true(self) -> None:\n for det_debug_mode, det_debug_mode_str in [(1, \"warn\"), (2, \"error\")]:\n warn_only = det_debug_mode == 1\n for debug_mode in (det_debug_mode, det_debug_mode_str):\n set_seed(42, debug_mode=debug_mode)\n assert torch.backends.cudnn.deterministic\n assert not torch.backends.cudnn.benchmark\n assert det_debug_mode == torch.get_deterministic_debug_mode()\n assert torch.are_deterministic_algorithms_enabled()\n assert (\n warn_only == torch.is_deterministic_algorithms_warn_only_enabled()\n )\n assert os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] == \":4096:8\"", + "completions": [ + "```python\n def test_deterministic_true(self) -> None:\n seed = set_seed(debug_mode=True)\n self.assertEqual(seed, 123456789)\n```\n" + ], + "line": 45, + "token": 345, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 56, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n```\n###test function signature:\n```python\n\ndef test_concat_periodarray_2d():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_concat_periodarray_2d", + "reference": "\ndef test_concat_periodarray_2d():\n pi = pd.period_range(\"2016-01-01\", periods=36, freq=\"D\")\n arr = pi._data.reshape(6, 6)\n\n result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)\n tm.assert_period_array_equal(result, arr)\n\n result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)\n tm.assert_period_array_equal(result, arr)\n\n msg = (\n \"all the input array dimensions.* for the concatenation axis must match exactly\"\n )\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)\n\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:2], arr[2:]], axis=1)", + "completions": [ + "```python\n\ndef test_concat_periodarray_2d():\n import numpy as np\n import pandas as pd\n from pandas.core.dtypes.common import is_period_dtype\n\n def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n\n def _concat_datetime(to_concat, axis: AxisInt = 0):\n from pandas.core.arrays import PeriodArray\n\n if axis == 0:\n " + ], + "line": 70, + "token": 610, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 57, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef ingest_ocp_payload(request):\n request_id = uuid4().hex\n response_data = {\"request-id\": request_id, \"payload-name\": []}\n if request.method == \"POST\":\n if not request.FILES:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for _, file in request.FILES.items():\n payload_name = file.name\n response_data[\"payload-name\"].append(payload_name)\n s3_signature = get_s3_signature(settings.S3_ENDPOINT, payload_name, method=\"put_object\")\n res = upload_file_to_s3(s3_signature, data=file.file)\n if res.status_code == HTTPStatus.OK:\n response_data[\"upload\"] = \"success\"\n else:\n response_data[\"upload\"] = \"failed\"\n response_data[\"failed-reason\"] = res.reason\n return Response(response_data, status=res.status_code)\n send_payload(request_id, payload_name)\n else:\n params = request.query_params\n payload_names = params.get(\"payload_name\")\n if not payload_names:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for payload_name in payload_names.split(\",\"):\n send_payload(request_id, payload_name)\n response_data[\"payload-name\"].append(payload_name)\n\n response_data[\"ingest-started\"] = True\n\n return Response(response_data, status=HTTPStatus.ACCEPTED)\n```\n###test function signature:\n```python\n\n def test_ingest_ocp_payload(self):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ingest_ocp_payload", + "reference": "\n def test_ingest_ocp_payload(self):\n # Arrange\n file1 = SimpleUploadedFile(\"file1.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n file2 = SimpleUploadedFile(\"file2.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n test_table = [\n # Happy path tests\n (\n \"POST\",\n {\"file1\": file1, \"file2\": file2},\n \"\",\n HTTPStatus.ACCEPTED,\n {\"upload\": \"success\", \"ingest-started\": True},\n \"happy_path_post\",\n ),\n (\"GET\", {}, \"?payload_name=file1,file2\", HTTPStatus.ACCEPTED, {\"ingest-started\": True}, \"happy_path_get\"),\n # Edge cases\n (\"POST\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_files_post\"),\n (\"GET\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_payload_name_get\"),\n # Error cases\n (\"POST\", {\"file1\": file1}, \"\", HTTPStatus.INTERNAL_SERVER_ERROR, {\"upload\": \"failed\"}, \"error_case_post\"),\n (\"GET\", {}, \"?payload_name=non_existent_file\", HTTPStatus.ACCEPTED, {}, \"error_case_get\"),\n ]\n for test in test_table:\n method, files, query_params, expected_status, expected_response, test_id = test\n with self.subTest(test=test_id):\n\n mock_response = MagicMock()\n mock_response.status_code = (\n HTTPStatus.OK if expected_response.get(\"upload\") != \"failed\" else HTTPStatus.INTERNAL_SERVER_ERROR\n )\n mock_response.reason = \"Some reason\"\n\n with patch(\"masu.api.ingest_ocp_payload.get_s3_signature\") as mock_get_s3_signature, patch(\n \"masu.api.ingest_ocp_payload.upload_file_to_s3\", return_value=mock_response\n ) as mock_upload_file_to_s3, patch(\"masu.api.ingest_ocp_payload.send_payload\") as mock_send_payload:\n\n request = (\n APIRequestFactory().get(f\"/{query_params}\")\n if method == \"GET\"\n else APIRequestFactory().post(\"/\", data=files)\n )\n\n # Act\n response = ingest_ocp_payload(request)\n\n # Assert\n self.assertEqual(response.status_code, expected_status)\n if method == \"POST\" and files:\n mock_get_s3_signature.assert_called()\n mock_upload_file_to_s3.assert_called()\n if method == \"GET\" and \"payload_name\" in query_params:\n mock_send_payload.assert_called()\n for key, value in expected_response.items():\n self.assertEqual(response.data[key], value)", + "completions": [ + "```python\n def test_ingest_ocp_payload(self):\n # Arrange\n request = self.factory.post(\n \"/api/v1/ingest/ocp/payload\",\n data={\"payload_name\": \"payload_name\"},\n content_type=\"multipart/form-data\",\n )\n request.FILES = {\"payload_name\": \"payload_name\"}\n # Act\n response = ingest_ocp_payload(request)\n # Assert\n self.assertEqual(response.status_code, HTTPStatus.ACCEPTED)\n self.assertEqual(response.data[\"request-id\"], \"request_id\")\n self.assertEqual(response.data[\"payload-name\"], [\"payload_name\"])\n self.assertEqual(response.data[\"ingest-started\"], True)\n```\n" + ], + "line": 41, + "token": 308, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 58, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def compute_nb_items_with_metadata_field(\n project_id: str,\n metadata_field: str,\n collection_name: str,\n quantile_value: float = 0.1,\n) -> Tuple[int, float, int]:\n\n mongo_db = await get_mongo_db()\n\n try:\n pipeline = [\n # Match on the project id and the existence of the metadata field\n {\n \"$match\": {\n \"project_id\": project_id,\n f\"metadata.{metadata_field}\": {\"$exists\": True, \"$ne\": None},\n }\n },\n {\n \"$group\": {\n \"_id\": \"$user_id\", # Group by 'user_id'\n \"count\": {\"$sum\": 1}, # Count the documents in each group\n }\n },\n {\n \"$sort\": {\"count\": -1} # Sort by 'count' in descending order\n },\n ]\n\n # Execute the aggregation pipeline\n results = [] # List to hold the results\n async for group in mongo_db[collection_name].aggregate(pipeline):\n results.append(group[\"count\"])\n\n logger.debug(\n f\"Results for {metadata_field} in collection {collection_name} for project {project_id}: {results}\"\n )\n\n # If no results, return 0\n if len(results) == 0:\n return 0, 0.0, 0\n\n # Get the average and the quantiles\n average = sum(results) / len(results)\n bottom_quantile = results[int(len(results) * quantile_value)]\n top_quantile = results[int(len(results) * (1 - quantile_value))]\n\n return bottom_quantile, average, top_quantile\n\n except Exception as e:\n logger.warning(\n f\"Failed to fetch the number of {metadata_field} in collection {collection_name} for project {project_id}: {e}\",\n )\n return 0, 0.0, 0\n```\n###test function signature:\n```python\n\nasync def test_main_pipeline(db, populated_project):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_main_pipeline", + "reference": "\nasync def test_main_pipeline(db, populated_project):\n async for mongo_db in db:\n test_project_id = populated_project.id\n\n # Run the function\n topics = await fetch_topics(test_project_id)\n\n logger.debug(f\"Topics: {topics}\")\n\n count = await nb_items_with_a_metadata_field(\n test_project_id, \"tasks\", \"user_id\"\n )\n\n assert count > 0\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_successrate_metadata_quantiles(test_project_id, \"user_id\")\n\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert bottom_quantile >= 0\n # assert bottom_quantile <= top_quantile\n\n test_project_id = populated_project.id\n test_metadata_field = \"user_id\"\n test_collection_name = \"tasks\"\n test_quantile_value = 0.1\n\n # Run the function\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_nb_items_with_metadata_field(\n test_project_id,\n test_metadata_field,\n test_collection_name,\n test_quantile_value,\n )\n\n logger.debug(\"runned compute_nb_items_with_metadata_field\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, int)\n assert isinstance(average, float)\n assert isinstance(top_quantile, int)\n # assert bottom_quantile <= average <= top_quantile\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_session_length_per_metadata(\n test_project_id,\n test_metadata_field,\n test_quantile_value,\n )\n logger.debug(\"runned compute_session_length_per_metadata\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, float)\n assert isinstance(average, float)\n assert isinstance(top_quantile, float)\n # assert bottom_quantile <= average <= top_quantile", + "completions": [ + "```python\nasync def test_main_pipeline(db, populated_project):\n # Arrange\n project_id = populated_project.id\n collection_name = \"collection_name\"\n metadata_field = \"metadata_field\"\n quantile_value = 0.1\n\n # Act\n result = await compute_nb_items_with_metadata_field(\n project_id, metadata_field, collection_name, quantile_value\n )\n\n # Assert\n assert result == (0, 0.0, 0)\n```\n" + ], + "line": 56, + "token": 430, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 59, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n```\n###test function signature:\n```python\n\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_engine_auto_error_message", + "reference": "\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n from pandas.compat._optional import VERSIONS\n\n # Do we have engines installed, but a bad version of them?\n pa_min_ver = VERSIONS.get(\"pyarrow\")\n fp_min_ver = VERSIONS.get(\"fastparquet\")\n have_pa_bad_version = (\n False\n if not _HAVE_PYARROW\n else Version(pyarrow.__version__) < Version(pa_min_ver)\n )\n have_fp_bad_version = (\n False\n if not _HAVE_FASTPARQUET\n else Version(fastparquet.__version__) < Version(fp_min_ver)\n )\n # Do we have usable engines installed?\n have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version\n have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version\n\n if not have_usable_pa and not have_usable_fp:\n # No usable engines found.\n if have_pa_bad_version:\n match = f\"Pandas requires version .{pa_min_ver}. or newer of .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n\n if have_fp_bad_version:\n match = f\"Pandas requires version .{fp_min_ver}. or newer of .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")", + "completions": [ + "```python\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n import pyarrow\n import fastparquet\n from pyarrow.parquet import ParquetEngine\n from fastparquet.parquet import ParquetEngine as FastParquetEngine\n\n def get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n\n # pyarrow is installed\n assert get_engine(\"auto\").__class__ == PyArrowImpl\n assert get_engine(\"pyarrow\").__class__ == PyArrowImpl\n assert get_engine(\"fastparquet\").__class__ == FastParquetImpl\n\n # pyarrow is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'pyarrow'\" in str(excinfo.value)\n\n # fastparquet is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'fastparquet'\" in str(excinfo.value)\n\n # pyarrow is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n\n # fastparquet is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n" + ], + "line": 38, + "token": 277, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 60, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def attribute(\n self,\n inp: InterpretableInput,\n target: Union[str, torch.Tensor, None] = None,\n num_trials: int = 1,\n gen_args: Optional[Dict] = None,\n # internal callback hook can be used for logging\n _inspect_forward: Optional[Callable] = None,\n **kwargs,\n ) -> LLMAttributionResult:\n\n assert isinstance(\n inp, self.SUPPORTED_INPUTS\n ), f\"LLMAttribution does not support input type {type(inp)}\"\n\n if target is None:\n # generate when None\n assert hasattr(self.model, \"generate\") and callable(self.model.generate), (\n \"The model does not have recognizable generate function.\"\n \"Target must be given for attribution\"\n )\n\n if not gen_args:\n gen_args = DEFAULT_GEN_ARGS\n\n model_inp = self._format_model_input(inp.to_model_input())\n output_tokens = self.model.generate(model_inp, **gen_args)\n target_tokens = output_tokens[0][model_inp.size(1) :]\n else:\n assert gen_args is None, \"gen_args must be None when target is given\"\n\n if type(target) is str:\n # exclude sos\n target_tokens = self.tokenizer.encode(target)[1:]\n target_tokens = torch.tensor(target_tokens)\n elif type(target) is torch.Tensor:\n target_tokens = target\n\n attr = torch.zeros(\n [\n 1 + len(target_tokens) if self.include_per_token_attr else 1,\n inp.n_itp_features,\n ],\n dtype=torch.float,\n device=self.device,\n )\n\n for _ in range(num_trials):\n attr_input = inp.to_tensor().to(self.device)\n\n cur_attr = self.attr_method.attribute(\n attr_input,\n additional_forward_args=(inp, target_tokens, _inspect_forward),\n **kwargs,\n )\n\n # temp necessary due to FA & Shapley's different return shape of multi-task\n # FA will flatten output shape internally (n_output_token, n_itp_features)\n # Shapley will keep output shape (batch, n_output_token, n_input_features)\n cur_attr = cur_attr.reshape(attr.shape)\n\n attr += cur_attr\n\n attr = attr / num_trials\n\n attr = inp.format_attr(attr)\n\n return LLMAttributionResult(\n attr[0],\n attr[1:]\n if self.include_per_token_attr\n else None, # shape(n_output_token, n_input_features)\n inp.values,\n self.tokenizer.convert_ids_to_tokens(target_tokens),\n )\n```\n###test function signature:\n```python\n\n def test_llm_attr(self, AttrClass) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_llm_attr", + "reference": "\n def test_llm_attr(self, AttrClass) -> None:\n llm = DummyLLM()\n llm.to(self.device)\n tokenizer = DummyTokenizer()\n llm_attr = LLMAttribution(AttrClass(llm), tokenizer)\n\n inp = TextTemplateInput(\"{} b {} {} e {}\", [\"a\", \"c\", \"d\", \"f\"])\n res = llm_attr.attribute(inp, \"m n o p q\")\n\n self.assertEqual(res.seq_attr.shape, (4,))\n self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))\n self.assertEqual(res.input_tokens, [\"a\", \"c\", \"d\", \"f\"])\n self.assertEqual(res.output_tokens, [\"m\", \"n\", \"o\", \"p\", \"q\"])\n self.assertEqual(res.seq_attr.device.type, self.device)\n self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)", + "completions": [ + "```python\n def test_llm_attr(self, AttrClass) -> None:\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for L" + ], + "line": 72, + "token": 464, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 61, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def plan(self, task: str):\n agent_pool_description = self.agent_pool.get_agent_pool_description()\n feedback = self.get_last_feedback()\n last_plan = self.get_last_plan()\n\n # run the planning\n plan = self.planning.execute(task, agent_pool_description, last_plan, feedback)\n\n logger.info(f\"Plan: {plan}\")\n\n self.shared_memory.add(\n event_type=EventType.planning, agent=self.name, content=str(plan)\n )\n\n return plan\n```\n###test function signature:\n```python\n\ndef test_planner_succeeds(get_llm): # noqa: F811\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_planner_succeeds", + "reference": "\ndef test_planner_succeeds(get_llm): # noqa: F811\n llm = get_llm(__file__, test_planner_succeeds.__name__)\n\n physicist_description = (\n \"The physicist agent answers questions or research about physics-related topics\"\n )\n physicist = Physicist(\n name=\"Physicist\",\n description=physicist_description,\n llm=llm,\n )\n\n agent_pool = AgentPool()\n agent_pool.add_agents([physicist])\n\n shared_memory = SharedMemory(\n objective=\"Share the information across different agents.\",\n agent_pool=agent_pool,\n )\n\n planner = Planner(\n name=\"planner\",\n agent_pool=agent_pool,\n shared_memory=shared_memory,\n llm=llm,\n )\n\n task = \"\"\"We need to render a highly complex 3D image on the solar system. We can use any publicly avaliable\n resources to achieve this task.\"\"\" # noqa: E501\n\n p = planner.plan(task=task)\n assert len(p.steps) > 0", + "completions": [ + "```python\ndef test_planner_succeeds(get_llm): # noqa: F811\n planner = get_llm\n planner.agent_pool.set_agent_pool_description(\n {\n \"agent_pool\": {\n \"agents\": [\n {\n \"name\": \"agent1\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n {\n \"name\": \"agent2\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n ]\n }\n }\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task1\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task2\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task3\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task4\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task5\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task6\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task7\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\"," + ], + "line": 21, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_edge(self, from_step: str, to_step: str) -> None:\n if from_step not in self.G:\n raise ValueError(f\"Step with name '{from_step}' does not exist\")\n\n if to_step not in self.G:\n raise ValueError(f\"Step with name '{to_step}' does not exist\")\n\n if to_step in self.G[from_step]:\n raise ValueError(\n f\"There is already a edge from '{to_step}' to '{from_step}'\"\n )\n\n if to_step in nx.ancestors(self.G, from_step):\n raise ValueError(\n f\"Cannot add edge from '{from_step}' to '{to_step}' as it would create a cycle.\"\n )\n\n self.G.add_edge(from_step, to_step)\n```\n###test function signature:\n```python\n\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validate_missing_step_input", + "reference": "\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n class DummyStep3(Step):\n @property\n def inputs(self) -> List[str]:\n return [\"instruction\"]\n\n @property\n def outputs(self) -> List[str]:\n return [\"response\"]\n\n def process(self) -> \"StepOutput\": # type: ignore\n yield [{\"response\": \"response1\"}]\n\n dag = DAG()\n dag.add_step(dummy_generator_step)\n dag.add_step(DummyStep3(name=\"dummy_step_3\", pipeline=pipeline))\n dag.add_edge(\"dummy_generator_step\", \"dummy_step_3\")\n\n with pytest.raises(\n ValueError,\n match=\"Step 'dummy_step_3' should have a parameter with type hint `StepInput`\",\n ):\n dag.validate()", + "completions": [ + "```python\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_2\")\n```\n" + ], + "line": 26, + "token": 238, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 63, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_hourly_temp_mean", + "reference": "\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n meter_data = il_electricity_cdd_hdd_hourly[\"meter_data\"][\"2016-03-01\":\"2016-07-01\"]\n temperature_data = il_electricity_cdd_hdd_hourly[\"temperature_data\"][\n \"2016-03-01\":\"2016-07-01\"\n ]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert list(sorted(df.columns)) == [\n \"n_hours_dropped\",\n \"n_hours_kept\",\n \"temperature_mean\",\n ]\n assert df.shape == (2952, 3)\n\n assert round(df.temperature_mean.mean()) == 62.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n df = il_electricity_cdd_hdd_hourly.copy()\n df = df.iloc[100:150]\n df = df.assign(\n **{\n \"cdd_10\": np.maximum(df.temperature_mean - 10, 0),\n \"cdd_20\": np.maximum(df.temperature_mean - 20, 0),\n \"hdd_10\": np.maximum(10 - df.temperature_mean, 0),\n \"hdd_20\": np.maximum(20 - df.temperature_mean, 0),\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n df = df.drop([\"temperature_mean\"], axis=1)\n df = df.reindex(il_electricity_cdd_hdd_hourly.index)\n df = overwrite_partial_rows_with_nan(df)\n df = df.iloc[:-1].reindex(df.index)\n assert_frame_equal(df, il_electricity_cdd_hdd_hourly)\n```\n" + ], + "line": 167, + "token": 992, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 64, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef pandasSQL_builder(\n con,\n schema: str | None = None,\n need_transaction: bool = False,\n) -> PandasSQL:\n import sqlite3\n\n if isinstance(con, sqlite3.Connection) or con is None:\n return SQLiteDatabase(con)\n\n sqlalchemy = import_optional_dependency(\"sqlalchemy\", errors=\"ignore\")\n\n if isinstance(con, str) and sqlalchemy is None:\n raise ImportError(\"Using URI string without sqlalchemy installed.\")\n\n if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):\n return SQLDatabase(con, schema, need_transaction)\n\n warnings.warn(\n \"pandas only supports SQLAlchemy connectable (engine/connection) or \"\n \"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 \"\n \"objects are not tested. Please consider using SQLAlchemy.\",\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return SQLiteDatabase(con)\n```\n###test function signature:\n```python\n\n def test_execute_closed_connection(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_execute_closed_connection", + "reference": "\n def test_execute_closed_connection(self):\n create_sql = \"\"\"\n CREATE TABLE test\n (\n a TEXT,\n b TEXT,\n c REAL,\n PRIMARY KEY (a, b)\n );\n \"\"\"\n with contextlib.closing(sqlite3.connect(\":memory:\")) as conn:\n cur = conn.cursor()\n cur.execute(create_sql)\n\n with sql.pandasSQL_builder(conn) as pandas_sql:\n pandas_sql.execute('INSERT INTO test VALUES(\"foo\", \"bar\", 1.234)')\n\n msg = \"Cannot operate on a closed database.\"\n with pytest.raises(sqlite3.ProgrammingError, match=msg):\n tquery(\"select * from test\", con=conn)", + "completions": [ + "```python\n def test_execute_closed_connection(self):\n from pandas.io.sql import PandasSQL\n from pandas.io.sql.sqldatabase import SQLiteDatabase\n from pandas.io.sql.pandas_sql import PandasSQLQueryCompiler\n from pandas.io.sql.pandas_sql import PandasSQLEngine\n from pandas.io.sql.pandas_sql import PandasSQLTable\n from pandas.io.sql.pandas_sql import PandasSQLConverter\n from pandas.io.sql.pandas_sql import PandasSQLAlchemy\n from pandas.io.sql.pandas_sql import PandasSQLQuery\n from pandas.io.sql.pandas_sql import PandasSQLExecutionContext\n from pandas.io.sql.pandas_sql import PandasSQLCompiler\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas." + ], + "line": 31, + "token": 237, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 65, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_hourly_degree_days", + "reference": "\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index,\n temperature_data,\n heating_balance_points=[60, 61],\n cooling_balance_points=[65, 66],\n temperature_mean=False,\n degree_day_method=\"hourly\",\n )\n assert df.shape == (27, 6)\n assert list(sorted(df.columns)) == [\n \"cdd_65\",\n \"cdd_66\",\n \"hdd_60\",\n \"hdd_61\",\n \"n_hours_dropped\",\n \"n_hours_kept\",\n ]\n snapshot.assert_match(\n [\n round(df.hdd_60.mean(), 2),\n round(df.hdd_61.mean(), 2),\n round(df.cdd_65.mean(), 2),\n round(df.cdd_66.mean(), 2),\n round(df.n_hours_kept.mean(), 2),\n round(df.n_hours_dropped.mean(), 2),\n ],\n \"values\",\n )", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"H\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 1" + ], + "line": 168, + "token": 985, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 66, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef initialize_model_parallel(\n tensor_model_parallel_size: int = 1,\n pipeline_model_parallel_size: int = 1,\n virtual_pipeline_model_parallel_size: Optional[int] = None,\n pipeline_model_parallel_split_rank: Optional[int] = None,\n use_fp8: bool = False,\n) -> None:\n # Get world size and rank. Ensure some consistencies.\n assert torch.distributed.is_initialized()\n world_size: int = torch.distributed.get_world_size()\n\n if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:\n raise RuntimeError(\n f\"world_size ({world_size}) is not divisible by tensor_model_parallel_size \"\n f\"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n )\n\n data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n\n num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n num_data_parallel_groups: int = world_size // data_parallel_size\n\n if virtual_pipeline_model_parallel_size is not None:\n if not pipeline_model_parallel_size > 2:\n raise RuntimeError(\"pipeline-model-parallel size should be greater than 2 with \" \"interleaved schedule\")\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size\n\n if pipeline_model_parallel_split_rank is not None:\n global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK\n _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank\n\n rank = torch.distributed.get_rank()\n\n # Build the data-parallel groups.\n global _DATA_PARALLEL_GROUP\n global _DATA_PARALLEL_GROUP_GLOO\n global _DATA_PARALLEL_GLOBAL_RANKS\n assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'\n all_data_parallel_group_ranks = []\n for i in range(pipeline_model_parallel_size):\n start_rank = i * num_pipeline_model_parallel_groups\n end_rank = (i + 1) * num_pipeline_model_parallel_groups\n for j in range(tensor_model_parallel_size):\n ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)\n all_data_parallel_group_ranks.append(list(ranks))\n group = torch.distributed.new_group(ranks)\n group_gloo = torch.distributed.new_group(ranks, backend=\"gloo\")\n if rank in ranks:\n _DATA_PARALLEL_GROUP = group\n _DATA_PARALLEL_GROUP_GLOO = group_gloo\n _DATA_PARALLEL_GLOBAL_RANKS = ranks\n\n # Build the model-parallel groups.\n global _MODEL_PARALLEL_GROUP\n assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'\n for i in range(data_parallel_size):\n ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _MODEL_PARALLEL_GROUP = group\n\n # Build the tensor model-parallel groups.\n global _TENSOR_MODEL_PARALLEL_GROUP\n assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized'\n for i in range(num_tensor_model_parallel_groups):\n ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _TENSOR_MODEL_PARALLEL_GROUP = group\n\n # Build the pipeline model-parallel groups and embedding groups\n # (first and last rank in each pipeline model-parallel group).\n global _PIPELINE_MODEL_PARALLEL_GROUP\n global _PIPELINE_GLOBAL_RANKS\n assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized'\n global _EMBEDDING_GROUP\n global _EMBEDDING_GLOBAL_RANKS\n assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'\n global _POSITION_EMBEDDING_GROUP\n global _POSITION_EMBEDDING_GLOBAL_RANKS\n assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'\n for i in range(num_pipeline_model_parallel_groups):\n ranks = range(i, world_size, num_pipeline_model_parallel_groups)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _PIPELINE_MODEL_PARALLEL_GROUP = group\n _PIPELINE_GLOBAL_RANKS = ranks\n # Setup embedding group (to exchange gradients between\n # first and last stages).\n if len(ranks) > 1:\n embedding_ranks = [ranks[0], ranks[-1]]\n position_embedding_ranks = [ranks[0]]\n if pipeline_model_parallel_split_rank is not None:\n if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:\n embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]]\n if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:\n position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]\n else:\n embedding_ranks = ranks\n position_embedding_ranks = ranks\n\n group = torch.distributed.new_group(embedding_ranks)\n if rank in embedding_ranks:\n _EMBEDDING_GROUP = group\n if rank in ranks:\n _EMBEDDING_GLOBAL_RANKS = embedding_ranks\n\n group = torch.distributed.new_group(position_embedding_ranks)\n if rank in position_embedding_ranks:\n _POSITION_EMBEDDING_GROUP = group\n if rank in ranks:\n _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks\n\n # Build the FP8 groups.\n global _AMAX_REDUCTION_GROUP\n assert _AMAX_REDUCTION_GROUP is None, \\\n 'FP8 amax reduction group is already initialized'\n if use_fp8:\n amax_group_size: int = tensor_model_parallel_size * data_parallel_size\n num_amax_groups: int = world_size // amax_group_size\n for i in range(num_amax_groups):\n start_rank = i * amax_group_size\n end_rank = (i + 1) * amax_group_size\n ranks = range(start_rank, end_rank)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _AMAX_REDUCTION_GROUP = group\n\n # Initialize global memory buffer\n # This isn't really \"parallel state\" but there isn't another good place to\n # put this. If we end up with a more generic initialization of megatron-core\n # we could stick it there\n _set_global_memory_buffer()\n```\n###test function signature:\n```python\n\ndef test_initialize__and_destroy_model_parallel():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_initialize__and_destroy_model_parallel", + "reference": "\ndef test_initialize__and_destroy_model_parallel():\n with pytest.raises(AssertionError):\n assert(ps.initialize_model_parallel())\n Utils.initialize_distributed()\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))\n Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)\n\n assert(ps.model_parallel_is_initialized())\n assert(ps.get_model_parallel_group() is not None)\n assert(ps.get_tensor_model_parallel_group() is not None)\n assert(ps.get_pipeline_model_parallel_group() is not None)\n assert(ps.get_data_parallel_group() is not None) \n Utils.destroy_model_parallel()\n assert(ps._MODEL_PARALLEL_GROUP is None)", + "completions": [ + "```python\ndef test_initialize_and_destroy_model_parallel():\n initialize_model_parallel(\n tensor_model_parallel_size=1,\n pipeline_model_parallel_size=1,\n virtual_pipeline_model_parallel_size=None,\n pipeline_model_parallel_split_rank=None,\n use_fp8=False,\n )\n destroy_model_parallel()\n```\n" + ], + "line": 133, + "token": 886, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 67, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_shapes(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_shapes", + "reference": "\n def test_spectrogram_shapes(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 729))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 183))", + "completions": [ + "```python\n def test_spectrogram_shapes(self):\n waveform = np.random.rand(1000)\n window = np.hanning(100)\n frame_length = 100\n hop_length = 10\n fft_length = 100\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (10, 100))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 68, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_base64_image_given_and_no_resize_needed", + "reference": "\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n _, image = example_local_image\n _, buffer = cv2.imencode(\".jpg\", image)\n base64_image = base64.b64encode(buffer).decode(\"ascii\")\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=base64_image,\n max_width=None,\n max_height=None,\n )\n\n # then\n assert (\n scaling_factor is None\n ), \"No resize parameters given, so scaling factor should not be established\"\n assert (\n serialised_image == base64_image\n ), \"Serialised image should be identical with input\"", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n image_bytes = base64.b64decode(example_local_image[0])\n image = bytes_to_opencv_image(payload=image_bytes)\n # when\n result = load_image_from_string(\n reference=example_local_image[0],\n max_height=None,\n max_width=None,\n )\n # then\n assert result[0] == example_local_image[0]\n assert result[1] is None\n assert np.array_equal(image, result[1])\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 69, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def submit_payout_address(self, robot_index=1):\n path = reverse(\"order\")\n params = f\"?order_id={self.order_id}\"\n headers = self.get_robot_auth(robot_index)\n\n payout_address = create_address(\"robot\")\n signed_payout_address = sign_message(\n payout_address,\n passphrase_path=f\"tests/robots/{robot_index}/token\",\n private_key_path=f\"tests/robots/{robot_index}/enc_priv_key\",\n )\n body = {\n \"action\": \"update_address\",\n \"address\": signed_payout_address,\n \"mining_fee_rate\": 50,\n }\n self.response = self.client.post(path + params, body, **headers)\n```\n###test function signature:\n```python\n def test_successful_onchain(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_successful_onchain", + "reference": " def test_successful_onchain(self):\n trade = Trade(self.client)\n trade.publish_order()\n trade.take_order()\n trade.lock_taker_bond()\n trade.lock_escrow(trade.taker_index)\n trade.submit_payout_address(trade.maker_index)\n trade.confirm_fiat(trade.maker_index)\n trade.confirm_fiat(trade.taker_index)\n\n trade.process_payouts(mine_a_block=True)\n trade.get_order(trade.maker_index)\n\n data = trade.response.json()\n\n self.assertEqual(trade.response.status_code, 200)\n self.assertResponse(trade.response)\n\n self.assertEqual(data[\"status_message\"], Order.Status(Order.Status.SUC).label)\n self.assertTrue(data[\"is_fiat_sent\"])\n self.assertFalse(data[\"is_disputed\"])\n self.assertIsInstance(data[\"maker_summary\"][\"address\"], str)\n self.assertIsHash(data[\"maker_summary\"][\"txid\"])\n\n self.assert_order_logs(data[\"id\"])", + "completions": [ + "```python\n def test_successful_onchain(self):\n self.submit_payout_address(robot_index=1)\n self.assertEqual(self.response.status_code, 200)\n self.assertEqual(self.response.json()[\"status\"], \"success\")\n```\n" + ], + "line": 26, + "token": 183, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 70, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_blank_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_blank_line", + "reference": "\ndef test_header_to_metadata_and_cell_blank_line():\n text = \"\"\"---\ntitle: Sample header\n---\n\nHeader is followed by a blank line\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {}\n assert lines[pos].startswith(\"Header is\")", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_blank_line():\n lines = [\"#' ---\", \"\"]\n header_prefix = \"#'\"\n header_suffix = \"\"\n ext = None\n root_level_metadata_as_raw_cell = True\n metadata, jupyter, cell, i = header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext, root_level_metadata_as_raw_cell\n )\n assert metadata == {}\n assert jupyter == []\n assert cell == new_raw_cell(source=\"\\n\".join([\"---\"]), metadata={\"lines_to_next_cell\": 1})\n assert i == 1\n```\n" + ], + "line": 104, + "token": 642, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 71, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def _create_output(\n self,\n accepted: torch.Tensor, # [batch_size, k]\n recovered_token_ids: torch.Tensor, # [batch_size, k]\n draft_token_ids: torch.Tensor, # [batch_size, k]\n bonus_token_ids: torch.Tensor, # [batch_size]\n ) -> torch.Tensor:\n bonus_token_ids = bonus_token_ids.squeeze()\n batch_size, k = recovered_token_ids.shape\n\n # Determine the index of the first False value for each row.\n limits = (accepted == 0).max(1).indices\n limits[~(accepted == 0).any(1)] = k\n\n # Create masks using the indices.\n indices = torch.arange(k, device=accepted.device).unsqueeze(0)\n accepted_mask = indices < limits.unsqueeze(1)\n after_false_mask = indices == limits.unsqueeze(1)\n\n # Create an extended output tensor\n output_with_bonus_tokens = -torch.ones(\n (batch_size, k + self._num_bonus_tokens),\n dtype=self.token_id_dtype,\n device=accepted.device)\n output = output_with_bonus_tokens[:, :k]\n\n # Fill in the first k columns of the output tensor using masks and data\n # tensors.\n output[:, :k] = torch.where(accepted_mask, draft_token_ids,\n -torch.ones_like(draft_token_ids))\n\n # Fill the last column.\n # We check output directly as accepted may have True values inconsistent\n # with causal acceptance.\n output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,\n bonus_token_ids, -1)\n\n # We disable bonus tokens because it causes corrupt KV cache for\n # proposal methods that require KV cache. We can fix it by \"prefilling\"\n # the bonus token in the proposer. The following issue tracks the fix.\n # https://github.com/vllm-project/vllm/issues/4212\n output_with_bonus_tokens[:, -1] = -1\n\n # Fill the recovered token ids.\n output.mul_(~after_false_mask).add_(\n recovered_token_ids.mul(after_false_mask))\n\n self.num_accepted_tokens += accepted.sum()\n self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()\n self.num_draft_tokens += batch_size * k\n\n return output_with_bonus_tokens\n```\n###test function signature:\n```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct_output_format", + "reference": "def test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n set_random_seed(seed)\n torch.set_default_device(device)\n\n batch_size = 10\n k = 5\n vocab_size = 3000\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"no_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"some_tokens_accepted\":\n last_accepted_indices = torch.randint(low=-1,\n high=k,\n size=(batch_size, ))\n accepted = mock_causal_accepted_tensor(k, last_accepted_indices)\n else:\n raise AssertionError()\n\n recovered_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n draft_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n bonus_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, 1),\n dtype=torch.int64)\n\n rejection_sampler = RejectionSampler()\n rejection_sampler.init_gpu_tensors(rank=0)\n output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access\n accepted,\n recovered_token_ids,\n draft_token_ids,\n bonus_token_ids,\n )\n\n # Bonus tokens are currently disabled. Verify they're set to -1.\n # See https://github.com/vllm-project/vllm/issues/4212\n expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n # Expect all tokens to be equal to draft tokens.\n assert torch.equal(output_token_ids[:, :-1], draft_token_ids)\n\n # Expect all bonus tokens to be included.\n assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)\n elif which_tokens_accepted == \"no_tokens_accepted\":\n # Expect first token to be equal to recovered tokens.\n assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])\n\n # Expect everything else to be -1.\n assert torch.equal(output_token_ids[:, 1:],\n torch.ones_like(output_token_ids[:, 1:]) * -1)\n elif which_tokens_accepted == \"some_tokens_accepted\":\n recovered_plus_bonus = torch.cat(\n (recovered_token_ids, expected_bonus_token_ids), dim=-1)\n # Assert first rejected token is a recovered token or bonus token.\n assert torch.equal(\n recovered_plus_bonus[torch.arange(0, batch_size),\n last_accepted_indices + 1],\n output_token_ids[torch.arange(0, batch_size),\n last_accepted_indices + 1])\n\n # Assert every subsequent token is -1.\n subsequent_mask = torch.arange(0, k + 1).expand(\n batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)\n assert torch.all(output_token_ids[subsequent_mask] == -1)", + "completions": [ + "```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n torch.manual_seed(seed)\n batch_size = 10\n k = 10\n num_bonus_tokens = 1\n accepted = torch.randint(0, 2, (batch_size, k), dtype=torch.int32)\n recovered_token_ids = torch.randint(0, 10, (batch_size, k),\n dtype=torch.int32)\n draft_token_ids = torch.randint(0, 10, (batch_size, k), dtype=torch.int32)\n bonus_token_ids = torch.randint(0, 10, (batch_size,), dtype=torch.int32)\n output = _create_output(accepted, recovered_token_ids, draft_token_ids,\n bonus_token_ids)\n assert output.shape == (batch_size, k + num_bonus_tokens)\n assert output.dtype == torch.int32\n assert output.device == device\n assert (output[:, :k] == torch.where(accepted, draft_token_ids,\n -torch.ones_like(draft_token_ids))).all()\n assert (output[:, -1] == torch.where(output[:, -1] != -1, bonus_token_ids,\n -1)).all()\n assert (output[:, -1] == -1).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted" + ], + "line": 54, + "token": 433, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 72, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_array(fp, allow_pickle=False, pickle_kwargs=None, *,\n max_header_size=_MAX_HEADER_SIZE):\n if allow_pickle:\n # Effectively ignore max_header_size, since `allow_pickle` indicates\n # that the input is fully trusted.\n max_header_size = 2**64\n\n version = read_magic(fp)\n _check_version(version)\n shape, fortran_order, dtype = _read_array_header(\n fp, version, max_header_size=max_header_size)\n if len(shape) == 0:\n count = 1\n else:\n count = numpy.multiply.reduce(shape, dtype=numpy.int64)\n\n # Now read the actual data.\n if dtype.hasobject:\n # The array contained Python objects. We need to unpickle the data.\n if not allow_pickle:\n raise ValueError(\"Object arrays cannot be loaded when \"\n \"allow_pickle=False\")\n if pickle_kwargs is None:\n pickle_kwargs = {}\n try:\n array = pickle.load(fp, **pickle_kwargs)\n except UnicodeError as err:\n # Friendlier error message\n raise UnicodeError(\"Unpickling a python object failed: %r\\n\"\n \"You may need to pass the encoding= option \"\n \"to numpy.load\" % (err,)) from err\n else:\n if isfileobj(fp):\n # We can use the fast fromfile() function.\n array = numpy.fromfile(fp, dtype=dtype, count=count)\n else:\n # This is not a real file. We have to read it the\n # memory-intensive way.\n # crc32 module fails on reads greater than 2 ** 32 bytes,\n # breaking large reads from gzip streams. Chunk reads to\n # BUFFER_SIZE bytes to avoid issue and reduce memory overhead\n # of the read. In non-chunked case count < max_read_count, so\n # only one read is performed.\n\n # Use np.ndarray instead of np.empty since the latter does\n # not correctly instantiate zero-width string dtypes; see\n # https://github.com/numpy/numpy/pull/6430\n array = numpy.ndarray(count, dtype=dtype)\n\n if dtype.itemsize > 0:\n # If dtype.itemsize == 0 then there's nothing more to read\n max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)\n\n for i in range(0, count, max_read_count):\n read_count = min(max_read_count, count - i)\n read_size = int(read_count * dtype.itemsize)\n data = _read_bytes(fp, read_size, \"array data\")\n array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n count=read_count)\n\n if fortran_order:\n array.shape = shape[::-1]\n array = array.transpose()\n else:\n array.shape = shape\n\n return array\n```\n###test function signature:\n```python\n\ndef test_version_2_0():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_version_2_0", + "reference": "\ndef test_version_2_0():\n f = BytesIO()\n # requires more than 2 byte for header\n dt = [((\"%d\" % i) * 100, float) for i in range(500)]\n d = np.ones(1000, dtype=dt)\n\n format.write_array(f, d, version=(2, 0))\n with warnings.catch_warnings(record=True) as w:\n warnings.filterwarnings('always', '', UserWarning)\n format.write_array(f, d)\n assert_(w[0].category is UserWarning)\n\n # check alignment of data portion\n f.seek(0)\n header = f.readline()\n assert_(len(header) % format.ARRAY_ALIGN == 0)\n\n f.seek(0)\n n = format.read_array(f, max_header_size=200000)\n assert_array_equal(d, n)\n\n # 1.0 requested but data cannot be saved this way\n assert_raises(ValueError, format.write_array, f, d, (1, 0))", + "completions": [ + "```python\ndef test_version_2_0():\n fp = open(\"test_file.npy\", \"rb\")\n read_array(fp)\n```\n" + ], + "line": 70, + "token": 516, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 73, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):\n if not torch.is_tensor(audio):\n if isinstance(audio, str):\n audio = load_audio(audio)\n audio = torch.from_numpy(audio)\n\n window = torch.hann_window(N_FFT).to(audio.device)\n stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)\n magnitudes = stft[..., :-1].abs() ** 2\n\n filters = mel_filters(audio.device, n_mels)\n mel_spec = filters @ magnitudes\n\n log_spec = torch.clamp(mel_spec, min=1e-10).log10()\n log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)\n log_spec = (log_spec + 4.0) / 4.0\n return log_spec\n```\n###test function signature:\n```python\n\ndef test_audio():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_audio", + "reference": "\ndef test_audio():\n audio_path = os.path.join(os.path.dirname(__file__), \"jfk.flac\")\n audio = load_audio(audio_path)\n assert audio.ndim == 1\n assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12\n assert 0 < audio.std() < 1\n\n mel_from_audio = log_mel_spectrogram(audio)\n mel_from_file = log_mel_spectrogram(audio_path)\n\n assert np.allclose(mel_from_audio, mel_from_file)\n assert mel_from_audio.max() - mel_from_audio.min() <= 2.0", + "completions": [ + "```python\ndef test_audio():\n audio = torch.randn(1, 16000)\n log_mel_spectrogram(audio)\n```\n" + ], + "line": 24, + "token": 213, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 74, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef transform_grid_cell(\n grid: Grid,\n cell: npt.NDArray[np.float_] = np.eye(3),\n output_shape: Optional[Tuple[int, int, int]] = None,\n mode: str = \"constant\",\n order: int = 1,\n **kwargs,\n) -> Grid:\n # Take the current shape of the grid if no output shape was provided\n if output_shape is None:\n output_shape = grid.shape\n\n # Make sure the cell has type float\n cell = np.asarray(cell, dtype=float)\n\n # Get the current cell in coordinates of the destination axes\n inv_cell = cell_invert(cell).T\n projected_cell = grid.cell.dot(inv_cell)\n\n # From that, infere how long will the bounding box of the cell be\n lengths = abs(projected_cell).sum(axis=0)\n\n # Create the transformation matrix. Since we want to control the shape\n # of the output, we can not use grid.dcell directly, we need to modify it.\n scales = output_shape / lengths\n forward_t = (grid.dcell.dot(inv_cell) * scales).T\n\n # Scipy's affine transform asks for the inverse transformation matrix, to\n # map from output pixels to input pixels. By taking the inverse of our\n # transformation matrix, we get exactly that.\n tr = cell_invert(forward_t).T\n\n # Calculate the offset of the image so that all points of the grid \"fall\" inside\n # the output array.\n # For this we just calculate the centers of the input and output images\n center_input = 0.5 * (_a.asarrayd(grid.shape) - 1)\n center_output = 0.5 * (_a.asarrayd(output_shape) - 1)\n\n # And then make sure that the input center that is interpolated from the output\n # falls in the actual input's center\n offset = center_input - tr.dot(center_output)\n\n # We pass all the parameters to scipy's affine_transform\n transformed_image = affine_transform(\n grid.grid,\n tr,\n order=1,\n offset=offset,\n output_shape=output_shape,\n mode=mode,\n **kwargs,\n )\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1))\n new_grid.grid = transformed_image\n new_grid.geometry = grid.geometry\n new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset))\n\n # Find the offset between the origin before and after the transformation\n return new_grid\n```\n###test function signature:\n```python\n\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_grid_cell", + "reference": "\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10))\n\n assert new_grid.shape == (10, 10, 10)\n assert new_grid.lattice.is_cartesian()\n\n if not skewed:\n assert np.allclose(new_grid.lattice.cell, grid.lattice.cell)\n\n assert np.allclose(new_grid.grid, grid.grid) == (not skewed)\n assert not np.allclose(new_grid.grid, 0)\n\n assert np.allclose(new_grid.origin, grid.origin) == (not skewed)\n\n # Convert to a skewed cell\n directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]])\n new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5))\n\n assert new_grid.shape == (5, 5, 5)\n for i in range(3):\n n = new_grid.lattice.cell[i] / directions[i]\n assert np.allclose(n, n[0])", + "completions": [ + "```python\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n cell = grid.cell.dot(cell_invert(grid.cell))\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = transform_grid_cell(grid, cell, grid.shape, mode=\"constant\", order=1)\n\n # Check that the new grid has the same shape and cell as the original grid\n assert new_grid.shape == grid.shape\n assert np.allclose(new_grid.cell, grid.cell)\n\n # Check that the new grid has the same geometry as the original grid\n assert new_grid.geometry == grid.geometry\n\n # Check that the new grid has the same origin as the original grid\n assert np.allclose(new_grid.origin, grid.origin)\n\n # Check that the new grid has the same dcell as the original grid\n assert np.allclose(new_grid.dcell, grid.dcell)\n\n # Check that the new grid has the same grid as the original grid\n assert np.allclose(new_grid.grid, grid.grid)\n\n # Check that the new grid has the same skewed as the original grid\n assert np.allclose(new_grid.skewed, grid.skewed)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid" + ], + "line": 63, + "token": 533, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 75, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:\n\n index_norm = norm(index, axis=1)\n xq_norm = norm(xq.T)\n sim = np.dot(index, xq.T) / (index_norm * xq_norm)\n return sim\n```\n###test function signature:\n```python\ndef test_top_scores__is_sorted(test_index):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_top_scores__is_sorted", + "reference": "def test_top_scores__is_sorted(test_index):\n\n xq = test_index[0] # should have max similarity\n\n sim = similarity_matrix(xq, test_index)\n _, idx = top_scores(sim, 3)\n\n # Scores and indexes should be sorted ascending\n assert np.array_equal(idx, np.array([2, 1, 0]))", + "completions": [ + "```python\ndef test_top_scores__is_sorted(test_index):\n sim = similarity_matrix(xq, test_index)\n assert np.all(np.diff(sim, axis=1) >= 0)\n```\n" + ], + "line": 15, + "token": 125, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_mixed_ref_area_2", + "reference": "\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n\n out_bars_with_mixed_ref_area_2 = _generate_nanoplot(**CASES[16])\n\n assert _is_nanoplot_output(out_bars_with_mixed_ref_area_2)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n y_ref_area = [1, 10]\n y_ref_line = 5\n x_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n expand_x = \"2020-01-01 00:00:00 UTC\"\n expand_y = \"2020-01-01 00:00:00 UTC\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n all_single_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_" + ], + "line": 749, + "token": 7038, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 77, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def mean(self, lower=0.0, upper=np.inf):\n loc = []\n for idx in self.index:\n qpd = self.qpd.loc[idx, :].values[0]\n l, _ = quad(exp_func, args=(qpd), a=lower, b=upper)\n loc.append(l)\n loc_arr = np.array(loc)\n return pd.DataFrame(loc_arr, index=self.index, columns=self.columns)\n```\n###test function signature:\n```python\ndef test_qpd_b_simple_use():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_qpd_b_simple_use", + "reference": "def test_qpd_b_simple_use():\n from skpro.distributions.qpd import QPD_B\n\n qpd = QPD_B(\n alpha=0.2,\n qv_low=[1, 2],\n qv_median=[3, 4],\n qv_high=[5, 6],\n lower=0,\n upper=10,\n )\n\n qpd.mean()", + "completions": [ + "```python\ndef test_qpd_b_simple_use():\n import numpy as np\n import pandas as pd\n from scipy.integrate import quad\n from scipy.special import expit\n\n def exp_func(x, qpd):\n return expit(qpd * x)\n\n qpd = pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z'])\n qpd_b = qpd.mean(lower=0.0, upper=np.inf)\n assert qpd_b.equals(pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z']))\n```\n" + ], + "line": 18, + "token": 146, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef create(\n supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False\n) -> adapters.HTTPAdapter:\n settings = supply_settings()\n allowed_methods = set(util.Retry.DEFAULT_ALLOWED_METHODS)\n if include_post:\n allowed_methods.add(\"POST\")\n\n retries = util.Retry(\n total=settings.total,\n backoff_factor=settings.backoff_factor,\n status_forcelist=settings.statuses,\n allowed_methods=allowed_methods,\n )\n return adapters.HTTPAdapter(max_retries=retries)\n```\n###test function signature:\n```python\n\ndef test_should_create_adapter(include_post):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_should_create_adapter", + "reference": "\ndef test_should_create_adapter(include_post):\n total_retries = 5\n backoff_factor = 0.4\n statuses = (random.randint(400, 600),)\n settings = config.ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses)\n adapter = http_adapter.create(supply_settings=lambda: settings, include_post=include_post)\n assert adapter.max_retries.total == total_retries\n assert adapter.max_retries.backoff_factor == backoff_factor\n assert adapter.max_retries.status_forcelist == statuses\n allowed_methods = adapter.max_retries.allowed_methods or []\n if include_post:\n assert \"POST\" in allowed_methods\n else:\n assert \"POST\" not in allowed_methods", + "completions": [ + "```python\n\ndef test_should_create_adapter(include_post):\n assert create(include_post=include_post) == adapters.HTTPAdapter(max_retries=util.Retry(\n total=config.ResilienceSettings().total,\n backoff_factor=config.ResilienceSettings().backoff_factor,\n status_forcelist=config.ResilienceSettings().statuses,\n allowed_methods=set(util.Retry.DEFAULT_ALLOWED_METHODS) if not include_post else set(util.Retry.DEFAULT_ALLOWED_METHODS).union({\"POST\"})\n ))\n```\n" + ], + "line": 24, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 79, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:\n summary_path = os.path.join(trial_path, 'summary.csv')\n if not os.path.exists(summary_path):\n raise ValueError(f\"summary.csv does not exist in {trial_path}.\")\n trial_summary_df = load_summary_file(summary_path, dict_columns=['best_module_params'])\n config_yaml_path = os.path.join(trial_path, 'config.yaml')\n with open(config_yaml_path, 'r') as f:\n config_dict = yaml.safe_load(f)\n yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)\n if output_path is not None:\n with open(output_path, 'w') as f:\n yaml.dump(yaml_dict, f)\n return yaml_dict\n```\n###test function signature:\n```python\n\ndef test_runner(evaluator):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_runner", + "reference": "\ndef test_runner(evaluator):\n os.environ['BM25'] = 'bm25'\n evaluator.start_trial(os.path.join(resource_dir, 'simple.yaml'))\n project_dir = evaluator.project_dir\n\n def runner_test(runner: Runner):\n answer = runner.run('What is the best movie in Korea? Have Korea movie ever won Oscar?',\n 'retrieved_contents')\n assert len(answer) == 1\n assert isinstance(answer, list)\n assert isinstance(answer[0], str)\n\n runner = Runner.from_trial_folder(os.path.join(project_dir, '0'))\n runner_test(runner)\n runner_test(runner)\n\n with tempfile.NamedTemporaryFile(suffix='yaml', mode='w+t') as yaml_path:\n extract_best_config(os.path.join(project_dir, '0'), yaml_path.name)\n runner = Runner.from_yaml(yaml_path.name, project_dir=project_dir)\n runner_test(runner)", + "completions": [ + "```python\ndef test_runner(evaluator):\n def test_extract_best_config():\n trial_path = 'path/to/trial'\n output_path = 'path/to/output'\n yaml_dict = extract_best_config(trial_path, output_path)\n assert yaml_dict is not None\n assert isinstance(yaml_dict, dict)\n assert 'best_module_params' in yaml_dict\n assert isinstance(yaml_dict['best_module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['" + ], + "line": 23, + "token": 205, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 80, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def init_weights(self):\n logger = MMLogger.get_current_instance()\n if self.init_cfg is None:\n logger.warn(f'No pre-trained weights for '\n f'{self.__class__.__name__}, '\n f'training start from scratch')\n if self.use_abs_pos_embed:\n trunc_normal_(self.absolute_pos_embed, std=0.02)\n for m in self.modules():\n if isinstance(m, nn.Linear):\n trunc_normal_init(m, std=.02, bias=0.)\n elif isinstance(m, nn.LayerNorm):\n constant_init(m, 1.0)\n else:\n assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n f'specify `Pretrained` in ' \\\n f'`init_cfg` in ' \\\n f'{self.__class__.__name__} '\n ckpt = CheckpointLoader.load_checkpoint(\n self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n if 'state_dict' in ckpt:\n _state_dict = ckpt['state_dict']\n elif 'model' in ckpt:\n _state_dict = ckpt['model']\n else:\n _state_dict = ckpt\n if self.convert_weights:\n # supported loading weight from original repo,\n _state_dict = swin_converter(_state_dict)\n\n state_dict = OrderedDict()\n for k, v in _state_dict.items():\n if k.startswith('backbone.'):\n state_dict[k[9:]] = v\n\n # strip prefix of state_dict\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n # reshape absolute position embedding\n if state_dict.get('absolute_pos_embed') is not None:\n absolute_pos_embed = state_dict['absolute_pos_embed']\n N1, L, C1 = absolute_pos_embed.size()\n N2, C2, H, W = self.absolute_pos_embed.size()\n if N1 != N2 or C1 != C2 or L != H * W:\n logger.warning('Error in loading absolute_pos_embed, pass')\n else:\n state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n # interpolate position bias table if needed\n relative_position_bias_table_keys = [\n k for k in state_dict.keys()\n if 'relative_position_bias_table' in k\n ]\n for table_key in relative_position_bias_table_keys:\n table_pretrained = state_dict[table_key]\n table_current = self.state_dict()[table_key]\n L1, nH1 = table_pretrained.size()\n L2, nH2 = table_current.size()\n if nH1 != nH2:\n logger.warning(f'Error in loading {table_key}, pass')\n elif L1 != L2:\n S1 = int(L1**0.5)\n S2 = int(L2**0.5)\n table_pretrained_resized = F.interpolate(\n table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n size=(S2, S2),\n mode='bicubic')\n state_dict[table_key] = table_pretrained_resized.view(\n nH2, L2).permute(1, 0).contiguous()\n\n # load state_dict\n self.load_state_dict(state_dict, False)\n```\n###test function signature:\n```python\ndef test_swin_transformer():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_swin_transformer", + "reference": "def test_swin_transformer():\n\n with pytest.raises(TypeError):\n # Pretrained arg must be str or None.\n SwinTransformer(pretrained=123)\n\n with pytest.raises(AssertionError):\n # Because swin uses non-overlapping patch embed, so the stride of patch\n # embed must be equal to patch size.\n SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)\n\n # test pretrained image size\n with pytest.raises(AssertionError):\n SwinTransformer(pretrain_img_size=(224, 224, 224))\n\n # Test absolute position embedding\n temp = torch.randn((1, 3, 224, 224))\n model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)\n model.init_weights()\n model(temp)\n\n # Test patch norm\n model = SwinTransformer(patch_norm=False)\n model(temp)\n\n # Test normal inference\n temp = torch.randn((1, 3, 32, 32))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 31, 31))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 112, 137))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 28, 35)\n assert outs[1].shape == (1, 192, 14, 18)\n assert outs[2].shape == (1, 384, 7, 9)\n assert outs[3].shape == (1, 768, 4, 5)\n\n model = SwinTransformer(frozen_stages=4)\n model.train()\n for p in model.parameters():\n assert not p.requires_grad", + "completions": [ + "```python\ndef test_swin_transformer():\n from mmcv.cnn import (build_pos_embed, build_pos_embed_sine,\n constant_init, trunc_normal_init)\n from mmcv.runner import CheckpointLoader\n from mmdet.models.builder import build_backbone\n from mmdet.models.utils import swin_converter\n from mmdet.utils import get_root_logger\n from mmdet.utils.logger import MMLogger\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import PatchEmbed\n from mmdet.models.backbones.swin import PatchMerging\n from mmdet.models.backbones.swin import BasicLayer\n from mmdet.models.backbones.swin import BasicBlock\n from mmdet.models.backbones.swin import DepthwiseSeparableConv\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransform" + ], + "line": 79, + "token": 586, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 81, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\n\ndef test_date_columns_equal():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_date_columns_equal", + "reference": "\ndef test_date_columns_equal():\n data = \"\"\"a|b|expected\n2017-01-01|2017-01-01|True\n2017-01-02|2017-01-02|True\n2017-10-01|2017-10-10|False\n2017-01-01||False\n|2017-01-01|False\n||True\"\"\"\n df = pl.read_csv(\n io.StringIO(data),\n separator=\"|\",\n null_values=[\"NULL\"],\n missing_utf8_is_empty_string=True,\n )\n # First compare just the strings\n actual_out = columns_equal(df[\"a\"], df[\"b\"], rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)\n\n # Then compare converted to datetime objects\n col_a = df[\"a\"].str.to_datetime(strict=False)\n col_b = df[\"b\"].str.to_datetime(strict=False)\n actual_out = columns_equal(col_a, col_b, rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)", + "completions": [ + "```python\ndef test_date_columns_equal():\n col_1 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n col_2 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 82, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def upsert_child_firing(\n db: PrefectDBInterface,\n session: AsyncSession,\n firing: Firing,\n):\n automation_id = firing.trigger.automation.id\n parent_trigger_id = firing.trigger.parent.id\n child_trigger_id = firing.trigger.id\n\n upsert = (\n postgresql.insert(db.CompositeTriggerChildFiring)\n .values(\n automation_id=automation_id,\n parent_trigger_id=parent_trigger_id,\n child_trigger_id=child_trigger_id,\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n )\n .on_conflict_do_update(\n index_elements=[\n db.CompositeTriggerChildFiring.automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id,\n ],\n set_=dict(\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n updated=pendulum.now(\"UTC\"),\n ),\n )\n )\n\n await session.execute(upsert)\n\n result = await session.execute(\n sa.select(db.CompositeTriggerChildFiring).filter(\n db.CompositeTriggerChildFiring.automation_id == automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id == parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id == child_trigger_id,\n )\n )\n\n return result.scalars().one()\n```\n###test function signature:\n```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_insert_compound_trigger_child_firing", + "reference": "\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n firing = Firing(\n trigger=saw_a_red_dragon,\n trigger_states={TriggerState.Triggered},\n triggered=frozen_time.add(seconds=1),\n triggering_labels={},\n triggering_event=baby_red_dragon_passed_by,\n )\n res = await upsert_child_firing(\n session=automations_session,\n firing=firing,\n )\n\n assert res and res.id\n assert res.automation_id == raise_the_alarm.id\n assert res.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert res.child_trigger_id == saw_a_red_dragon.id\n assert res.child_firing_id == firing.id\n assert res.child_fired_at == firing.triggered\n\n child_firings = await get_child_firings(\n session=automations_session, trigger=saw_a_red_and_green_dragon\n )\n assert len(child_firings) == 1\n assert child_firings[0].id == res.id\n assert child_firings[0].child_firing_id == firing.id", + "completions": [ + "```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n \"\"\"\n Test that a compound trigger child firing is inserted correctly\n \"\"\"\n # Arrange\n firing = Firing(\n id=uuid4(),\n triggered=frozen_time,\n trigger=saw_a_red_and_green_dragon,\n event=baby_red_dragon_passed_by,\n )\n\n # Act\n result = await upsert_child_firing(\n db=automations_session,\n session=automations_session,\n firing=firing,\n )\n\n # Assert\n assert result.automation_id == raise_the_alarm.id\n assert result.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert result.child_trigger_id == saw_a_red_dragon.id\n assert result.child_firing_id == firing.id\n assert result.child_fired_at == firing.triggered\n assert result.child_firing == firing.dict()\n assert result.updated == frozen_time\n\n```\n" + ], + "line": 58, + "token": 225, + "line_diff": 3, + "token_diff": 1 + }, + { + "id": 83, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef write_output(output_file, output_type, results):\n\n # If json or txt, write contents directly to file.\n if output_type == \"json\" or output_type == \"txt\":\n with open(output_file, \"w\") as f:\n json.dump(results, f, indent=4)\n return\n # If csv or xlsx, convert results to pandas dataframes.\n urls_df = get_urls_df(results)\n codes_df = get_codes_df(results)\n\n # If csv, write dataframes to separate csv files for urls, codes.\n if output_type == \"csv\":\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n urls_df.to_csv(urls_output_file, index=False)\n codes_df.to_csv(codes_output_file, index=False)\n\n # If xlsx, write dataframes to separate sheets for urls, codes.\n if output_type == \"xlsx\":\n writer = pd.ExcelWriter(output_file, engine=\"xlsxwriter\")\n urls_df.to_excel(writer, sheet_name=\"URLs\", index=False)\n codes_df.to_excel(writer, sheet_name=\"Codes\", index=False)\n writer.close()\n```\n###test function signature:\n```python\n def test_write_output_csv(self, mock_urls, mock_codes):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_output_csv", + "reference": " def test_write_output_csv(self, mock_urls, mock_codes):\n\n test_file = \"./test_output/test_file.csv\"\n test_file_urls = \"./test_output/test_file_urls.csv\"\n test_file_codes = \"./test_output/test_file_codes.csv\"\n test_results = {\"test\": \"test\"}\n mock_urls.return_value = pd.DataFrame([test_results])\n mock_codes.return_value = pd.DataFrame([test_results])\n\n with open(test_file_urls, \"w\") as f:\n pass\n\n with open(test_file_codes, \"w\") as f:\n pass\n\n write_output(test_file, \"csv\", test_results)\n\n test_data_urls = pd.read_csv(test_file_urls).to_dict(orient=\"records\")[0]\n test_data_codes = pd.read_csv(test_file_codes).to_dict(orient=\"records\")[0]\n\n os.remove(test_file_urls)\n os.remove(test_file_codes)\n\n self.assertEqual(test_data_urls, test_results)\n self.assertEqual(test_data_codes, test_results)", + "completions": [ + "```python\ndef test_write_output_csv(self, mock_urls, mock_codes):\n \"\"\"\n Test write_output function with csv output type.\n \"\"\"\n output_file = \"output.csv\"\n output_type = \"csv\"\n results = {\"urls\": mock_urls, \"codes\": mock_codes}\n write_output(output_file, output_type, results)\n\n # Check that the urls and codes dataframes are written to separate csv files.\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n self.assertTrue(os.path.isfile(urls_output_file))\n self.assertTrue(os.path.isfile(codes_output_file))\n\n # Check that the dataframes are correct.\n urls_df = pd.read_csv(urls_output_file)\n codes_df = pd.read_csv(codes_output_file)\n self.assertEqual(urls_df.shape, (2, 2))\n self.assertEqual(codes_df.shape, (2, 2))\n self.assertEqual(urls_df.iloc[0, 0], \"url1\")\n self.assertEqual(codes_df.iloc[0, 0], \"code1\")\n self.assertEqual(urls_df.iloc[1, 0], \"url2\")\n self.assertEqual(codes_df.iloc[1, 0], \"code2\")\n```\n" + ], + "line": 31, + "token": 279, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 84, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def terminate_idle_instance() -> None:\n async with get_session_ctx() as session:\n async with PROCESSING_POOL_LOCK:\n res = await session.execute(\n select(InstanceModel)\n .where(\n InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE,\n InstanceModel.deleted == False,\n InstanceModel.job == None, # noqa: E711\n InstanceModel.status == InstanceStatus.IDLE,\n )\n .options(joinedload(InstanceModel.project))\n )\n instances = res.scalars().all()\n\n for instance in instances:\n last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)\n if instance.last_job_processed_at is not None:\n last_time = instance.last_job_processed_at.replace(\n tzinfo=datetime.timezone.utc\n )\n\n idle_seconds = instance.termination_idle_time\n delta = datetime.timedelta(seconds=idle_seconds)\n\n current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc)\n\n if last_time + delta < current_time:\n jpd = JobProvisioningData.__response__.parse_raw(\n instance.job_provisioning_data\n )\n await terminate_job_provisioning_data_instance(\n project=instance.project, job_provisioning_data=jpd\n )\n instance.deleted = True\n instance.deleted_at = get_current_datetime()\n instance.finished_at = get_current_datetime()\n instance.status = InstanceStatus.TERMINATED\n instance.termination_reason = \"Idle timeout\"\n\n idle_time = current_time - last_time\n logger.info(\n \"Instance %s terminated by termination policy: idle time %ss\",\n instance.name,\n str(idle_time.seconds),\n extra={\n \"instance_name\": instance.name,\n \"instance_status\": InstanceStatus.TERMINATED.value,\n },\n )\n\n await session.commit()\n```\n###test function signature:\n```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_terminate_by_idle_timeout", + "reference": "\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n project = await create_project(session=session)\n pool = await create_pool(session, project)\n\n instance = await create_instance(session, project, pool, status=InstanceStatus.IDLE)\n instance.termination_idle_time = 300\n instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE\n instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)\n await session.commit()\n\n with patch(\n \"dstack._internal.server.background.tasks.process_instances.terminate_job_provisioning_data_instance\"\n ):\n await terminate_idle_instance()\n\n await session.refresh(instance)\n\n assert instance is not None\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"", + "completions": [ + "```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n # given\n project = ProjectModel(name=\"test_project\")\n await session.add(project)\n await session.commit()\n\n instance = InstanceModel(\n name=\"test_instance\",\n project=project,\n termination_policy=TerminationPolicy.DESTROY_AFTER_IDLE,\n termination_idle_time=10,\n status=InstanceStatus.IDLE,\n created_at=datetime.datetime.now(datetime.timezone.utc),\n )\n await session.add(instance)\n await session.commit()\n\n # when\n await terminate_idle_instance()\n\n # then\n instance = await session.get(InstanceModel, instance.id)\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"\n```\n" + ], + "line": 56, + "token": 289, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 85, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef nargsort(\n items,\n kind: str = \"quicksort\",\n ascending: bool = True,\n na_position: str = \"last\",\n key: Callable | None = None,\n mask: npt.NDArray[np.bool_] | None = None,\n) -> npt.NDArray[np.intp]:\n\n if key is not None:\n items = ensure_key_mapped(items, key)\n return nargsort(\n items,\n kind=kind,\n ascending=ascending,\n na_position=na_position,\n key=None,\n mask=mask,\n )\n\n if isinstance(items, ABCRangeIndex):\n return items.argsort(ascending=ascending) # TODO: test coverage with key?\n elif not isinstance(items, ABCMultiIndex):\n items = extract_array(items)\n if mask is None:\n mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?\n\n if is_extension_array_dtype(items):\n return items.argsort(ascending=ascending, kind=kind, na_position=na_position)\n else:\n items = np.asanyarray(items)\n\n idx = np.arange(len(items))\n non_nans = items[~mask]\n non_nan_idx = idx[~mask]\n\n nan_idx = np.nonzero(mask)[0]\n if not ascending:\n non_nans = non_nans[::-1]\n non_nan_idx = non_nan_idx[::-1]\n indexer = non_nan_idx[non_nans.argsort(kind=kind)]\n if not ascending:\n indexer = indexer[::-1]\n # Finally, place the NaNs at the end or the beginning according to\n # na_position\n if na_position == \"last\":\n indexer = np.concatenate([indexer, nan_idx])\n elif na_position == \"first\":\n indexer = np.concatenate([nan_idx, indexer])\n else:\n raise ValueError(f\"invalid na_position: {na_position}\")\n return ensure_platform_int(indexer)\n```\n###test function signature:\n```python\n\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nargsort", + "reference": "\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)\n\n # mergesort is the most difficult to get right because we want it to be\n # stable.\n\n # According to numpy/core/tests/test_multiarray, \"\"\"The number of\n # sorted items must be greater than ~50 to check the actual algorithm\n # because quick and merge sort fall over to insertion sort for small\n # arrays.\"\"\"\n\n result = nargsort(\n items, kind=\"mergesort\", ascending=ascending, na_position=na_position\n )\n tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)", + "completions": [ + "```python\n def test_nargsort(self, ascending, na_position, exp, box):\n items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216," + ], + "line": 58, + "token": 412, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 86, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_rotated_rb_circuits(\n n_qubits: int,\n num_cliffords: int,\n theta: Optional[float] = None,\n trials: int = 1,\n return_type: Optional[str] = None,\n seed: Optional[int] = None,\n) -> List[QPROGRAM]:\n\n circuits = cast(\n List[cirq.Circuit],\n generate_rb_circuits(n_qubits, num_cliffords, trials, seed=seed),\n )\n\n if theta is None:\n theta = 2 * np.pi * np.random.rand()\n\n for circ in circuits:\n qubits = list(circ.all_qubits())\n circ.insert(len(circ) // 2, cirq.Rz(rads=theta).on(qubits[0]))\n\n return_type = \"cirq\" if not return_type else return_type\n return [convert_from_mitiq(circuit, return_type) for circuit in circuits]\n```\n###test function signature:\n```python\n\ndef test_rotated_rb_circuits(n_qubits, theta):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rotated_rb_circuits", + "reference": "\ndef test_rotated_rb_circuits(n_qubits, theta):\n depth = 10\n for trials in [2, 3]:\n circuits = generate_rotated_rb_circuits(\n n_qubits=n_qubits, num_cliffords=depth, theta=theta, trials=trials\n )\n for circ in circuits:\n zero_prob = (\n cirq.DensityMatrixSimulator()\n .simulate(circ)\n .final_density_matrix[0, 0]\n .real\n )\n assert -1.0001 <= zero_prob <= 1.0001", + "completions": [ + "```python\ndef test_rotated_rb_circuits(n_qubits, theta):\n from cirq import Circuit, LineQubit\n from cirq.testing import assert_has_diagram\n from mitiq.generate import generate_rotated_rb_circuits\n\n circuits = generate_rotated_rb_circuits(n_qubits, 10, theta)\n for circuit in circuits:\n assert_has_diagram(\n circuit,\n f\"\"\"\n0: ───Rz({theta})───\n1: ───Rz({theta})───\n2: ───Rz({theta})───\n3: ───Rz({theta})───\n4: ───Rz({theta})───\n5: ───Rz({theta})───\n6: ───Rz({theta})───\n7: ───Rz({theta})───\n8: ───Rz({theta})───\n9: ───Rz({theta})───\n10: ───Rz({theta})───\n11: ───Rz({theta})───\n12: ───Rz({theta})───\n13: ───Rz({theta})───\n14: ───Rz({theta})───\n15: ───Rz({theta})───\n16: ───Rz({theta})───\n17: ───Rz({theta})───\n18: ───Rz({theta})───\n19: ───Rz({theta})───\n20: ───Rz({theta})───\n21: ───Rz({theta})───\n22: ───Rz({theta})───\n23: ───Rz({theta})───\n24: ───Rz({theta})───\n25: ───Rz({theta})───\n26: ───Rz({theta})───\n27: ───Rz({theta})───\n28: ───Rz({theta})───\n29: ───Rz({theta})───\n30: ───Rz({theta})───\n31: ───Rz({theta})───\n32: ───Rz({theta})───\n33: ───Rz({theta})───\n34: ───Rz({theta})───\n35: ───Rz({theta})───\n36: ───Rz({theta})───\n37: ───Rz({theta})───\n38: ───Rz({theta})───\n39: ───Rz({theta})───\n40: ───Rz({theta})───\n41: ───Rz({theta})───\n42: ───Rz({theta})───\n43: ───Rz({theta})───\n44: ───Rz({theta})───\n45: ───Rz({theta})───\n46: ───Rz({theta})───\n47: ───Rz({theta})───\n48: ───Rz({theta})───\n49: ───Rz({theta})───\n50: ───Rz({theta})───\n51: ───Rz({theta})───\n52: ───Rz({theta})───\n53: ───Rz({theta})───\n54: ───Rz({theta})───\n55: ───Rz({theta})───\n56: ───Rz({theta})───\n57: ───Rz({theta})───\n58: ───Rz({theta})───\n59: ───Rz({theta})───\n60: ───Rz({theta})───\n61: ───Rz({theta})───\n62: ───Rz({theta})───\n63: ───" + ], + "line": 29, + "token": 226, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 87, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef eye(N, M=None, k=None, dtype=None):\n dtype = to_torch_dtype(dtype or config.floatx())\n M = N if M is None else M\n k = 0 if k is None else k\n if k == 0:\n # TODO: torch.eye doesn't support bfloat16 with cpu\n if get_device() == \"cpu\" and dtype == torch.bfloat16:\n return cast(\n torch.eye(\n N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n ),\n dtype,\n )\n return torch.eye(N, M, dtype=dtype, device=get_device())\n diag_length = builtins.max(N, M)\n diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n return torch.diag(diag, diagonal=k)[:N, :M]\n```\n###test function signature:\n```python\n\n def test_eye(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_eye", + "reference": "\n def test_eye(self):\n self.assertAllClose(knp.eye(3), np.eye(3))\n self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n self.assertAllClose(knp.Eye()(3), np.eye(3))\n self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1))\n\n # Test k >= N\n self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3))\n\n # Test k > 0 and N >= M\n self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1))\n\n # Test k > 0 and N < M and N + k > M\n self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2))\n\n # Test k < 0 and M >= N\n self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1))\n\n # Test k < 0 and M < N and M - k > N\n self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))", + "completions": [ + "```python\n\n def test_eye(self):\n self.assertEqual(eye(3), torch.eye(3))\n self.assertEqual(eye(3, 4), torch.eye(3, 4))\n self.assertEqual(eye(3, 4, 1), torch.eye(3, 4, 1))\n self.assertEqual(eye(3, 4, -1), torch.eye(3, 4, -1))\n self.assertEqual(eye(3, 4, 2), torch.eye(3, 4, 2))\n self.assertEqual(eye(3, 4, -2), torch.eye(3, 4, -2))\n self.assertEqual(eye(3, 4, 3), torch.eye(3, 4, 3))\n self.assertEqual(eye(3, 4, -3), torch.eye(3, 4, -3))\n self.assertEqual(eye(3, 4, 4), torch.eye(3, 4, 4))\n self.assertEqual(eye(3, 4, -4), torch.eye(3, 4, -4))\n self.assertEqual(eye(3, 4, 5), torch.eye(3, 4, 5))\n self.assertEqual(eye(3, 4, -5), torch.eye(3, 4, -5))\n self.assertEqual(eye(3, 4, 6), torch.eye(3, 4, 6))\n self.assertEqual(eye(3, 4, -6), torch.eye(3, 4, -6))\n self.assertEqual(eye(3, 4, 7), torch.eye(3, 4, 7))\n self.assertEqual(eye(3, 4, -7), torch.eye(3, 4, -7))\n self.assertEqual(eye(3, 4, 8), torch.eye(3, 4, 8))\n self.assertEqual(eye(3, 4, -8), torch.eye(3, 4, -8))\n self.assertEqual(eye(3, 4, 9), torch.eye(3, 4, 9))\n self.assertEqual(eye(3, 4, -9), torch.eye(3, 4, -9))\n self.assertEqual(eye(3, 4, 10), torch.eye(3, 4, 10))\n self.assertEqual(eye(3, 4, -10), torch.eye(3, 4, -10))\n self.assertEqual(eye(3, 4, 11), torch.eye(3, 4, 11))\n self.assertEqual(eye(3, 4, -11), torch.eye(3, 4, -11))\n self.assertEqual(eye(3, 4, 12), torch.eye(3, 4, 12))\n self.assertEqual(eye(3, 4, -12), torch.eye(3, 4, -12))\n self.assertEqual(eye(3, 4, 13), torch.eye(3, 4, 13))\n self.assertEqual(eye(3, 4, -13), torch.eye(3, 4, -13))\n self.assertEqual(eye(3, 4, 14), torch.eye(3, 4, 14))\n self.assertEqual(eye(3, 4, -14), torch.eye(3, 4, -14))\n self.assertEqual(eye(3, 4, 15), torch.eye(" + ], + "line": 27, + "token": 219, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 88, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef ragged_qk_dotprod(\n query: RaggedActivations, key: RaggedActivations, lut: RaggedQkPidLookupTable\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n # check constraints\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n # allocates output\n # max_n_ctx_q_across_seqs = query.max_n_ctx_per_seq\n\n assert query.n_seqs == key.n_seqs\n # TODO: flag use zeros for garbage\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n # Integers\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n```\n###test function signature:\n```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ragged_qk_dotprod_single_seq", + "reference": "\ndef test_ragged_qk_dotprod_single_seq(dtype):\n d_head = 2\n\n key = RaggedActivations.from_list(\n [\n make_seq(n_ctx=3, value=42, d_model=d_head, dtype=dtype),\n ]\n )\n query = RaggedActivations.from_list(\n [\n make_seq(n_ctx=4, value=55, d_model=d_head, dtype=dtype),\n ]\n )\n torch_scores = scores_via_qk_dotprod(query, key)\n print(f\"{torch_scores=}\")\n\n lut = RaggedQkPidLookupTable.from_query_and_key_tokens_per_seq(\n n_ctx_q_per_seq=query.n_ctx_per_seq, n_ctx_k_per_seq=key.n_ctx_per_seq\n )\n\n scores = ragged_qk_dotprod(query, key, lut)\n assert_eq(torch_scores, scores)", + "completions": [ + "```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n # query\n query_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n query = RaggedActivations(\n raw_tensor=query_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # key\n key_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n key = RaggedActivations(\n raw_tensor=key_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # lut\n pid_to_in_q_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_in_k_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_q_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_k_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_seq_idx = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n lut = RaggedQkPidLookupTable(\n pid_to_in_q_token_offset=pid_to_in_q_token_offset,\n pid_to_in_k_token_offset=pid_to_in_k_token_offset,\n pid_to_out_q_block=pid_to_out_q_block,\n pid_to_out_k_block=pid_to_out_k_block,\n pid_to_out_seq_idx=pid_to_out_seq_idx,\n n_pids_total=1,\n )\n\n # scores_out\n scores_out = torch.ones(\n (1, 1, 1), dtype=dtype, device=torch.device(\"cuda\")\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_" + ], + "line": 57, + "token": 310, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 89, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):\n # Ravel both arrays, behavior for the first array could be different\n ar1 = np.asarray(ar1).ravel()\n ar2 = np.asarray(ar2).ravel()\n\n # Ensure that iteration through object arrays yields size-1 arrays\n if ar2.dtype == object:\n ar2 = ar2.reshape(-1, 1)\n\n if kind not in {None, 'sort', 'table'}:\n raise ValueError(\n f\"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.\")\n\n # Can use the table method if all arrays are integers or boolean:\n is_int_arrays = all(ar.dtype.kind in (\"u\", \"i\", \"b\") for ar in (ar1, ar2))\n use_table_method = is_int_arrays and kind in {None, 'table'}\n\n if use_table_method:\n if ar2.size == 0:\n if invert:\n return np.ones_like(ar1, dtype=bool)\n else:\n return np.zeros_like(ar1, dtype=bool)\n\n # Convert booleans to uint8 so we can use the fast integer algorithm\n if ar1.dtype == bool:\n ar1 = ar1.astype(np.uint8)\n if ar2.dtype == bool:\n ar2 = ar2.astype(np.uint8)\n\n ar2_min = np.min(ar2)\n ar2_max = np.max(ar2)\n\n ar2_range = int(ar2_max) - int(ar2_min)\n\n # Constraints on whether we can actually use the table method:\n # 1. Assert memory usage is not too large\n below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)\n # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype\n range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max\n # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype\n if ar1.size > 0:\n ar1_min = np.min(ar1)\n ar1_max = np.max(ar1)\n\n # After masking, the range of ar1 is guaranteed to be\n # within the range of ar2:\n ar1_upper = min(int(ar1_max), int(ar2_max))\n ar1_lower = max(int(ar1_min), int(ar2_min))\n\n range_safe_from_overflow &= all((\n ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,\n ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min\n ))\n\n # Optimal performance is for approximately\n # log10(size) > (log10(range) - 2.27) / 0.927.\n # However, here we set the requirement that by default\n # the intermediate array can only be 6x\n # the combined memory allocation of the original\n # arrays. See discussion on \n # https://github.com/numpy/numpy/pull/12065.\n\n if (\n range_safe_from_overflow and \n (below_memory_constraint or kind == 'table')\n ):\n\n if invert:\n outgoing_array = np.ones_like(ar1, dtype=bool)\n else:\n outgoing_array = np.zeros_like(ar1, dtype=bool)\n\n # Make elements 1 where the integer exists in ar2\n if invert:\n isin_helper_ar = np.ones(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 0\n else:\n isin_helper_ar = np.zeros(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 1\n\n # Mask out elements we know won't work\n basic_mask = (ar1 <= ar2_max) & (ar1 >= ar2_min)\n outgoing_array[basic_mask] = isin_helper_ar[ar1[basic_mask] -\n ar2_min]\n\n return outgoing_array\n elif kind == 'table': # not range_safe_from_overflow\n raise RuntimeError(\n \"You have specified kind='table', \"\n \"but the range of values in `ar2` or `ar1` exceed the \"\n \"maximum integer of the datatype. \"\n \"Please set `kind` to None or 'sort'.\"\n )\n elif kind == 'table':\n raise ValueError(\n \"The 'table' method is only \"\n \"supported for boolean or integer arrays. \"\n \"Please select 'sort' or None for kind.\"\n )\n\n\n # Check if one of the arrays may contain arbitrary objects\n contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject\n\n # This code is run when\n # a) the first condition is true, making the code significantly faster\n # b) the second condition is true (i.e. `ar1` or `ar2` may contain\n # arbitrary objects), since then sorting is not guaranteed to work\n if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:\n if invert:\n mask = np.ones(len(ar1), dtype=bool)\n for a in ar2:\n mask &= (ar1 != a)\n else:\n mask = np.zeros(len(ar1), dtype=bool)\n for a in ar2:\n mask |= (ar1 == a)\n return mask\n\n # Otherwise use sorting\n if not assume_unique:\n ar1, rev_idx = np.unique(ar1, return_inverse=True)\n ar2 = np.unique(ar2)\n\n ar = np.concatenate((ar1, ar2))\n # We need this to be a stable sort, so always use 'mergesort'\n # here. The values from the first array should always come before\n # the values from the second array.\n order = ar.argsort(kind='mergesort')\n sar = ar[order]\n if invert:\n bool_ar = (sar[1:] != sar[:-1])\n else:\n bool_ar = (sar[1:] == sar[:-1])\n flag = np.concatenate((bool_ar, [invert]))\n ret = np.empty(ar.shape, dtype=bool)\n ret[order] = flag\n\n if assume_unique:\n return ret[:len(ar1)]\n else:\n return ret[rev_idx]\n```\n###test function signature:\n```python\n\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_in1d_ravel", + "reference": "\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n a = np.arange(6).reshape(2, 3)\n b = np.arange(3, 9).reshape(3, 2)\n long_b = np.arange(3, 63).reshape(30, 2)\n ec = np.array([False, False, False, True, True, True])\n\n assert_array_equal(in1d(a, b, assume_unique=True, kind=kind),\n ec)\n assert_array_equal(in1d(a, b, assume_unique=False,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=True,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=False,\n kind=kind),\n ec)", + "completions": [ + "```python\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n ar1 = np.array([1, 2, 3, 4, 5])\n ar2 = np.array([1, 2, 3, 4, 5])\n ar1_ravel = ar1.ravel()\n ar2_ravel = ar2.ravel()\n self.assertTrue(np.array_equal(ar1_ravel, ar1))\n self.assertTrue(np.array_equal(ar2_ravel, ar2))\n self.assertTrue(np.array_equal(in1d(ar1, ar2, kind=kind),\n in1d(ar1_ravel, ar2_ravel, kind=kind)))\n```\n" + ], + "line": 133, + "token": 1112, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 90, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_metadata", + "reference": "\ndef test_header_to_metadata_and_cell_metadata():\n text = \"\"\"---\ntitle: Sample header\njupyter:\n mainlanguage: python\n---\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {\"mainlanguage\": \"python\"}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {\"lines_to_next_cell\": 0}\n assert pos == len(lines)", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_metadata():\n lines = [\n \"#' ---\",\n \"#' jupyter:\",\n \"#' jupytext:\",\n \"#' text_representation:\",\n \"#' extension: .py\",\n \"#' format_name: light\",\n \"#' format_version: '1.5'\",\n \"#' jupytext_version: 1.11.4\",\n \"#' kernelspec:\",\n \"#' display_name: Python 3\",\n \"#' language: python\",\n \"#' name: python3\",\n \"#' ---\",\n \"#'\",\n \"#' This is a comment\",\n \"#'\",\n \"#' And another one\",\n \"#'\",\n \"#' ```python\",\n \"#' import numpy as np\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' import matplotlib.pyplot as plt\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %matplotlib inline\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2" + ], + "line": 104, + "token": 642, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 91, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef import_optional_dependency(\n name: str,\n extra: str = \"\",\n errors: str = \"raise\",\n min_version: str | None = None,\n):\n\n assert errors in {\"warn\", \"raise\", \"ignore\"}\n\n package_name = INSTALL_MAPPING.get(name)\n install_name = package_name if package_name is not None else name\n\n msg = (\n f\"Missing optional dependency '{install_name}'. {extra} \"\n f\"Use pip or conda to install {install_name}.\"\n )\n try:\n module = importlib.import_module(name)\n except ImportError:\n if errors == \"raise\":\n raise ImportError(msg)\n return None\n\n # Handle submodules: if we have submodule, grab parent module from sys.modules\n parent = name.split(\".\")[0]\n if parent != name:\n install_name = parent\n module_to_get = sys.modules[install_name]\n else:\n module_to_get = module\n minimum_version = min_version if min_version is not None else VERSIONS.get(parent)\n if minimum_version:\n version = get_version(module_to_get)\n if version and Version(version) < Version(minimum_version):\n msg = (\n f\"Pandas requires version '{minimum_version}' or newer of '{parent}' \"\n f\"(version '{version}' currently installed).\"\n )\n if errors == \"warn\":\n warnings.warn(\n msg,\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return None\n elif errors == \"raise\":\n raise ImportError(msg)\n\n return module\n```\n###test function signature:\n```python\n\ndef test_bad_version(monkeypatch):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_bad_version", + "reference": "\ndef test_bad_version(monkeypatch):\n name = \"fakemodule\"\n module = types.ModuleType(name)\n module.__version__ = \"0.9.0\"\n sys.modules[name] = module\n monkeypatch.setitem(VERSIONS, name, \"1.0.0\")\n\n match = \"Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'\"\n with pytest.raises(ImportError, match=match):\n import_optional_dependency(\"fakemodule\")\n\n # Test min_version parameter\n result = import_optional_dependency(\"fakemodule\", min_version=\"0.8\")\n assert result is module\n\n with tm.assert_produces_warning(UserWarning):\n result = import_optional_dependency(\"fakemodule\", errors=\"warn\")\n assert result is None\n\n module.__version__ = \"1.0.0\" # exact match is OK\n result = import_optional_dependency(\"fakemodule\")\n assert result is module", + "completions": [ + "```python\ndef test_bad_version(monkeypatch):\n monkeypatch.setattr(import_optional_dependency, \"get_version\", lambda x: \"0.1.0\")\n monkeypatch.setattr(import_optional_dependency, \"VERSIONS\", {\"pandas\": \"0.2.0\"})\n with pytest.raises(ImportError):\n import_optional_dependency(\"pandas\")\n```\n" + ], + "line": 54, + "token": 352, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 92, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_mel_spectrogram(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mel_spectrogram", + "reference": "\n def test_mel_spectrogram(self):\n waveform = self._load_datasamples(1)[0]\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=513,\n num_mel_filters=13,\n min_frequency=100,\n max_frequency=4000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"htk\",\n )\n self.assertEqual(mel_filters.shape, (513, 13))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n )\n self.assertEqual(spec.shape, (513, 732))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n mel_filters=mel_filters,\n )\n self.assertEqual(spec.shape, (13, 732))\n\n # fmt: off\n expected = np.array([\n 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,\n 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,\n 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,\n 9.44153646e-04\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:, 300], expected))", + "completions": [ + "```python\n def test_mel_spectrogram(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n mel_floor = 1e-10\n log_mel = \"log\"\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (1, 10))\n self.assertEqual(spectrogram[0][0], 1)\n self.assertEqual(spectrogram[0][1], 2)\n self.assertEqual(spectrogram[0][2], 3)\n self.assertEqual(spectrogram[0][3], 4)\n self.assertEqual(spectrogram[0][4], 5)\n self.assertEqual(spectrogram[0][5], 6)\n self.assertEqual(spectrogram[0][6], 7)\n self.assertEqual(spectrogram[0][7], 8)\n self.assertEqual(spectrogram[0][8], 9)\n self.assertEqual(spectrogram[0][9], 10)\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 93, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef create(\n meta_schema,\n validators=(),\n version=None,\n type_checker=_types.draft202012_type_checker,\n format_checker=_format.draft202012_format_checker,\n id_of=_id_of,\n applicable_validators=methodcaller(\"items\"),\n):\n # preemptively don't shadow the `Validator.format_checker` local\n format_checker_arg = format_checker\n\n @attr.s\n class Validator:\n\n VALIDATORS = dict(validators)\n META_SCHEMA = dict(meta_schema)\n TYPE_CHECKER = type_checker\n FORMAT_CHECKER = format_checker_arg\n ID_OF = staticmethod(id_of)\n\n schema = attr.ib(repr=reprlib.repr)\n resolver = attr.ib(default=None, repr=False)\n format_checker = attr.ib(default=None)\n\n def __init_subclass__(cls):\n warnings.warn(\n (\n \"Subclassing validator classes is not intended to \"\n \"be part of their public API. A future version \"\n \"will make doing so an error, as the behavior of \"\n \"subclasses isn't guaranteed to stay the same \"\n \"between releases of jsonschema. Instead, prefer \"\n \"composition of validators, wrapping them in an object \"\n \"owned entirely by the downstream library.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n\n def __attrs_post_init__(self):\n if self.resolver is None:\n self.resolver = RefResolver.from_schema(\n self.schema,\n id_of=id_of,\n )\n\n @classmethod\n def check_schema(cls, schema, format_checker=_UNSET):\n Validator = validator_for(cls.META_SCHEMA, default=cls)\n if format_checker is _UNSET:\n format_checker = Validator.FORMAT_CHECKER\n validator = Validator(\n schema=cls.META_SCHEMA,\n format_checker=format_checker,\n )\n for error in validator.iter_errors(schema):\n raise exceptions.SchemaError.create_from(error)\n\n def evolve(self, **changes):\n # Essentially reproduces attr.evolve, but may involve instantiating\n # a different class than this one.\n cls = self.__class__\n\n schema = changes.setdefault(\"schema\", self.schema)\n NewValidator = validator_for(schema, default=cls)\n\n for field in attr.fields(cls):\n if not field.init:\n continue\n attr_name = field.name # To deal with private attributes.\n init_name = attr_name if attr_name[0] != \"_\" else attr_name[1:]\n if init_name not in changes:\n changes[init_name] = getattr(self, attr_name)\n\n return NewValidator(**changes)\n\n def iter_errors(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.iter_errors \"\n \"is deprecated and will be removed in a future \"\n \"release. Call validator.evolve(schema=new_schema).\"\n \"iter_errors(...) instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n else:\n _schema = self.schema\n\n if _schema is True:\n return\n elif _schema is False:\n yield exceptions.ValidationError(\n f\"False schema does not allow {instance!r}\",\n validator=None,\n validator_value=None,\n instance=instance,\n schema=_schema,\n )\n return\n\n scope = id_of(_schema)\n if scope:\n self.resolver.push_scope(scope)\n try:\n for k, v in applicable_validators(_schema):\n validator = self.VALIDATORS.get(k)\n if validator is None:\n continue\n\n errors = validator(self, v, instance, _schema) or ()\n for error in errors:\n # set details if not already set by the called fn\n error._set(\n validator=k,\n validator_value=v,\n instance=instance,\n schema=_schema,\n type_checker=self.TYPE_CHECKER,\n )\n if k not in {\"if\", \"$ref\"}:\n error.schema_path.appendleft(k)\n yield error\n finally:\n if scope:\n self.resolver.pop_scope()\n\n def descend(self, instance, schema, path=None, schema_path=None):\n for error in self.evolve(schema=schema).iter_errors(instance):\n if path is not None:\n error.path.appendleft(path)\n if schema_path is not None:\n error.schema_path.appendleft(schema_path)\n yield error\n\n def validate(self, *args, **kwargs):\n for error in self.iter_errors(*args, **kwargs):\n raise error\n\n def is_type(self, instance, type):\n try:\n return self.TYPE_CHECKER.is_type(instance, type)\n except exceptions.UndefinedTypeCheck:\n raise exceptions.UnknownType(type, instance, self.schema)\n\n def is_valid(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.is_valid is deprecated \"\n \"and will be removed in a future release. Call \"\n \"validator.evolve(schema=new_schema).is_valid(...) \"\n \"instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n self = self.evolve(schema=_schema)\n\n error = next(self.iter_errors(instance), None)\n return error is None\n\n if version is not None:\n safe = version.title().replace(\" \", \"\").replace(\"-\", \"\")\n Validator.__name__ = Validator.__qualname__ = f\"{safe}Validator\"\n Validator = validates(version)(Validator)\n\n return Validator\n```\n###test function signature:\n```python\n\n def test_validates_registers_meta_schema_id(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validates_registers_meta_schema_id", + "reference": "\n def test_validates_registers_meta_schema_id(self):\n meta_schema_key = \"meta schema id\"\n my_meta_schema = {\"id\": meta_schema_key}\n\n validators.create(\n meta_schema=my_meta_schema,\n version=\"my version\",\n id_of=lambda s: s.get(\"id\", \"\"),\n )\n self.addCleanup(validators._META_SCHEMAS.pop, meta_schema_key)\n\n self.assertIn(meta_schema_key, validators._META_SCHEMAS)", + "completions": [ + "```python\n def test_validates_registers_meta_schema_id(self):\n validator = create(\n meta_schema={\n \"$id\": \"https://example.com/schema.json\",\n \"type\": \"object\",\n \"properties\": {\"foo\": {\"type\": \"string\"}},\n }\n )\n self.assertEqual(\n validator.resolver.store[\"https://example.com/schema.json\"],\n validator.schema,\n )\n```\n" + ], + "line": 160, + "token": 896, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 94, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_metrics(directory: str) -> pd.Series:\n\n dfs = []\n # Loop through all csv files in the directory\n csv_files = [\n filename for filename in os.listdir(directory) if filename.endswith(\".csv\")\n ]\n\n if len(csv_files) == 0:\n return pd.Series()\n\n for filename in csv_files:\n filepath = os.path.join(directory, filename)\n # Read the csv file into a pandas DataFrame\n df = pd.read_csv(\n filepath,\n header=None,\n names=[\n \"initial_time_iso\",\n \"lead_time_hours\",\n \"channel\",\n \"metric\",\n \"value\",\n ],\n )\n out = pd.DataFrame()\n out[\"initial_time\"] = pd.to_datetime(df[\"initial_time_iso\"])\n out[\"lead_time\"] = pd.to_timedelta(df[\"lead_time_hours\"], unit=\"h\")\n out[\"channel\"] = df[\"channel\"]\n out[\"metric\"] = df[\"metric\"]\n out[\"value\"] = df[\"value\"]\n dfs.append(out)\n\n df = pd.concat(dfs, axis=0)\n df.set_index([\"initial_time\", \"lead_time\", \"channel\", \"metric\"], inplace=True)\n return df[\"value\"]\n```\n###test function signature:\n```python\n\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_and_read_metric", + "reference": "\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n initial_time = datetime.datetime(2022, 1, 1, 0, 0, 0)\n lead_time = datetime.timedelta(hours=24)\n channel = \"t2m\"\n metric = \"rmse\"\n value = 25.6\n\n # Read the metric from the directory\n directory = tmpdir.mkdir(\"metrics\")\n filename = \"test_metric.csv\"\n filepath = os.path.join(directory, filename)\n with open(filepath, \"w\") as f:\n forecast_metrics_io.write_metric(\n f, initial_time, lead_time, channel, metric, value\n )\n metrics = forecast_metrics_io.read_metrics(directory)\n\n # Check that the metric was written correctly\n expected_index = pd.MultiIndex.from_tuples(\n [(initial_time, lead_time, channel, metric)],\n names=[\"initial_time\", \"lead_time\", \"channel\", \"metric\"],\n )\n expected_value = pd.Series([value], index=expected_index, name=\"value\")\n pd.testing.assert_series_equal(metrics, expected_value)", + "completions": [ + "```python\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n metric = pd.Series(\n data=[1, 2, 3],\n index=[\n pd.Timestamp(\"2020-01-01 00:00:00\"),\n pd.Timestamp(\"2020-01-01 01:00:00\"),\n pd.Timestamp(\"2020-01-01 02:00:00\"),\n ],\n name=\"metric\",\n )\n metric.to_csv(tmpdir.join(\"metric.csv\"), header=False)\n\n # Read the metric from the file object\n read_metric = read_metrics(tmpdir.strpath)\n\n # Check that the metric is the same\n assert_series_equal(metric, read_metric)\n```\n" + ], + "line": 43, + "token": 320, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 95, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n\n def test_kde_bandwidth_method(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_kde_bandwidth_method", + "reference": "\n def test_kde_bandwidth_method(self):\n\n np.random.seed(8765678)\n n_basesample = 50\n xn = np.random.randn(n_basesample)\n\n # Default\n gkde = mlab.GaussianKDE(xn)\n # Supply a callable\n gkde2 = mlab.GaussianKDE(xn, 'scott')\n # Supply a scalar\n gkde3 = mlab.GaussianKDE(xn, bw_method=gkde.factor)\n\n xs = np.linspace(-7, 7, 51)\n kdepdf = gkde.evaluate(xs)\n kdepdf2 = gkde2.evaluate(xs)\n assert kdepdf.all() == kdepdf2.all()\n kdepdf3 = gkde3.evaluate(xs)\n assert kdepdf.all() == kdepdf3.all()", + "completions": [ + "```python\n def test_kde_bandwidth_method(self):\n from sklearn.neighbors import KernelDensity\n from sklearn.metrics import euclidean_distances\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n from sklearn.model_selection import train_test_split\n from sklearn.metrics import accuracy_score\n from sklearn.metrics import f1_score\n from sklearn.metrics import precision_score\n from sklearn.metrics import recall_score\n from sklearn.metrics import roc_auc_score\n from sklearn.metrics import confusion_matrix\n from sklearn.metrics import classification_report\n from sklearn.metrics import cohen_kappa_score\n from sklearn.metrics import mean_squared_error\n from sklearn.metrics import mean_absolute_error\n from sklearn.metrics import mean_squared_log_error\n from sklearn.metrics import median_absolute_error\n from sklearn.metrics import r2_score\n from sklearn.metrics import explained_variance_score\n from sklearn.metrics import silhouette_score\n from sklearn.metrics import calinski_harabasz_score\n from sklearn.metrics import davies_bouldin_score\n from sklearn.metrics import mutual_info_score\n from sklearn.metrics import normalized_mutual_info_score\n from sklearn.metrics import adjusted_mutual_info_score\n from sklearn.metrics import adjusted_rand_score\n from sklearn.metrics import completeness_score\n from sklearn.metrics import contingency_matrix\n from sklearn.metrics import homogeneity_completeness_v_measure\n from sklearn.metrics import homogeneity_score\n from sklearn.metrics import v_measure_score\n from sklearn.metrics import normalized_v_measure_score\n from sklearn.metrics import poisson\n from sklearn.metrics import kld\n from sklearn.metrics import kld_normalized\n from sklearn.metrics import kld_unnormalized\n from sklearn.metrics import kld_unnormalized_normalized\n from sklearn.metrics import kld_unnormalized_normalized_sym\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from" + ], + "line": 33, + "token": 273, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 96, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_top_k_embeddings(\n query_embedding: List[float],\n embeddings: List[List[float]],\n similarity_fn: Optional[Callable[..., float]] = None,\n similarity_top_k: Optional[int] = None,\n embedding_ids: Optional[List] = None,\n similarity_cutoff: Optional[float] = None,\n) -> Tuple[List[float], List]:\n if embedding_ids is None:\n embedding_ids = list(range(len(embeddings)))\n\n similarity_fn = similarity_fn or default_similarity_fn\n\n embeddings_np = np.array(embeddings)\n query_embedding_np = np.array(query_embedding)\n\n similarity_heap: List[Tuple[float, Any]] = []\n for i, emb in enumerate(embeddings_np):\n similarity = similarity_fn(query_embedding_np, emb)\n if similarity_cutoff is None or similarity > similarity_cutoff:\n heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))\n if similarity_top_k and len(similarity_heap) > similarity_top_k:\n heapq.heappop(similarity_heap)\n result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)\n\n result_similarities = [s for s, _ in result_tups]\n result_ids = [n for _, n in result_tups]\n\n return result_similarities, result_ids\n```\n###test function signature:\n```python\ndef test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_top_k_mmr_embeddings", + "reference": "def test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n query_embedding = [5.0, 0.0, 0.0]\n embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]]\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.8\n )\n\n assert np.isclose(0.8 * 4 / 5, result_similarities[0], atol=0.00001)\n assert np.isclose(\n 0.8 * 3 / 5 - (1 - 0.8) * (3 * 4 / 25 + 3 * 4 / 25),\n result_similarities[1],\n atol=0.00001,\n )\n assert np.isclose(\n 0.8 * -4 / 5 - (1 - 0.8) * (3 * -4 / 25 + 4 * 3 / 25),\n result_similarities[2],\n atol=0.00001,\n )\n assert result_ids == [0, 1, 2]\n\n # Tests that if the first embedding vector is close to the second,\n # it will return the third\n query_embedding = [1.0, 0.0, 1.0]\n embeddings = [[1.0, 0.0, 0.9], [1.0, 0.0, 0.8], [0.7, 0.0, 1.0]]\n\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.5\n )\n assert result_ids == [0, 2, 1]\n\n # Tests that embedding ids map properly to results\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, embedding_ids=[\"A\", \"B\", \"C\"], mmr_threshold=0.5\n )\n assert result_ids == [\"A\", \"C\", \"B\"]\n # Test that it will go back to the original order under a high threshold\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n assert result_ids == [0, 1, 2]\n\n # Test similarity_top_k works\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1, similarity_top_k=2\n )\n assert result_ids == [0, 1]\n\n # Test the results for get_top_k_embeddings and get_top_k_mmr_embeddings are the\n # same for threshold = 1\n query_embedding = [10, 23, 90, 78]\n embeddings = [[1, 23, 89, 68], [1, 74, 144, 23], [0.23, 0.0, 1.0, 9]]\n result_similarities_no_mmr, result_ids_no_mmr = get_top_k_embeddings(\n query_embedding, embeddings\n )\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n\n for result_no_mmr, result_with_mmr in zip(\n result_similarities_no_mmr, result_similarities\n ):\n assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)", + "completions": [ + "```python\ndef test_get_top_k_mmr_embeddings() -> None:\n query_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]\n embeddings = [\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n ]\n result_similarities, result_ids = get_top_k_embeddings(\n query_embedding, embeddings, similarity_top_k=3\n )\n assert result_similarities == [0.5, 0.4, 0.3]\n assert result_ids == [0, 1, 2]\n```\n" + ], + "line": 35, + "token": 309, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 97, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef frame_accuracy(\n logits: np.ndarray,\n label: np.ndarray,\n background_logit_threshold: float = 0.0,\n) -> float:\n top1_idx = np.argmax(logits, axis=-1)\n background_label = np.sum(label, axis=-1) == 0\n pred_background = (\n np.max(logits, axis=-1) <= background_logit_threshold\n ).astype(np.int32)\n\n # Extracts the label at the highest logit index for each input.\n top1_correct = np.take_along_axis(label, top1_idx[..., None], axis=-1)\n top1_correct = np.squeeze(top1_correct)\n foreground_correct = ~pred_background.astype(bool) * top1_correct\n\n # Count correctly classified background samples.\n background_correct = pred_background * background_label\n correct = foreground_correct + background_correct\n return np.sum(correct) / len(logits)\n```\n###test function signature:\n```python\n\n def test_frame_accuracy_all_background(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_frame_accuracy_all_background", + "reference": "\n def test_frame_accuracy_all_background(self):\n logits = np.array(\n [\n [1.5, 0.0, -1.0],\n [0.1, 0.7, 0.2],\n [0.1, 0.7, 0.2],\n [0.1, -0.7, 0.2],\n [-0.1, -0.7, -0.2],\n ],\n np.float32,\n )\n label = np.array(\n [\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n ],\n np.int32,\n )\n actual = metrics.frame_accuracy(logits, label)\n self.assertAlmostEqual(actual, 0.2)", + "completions": [ + "```python\ndef test_frame_accuracy_all_background(self):\n logits = np.array([[0.1, 0.2, 0.3, 0.4]])\n label = np.array([[0, 0, 0, 0]])\n self.assertAlmostEqual(frame_accuracy(logits, label), 1.0)\n```\n" + ], + "line": 28, + "token": 206, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 98, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_local_image_to_be_loaded", + "reference": "\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n file_path, _ = example_local_image\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=file_path,\n max_width=64,\n max_height=128,\n )\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert abs(scaling_factor - 0.5) < 1e-5\n assert decoding_result.shape == (64, 64, 3)\n assert (decoding_result == 0).all()", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n reference, image = example_local_image\n # when\n result = load_image_from_string(\n reference=reference,\n max_height=image.shape[0],\n max_width=image.shape[1],\n )\n # then\n assert result[0] == reference\n assert result[1] is None\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 99, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = \"htk\") -> Union[float, np.ndarray]:\n\n if mel_scale not in [\"slaney\", \"htk\", \"kaldi\"]:\n raise ValueError('mel_scale should be one of \"htk\", \"slaney\" or \"kaldi\".')\n\n if mel_scale == \"htk\":\n return 2595.0 * np.log10(1.0 + (freq / 700.0))\n elif mel_scale == \"kaldi\":\n return 1127.0 * np.log(1.0 + (freq / 700.0))\n\n min_log_hertz = 1000.0\n min_log_mel = 15.0\n logstep = 27.0 / np.log(6.4)\n mels = 3.0 * freq / 200.0\n\n if isinstance(freq, np.ndarray):\n log_region = freq >= min_log_hertz\n mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep\n elif freq >= min_log_hertz:\n mels = min_log_mel + np.log(freq / min_log_hertz) * logstep\n\n return mels\n```\n###test function signature:\n```python\n\n def test_hertz_to_mel(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_hertz_to_mel", + "reference": "\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(0.0), 0.0)\n self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)\n\n inputs = np.array([100, 200])\n expected = np.array([150.48910241, 283.22989816])\n self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))\n\n self.assertEqual(hertz_to_mel(0.0, \"slaney\"), 0.0)\n self.assertEqual(hertz_to_mel(100, \"slaney\"), 1.5)\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"slaney\"), expected))\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"kaldi\"), expected))\n\n with pytest.raises(ValueError):\n hertz_to_mel(100, mel_scale=None)", + "completions": [ + "```python\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(" + ], + "line": 27, + "token": 262, + "line_diff": 1, + "token_diff": 1 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/6/QS/line_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Python/6/QS/line_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..30e169e4109de55d8f2136966749f8e01f16ca0a --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/6/QS/line_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(15~25),subset_1(26~33),subset_2(35~48),subset_3(49~62),subset_4(62~97),subset_5(100~749) +StarCoder2-15b,28.37,24.58,23.50,23.66,20.94,24.81 +CodeLlama-7b,22.58,24.25,29.78,30.65,25.08,26.45 +CodeLlama-13b,29.58,27.96,25.56,27.22,22.71,27.23 +CodeLlama-34b,30.44,28.44,23.55,25.52,26.18,28.08 +DeepSeek-Coder-1.3b,22.60,25.16,27.36,26.48,25.41,26.78 +DeepSeek-Coder-6.7b,23.19,24.98,26.84,25.88,22.35,23.94 +DeepSeek-Coder-33b,30.14,29.12,27.23,29.99,26.94,30.48 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/6/QS/token_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Python/6/QS/token_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..cc6273db633d95973c37938ddae8126fea92dd94 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/6/QS/token_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(122~169),subset_1(183~273),subset_2(273~345),subset_3(347~453),subset_4(464~804),subset_5(804~7038) +StarCoder2-15b,25.64,27.79,23.40,23.82,20.49,25.00 +CodeLlama-7b,22.21,26.21,30.04,29.10,22.55,28.66 +CodeLlama-13b,32.65,25.61,22.48,28.17,24.07,26.15 +CodeLlama-34b,26.71,32.83,25.42,24.63,25.54,27.58 +DeepSeek-Coder-1.3b,23.09,25.69,27.82,25.52,24.04,27.61 +DeepSeek-Coder-6.7b,25.06,25.46,24.87,25.85,22.36,23.58 +DeepSeek-Coder-33b,28.22,31.45,27.25,29.55,27.76,29.85 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/6/QS/tongji.py b/dataset/Test Generation/ComplexCodeEval-Python/6/QS/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/6/QS/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/7/EI/EI.json b/dataset/Test Generation/ComplexCodeEval-Python/7/EI/EI.json new file mode 100644 index 0000000000000000000000000000000000000000..1bfa3ad2b10e19b576b6b767acede2c4b005d227 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/7/EI/EI.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_pair_vector_and_distance(g: dgl.DGLGraph):\n dst_pos = g.ndata[\"pos\"][g.edges()[1]] + g.edata[\"pbc_offshift\"]\n src_pos = g.ndata[\"pos\"][g.edges()[0]]\n bond_vec = dst_pos - src_pos\n bond_dist = torch.norm(bond_vec, dim=1)\n\n return bond_vec, bond_dist\n```\n###test function signature:\n```python\n\n def test_compute_angle(self, graph_Mo, graph_CH4):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_angle", + "reference": "\n def test_compute_angle(self, graph_Mo, graph_CH4):\n s1, g1, state1 = graph_Mo\n lattice = torch.tensor(s1.lattice.matrix, dtype=matgl.float_th).unsqueeze(dim=0)\n g1.edata[\"pbc_offshift\"] = torch.matmul(g1.edata[\"pbc_offset\"], lattice[0])\n g1.ndata[\"pos\"] = g1.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g1)\n g1.edata[\"bond_vec\"] = bv\n g1.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g1, 4.0)\n\n line_graph = create_line_graph(g1, 4.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))\n np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata[\"theta\"])), decimal=4)\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n s2, g2, state2 = graph_CH4\n lattice = torch.tensor(np.identity(3), dtype=matgl.float_th).unsqueeze(dim=0)\n g2.edata[\"pbc_offshift\"] = torch.matmul(g2.edata[\"pbc_offset\"], lattice[0])\n g2.ndata[\"pos\"] = g2.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g2)\n g2.edata[\"bond_vec\"] = bv\n g2.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g2, 2.0)\n\n line_graph = create_line_graph(g2, 2.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata[\"theta\"]))\n )\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )", + "completions": [ + "```python\n def test_compute_angle(self, graph_Mo, graph_CH4):\n \"\"\"\n Test compute_angle function\n \"\"\"\n # compute angle between Mo-C and Mo-H\n angle_Mo_C = compute_angle(graph_Mo, graph_CH4, 0, 1)\n angle_Mo_H = compute_angle(graph_Mo, graph_CH4, 0, 2)\n\n # check angle between Mo-C and Mo-H\n self.assertAlmostEqual(angle_Mo_C, 104.5, places=1)\n self.assertAlmostEqual(angle_Mo_H, 109.5, places=1)\n```\n" + ], + "line": 16, + "token": 148, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 1, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def _make_schedule(self, targets: List[Task]) -> List[Task]:\n self.schedule = []\n self.dependencies = self._build_dependencies(targets)\n\n edge_tups = []\n for node in self.dependencies:\n for dependency in self.dependencies[node]:\n edge_tups.append((dependency, node))\n\n for task in targets:\n # add edges from a dummy node to each target to guarantee\n # they will be included in the final schedule\n edge_tups.append((Executor.DUMMY_TASK_VALUE, task))\n\n def _compare_key(task: Union[Task, str]):\n if task == Executor.DUMMY_TASK_VALUE:\n return (\"\", 0)\n return (\n task.group_label() or \"\",\n -task.priority(),\n )\n\n graph = networkx.DiGraph(edge_tups)\n res = [\n t\n for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)\n if t != Executor.DUMMY_TASK_VALUE\n ]\n return res\n```\n###test function signature:\n```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_group_label_with_dependencies", + "reference": "\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = create_mock_task(\"task1\", result=1, group_label=\"group1\")\n task2 = create_mock_task(\n \"task2\", result=2, dependencies={\"task1\": task1}, group_label=\"group2\"\n )\n task3 = create_mock_task(\n \"task3\", result=3, dependencies={\"task2\": task2}, group_label=\"group1\"\n )\n\n executor = Executor([task3])\n schedule = executor._make_schedule([task3])\n scheduled_labels = [\n task.group_label() for task in schedule if task.group_label()\n ]\n\n # Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1\n group1_indices = [\n i for i, label in enumerate(scheduled_labels) if label == \"group1\"\n ]\n group2_index = scheduled_labels.index(\"group2\")\n\n assert (\n group1_indices[-1] > group2_index\n ), \"Task with the same group label but later dependency was not scheduled after different group label\"", + "completions": [ + "```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = Task(name=\"task1\", group_label=\"group1\")\n task2 = Task(name=\"task2\", group_label=\"group1\")\n task3 = Task(name=\"task3\", group_label=\"group2\")\n task4 = Task(name=\"task4\", group_label=\"group2\")\n task5 = Task(name=\"task5\", group_label=\"group3\")\n task6 = Task(name=\"task6\", group_label=\"group3\")\n task7 = Task(name=\"task7\", group_label=\"group4\")\n task8 = Task(name=\"task8\", group_label=\"group4\")\n task9 = Task(name=\"task9\", group_label=\"group5\")\n task10 = Task(name=\"task10\", group_label=\"group5\")\n\n # Creating dependencies between tasks\n task1.depends_on(task2)\n task2.depends_on(task3)\n task3.depends_on(task4)\n task4.depends_on(task5)\n task5.depends_on(task6)\n task6.depends_on(task7)\n task7.depends_on(task8)\n task8.depends_on(task9)\n task9.depends_on(task10)\n\n # Creating a list of tasks\n tasks = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies\n tasks_with_dependencies = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with group labels\n tasks_with_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_" + ], + "line": 36, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 2, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n):\n if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:\n return\n\n head_dim = dim // heads\n max_seq_len = find_multiple(max_seq_len, 8)\n\n for b in layers:\n b.attention.kv_cache = KVCache(\n max_batch_size, max_seq_len, heads, head_dim\n )\n\n freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)\n causal_mask = torch.tril(\n torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)\n )\n\n return causal_mask, freq_cis\n```\n###test function signature:\n```python\n\ndef test_setup_cache_max_batch_size_greater_than_max():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_setup_cache_max_batch_size_greater_than_max", + "reference": "\ndef test_setup_cache_max_batch_size_greater_than_max():\n layers = [\n Mock(attention=Mock(kw_cache=None)),\n Mock(attention=Mock(kw_cache=None)),\n ]\n max_batch_size = 64\n max_seq_len = 32\n heads = 4\n head_dim = 32\n block_size = 32\n rope_base = 1000\n setup_cache(\n max_batch_size + 10,\n max_seq_len,\n head_dim * heads,\n heads,\n layers,\n block_size,\n rope_base,\n )\n for layer in layers:\n assert isinstance(layer.attention.kw_cache, KVCache)\n assert layer.attention.kw_cache.k_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )\n assert layer.attention.kw_cache.v_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )", + "completions": [ + "```python\ndef test_setup_cache_max_batch_size_greater_than_max():\n max_batch_size = 10\n max_seq_len = 10\n dim = 10\n heads = 10\n layers = [1, 2, 3]\n block_size = 10\n rope_base = 10\n\n causal_mask, freq_cis = setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n )\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size" + ], + "line": 26, + "token": 168, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_file_to_be_downloaded", + "reference": "\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n response = Response()\n response.status_code = 200\n image = np.zeros((128, 128, 3), dtype=np.uint8)\n _, encoded_image = cv2.imencode(\".jpg\", image)\n response._content = encoded_image\n requests_get_mock.return_value = response\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=\"https://some/file.jpg\"\n )\n\n # then\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert scaling_factor is None\n assert decoding_result.shape == image.shape\n assert np.allclose(decoding_result, image)\n requests_get_mock.assert_called_once_with(\"https://some/file.jpg\")", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n reference = \"https://www.example.com/image.jpg\"\n max_height = 100\n max_width = 100\n requests_get_mock.return_value.content = b\"image_bytes\"\n # when\n result = load_image_from_string(\n reference=reference, max_height=max_height, max_width=max_width\n )\n # then\n assert result == (\n \"data:image/jpeg;base64,image_bytes\",\n None,\n )\n```\n" + ], + "line": 47, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 4, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_random_circuits", + "reference": "\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n placer = placer(connectivity=connectivity)\n layout_circ = Circuit(5)\n initial_layout = placer(layout_circ)\n router = Sabre(connectivity=connectivity, lookahead=look, decay_lookahead=decay)\n circuit = generate_random_circuit(nqubits=5, ngates=n_gates)\n measurement = gates.M(*range(5))\n circuit.add(measurement)\n transpiled_circuit, final_qubit_map = router(circuit, initial_layout)\n assert router.added_swaps >= 0\n assert_connectivity(connectivity, transpiled_circuit)\n assert_placement(transpiled_circuit, final_qubit_map)\n assert n_gates + router.added_swaps + 1 == transpiled_circuit.ngates\n assert_circuit_equivalence(\n original_circuit=circuit,\n transpiled_circuit=transpiled_circuit,\n final_map=final_qubit_map,\n initial_map=initial_layout,\n )\n assert transpiled_circuit.queue[-1].register_name == measurement.register_name", + "completions": [ + "```python\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n import random\n import networkx as nx\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import" + ], + "line": 27, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 5, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef xx_circuit_step(source, strength, target, embodiment):\n\n permute_source_for_overlap, permute_target_for_overlap = None, None\n\n # apply all possible reflections, shifts to the source\n for source_reflection_name in reflection_options:\n reflected_source_coord, source_reflection, reflection_phase_shift = apply_reflection(\n source_reflection_name, source\n )\n for source_shift_name in shift_options:\n shifted_source_coord, source_shift, shift_phase_shift = apply_shift(\n source_shift_name, reflected_source_coord\n )\n\n # check for overlap, back out permutation\n source_shared, target_shared = None, None\n for i, j in [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]:\n\n if (\n abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi)) < EPSILON\n or abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi) - np.pi)\n < EPSILON\n ):\n source_shared, target_shared = i, j\n break\n if source_shared is None:\n continue\n\n # pick out the other coordinates\n source_first, source_second = (x for x in [0, 1, 2] if x != source_shared)\n target_first, target_second = (x for x in [0, 1, 2] if x != target_shared)\n\n # check for arccos validity\n r, s, u, v, x, y = decompose_xxyy_into_xxyy_xx(\n float(target[target_first]),\n float(target[target_second]),\n float(shifted_source_coord[source_first]),\n float(shifted_source_coord[source_second]),\n float(strength),\n )\n if any(math.isnan(val) for val in (r, s, u, v, x, y)):\n continue\n\n # OK: this combination of things works.\n # save the permutation which rotates the shared coordinate into ZZ.\n permute_source_for_overlap = canonical_rotation_circuit(source_first, source_second)\n permute_target_for_overlap = canonical_rotation_circuit(target_first, target_second)\n break\n\n if permute_source_for_overlap is not None:\n break\n\n if permute_source_for_overlap is None:\n raise QiskitError(\n \"Error during RZX decomposition: Could not find a suitable Weyl \"\n f\"reflection to match {source} to {target} along {strength}.\"\n )\n\n prefix_circuit, affix_circuit = QuantumCircuit(2), QuantumCircuit(2)\n\n # the basic formula we're trying to work with is:\n # target^p_t_f_o =\n # rs * (source^s_reflection * s_shift)^p_s_f_o * uv * operation * xy\n # but we're rearranging it into the form\n # target = affix source prefix\n # and computing just the prefix / affix circuits.\n\n # the outermost prefix layer comes from the (inverse) target permutation.\n prefix_circuit.compose(permute_target_for_overlap.inverse(), inplace=True)\n # the middle prefix layer comes from the local Z rolls.\n prefix_circuit.rz(2 * x, [0])\n prefix_circuit.rz(2 * y, [1])\n prefix_circuit.compose(embodiment, inplace=True)\n prefix_circuit.rz(2 * u, [0])\n prefix_circuit.rz(2 * v, [1])\n # the innermost prefix layer is source_reflection, shifted by source_shift,\n # finally conjugated by p_s_f_o.\n prefix_circuit.compose(permute_source_for_overlap, inplace=True)\n prefix_circuit.compose(source_reflection, inplace=True)\n prefix_circuit.global_phase += -np.log(reflection_phase_shift).imag\n prefix_circuit.global_phase += -np.log(shift_phase_shift).imag\n\n # the affix circuit is constructed in reverse.\n # first (i.e., innermost), we install the other half of the source transformations and p_s_f_o.\n affix_circuit.compose(source_reflection.inverse(), inplace=True)\n affix_circuit.compose(source_shift, inplace=True)\n affix_circuit.compose(permute_source_for_overlap.inverse(), inplace=True)\n # then, the other local rolls in the middle.\n affix_circuit.rz(2 * r, [0])\n affix_circuit.rz(2 * s, [1])\n # finally, the other half of the p_t_f_o conjugation.\n affix_circuit.compose(permute_target_for_overlap, inplace=True)\n\n return {\"prefix_circuit\": prefix_circuit, \"affix_circuit\": affix_circuit}\n```\n###test function signature:\n```python\n def test_xx_circuit_step(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_xx_circuit_step", + "reference": " def test_xx_circuit_step(self):\n\n for _ in range(100):\n source_coordinate, interaction, target_coordinate = self._generate_xxyy_test_case()\n\n source_embodiment = qiskit.QuantumCircuit(2)\n source_embodiment.append(UnitaryGate(canonical_matrix(*source_coordinate)), [0, 1])\n interaction_embodiment = qiskit.QuantumCircuit(2)\n interaction_embodiment.append(UnitaryGate(canonical_matrix(*interaction)), [0, 1])\n\n prefix_circuit, affix_circuit = itemgetter(\"prefix_circuit\", \"affix_circuit\")(\n xx_circuit_step(\n source_coordinate, interaction[0], target_coordinate, interaction_embodiment\n )\n )\n\n target_embodiment = QuantumCircuit(2)\n target_embodiment.compose(prefix_circuit, inplace=True)\n target_embodiment.compose(source_embodiment, inplace=True)\n target_embodiment.compose(affix_circuit, inplace=True)\n self.assertTrue(\n np.all(\n np.abs(\n qiskit.quantum_info.operators.Operator(target_embodiment).data\n - canonical_matrix(*target_coordinate)\n )\n < EPSILON\n )\n )", + "completions": [ + "```python\n def test_xx_circuit_step(self):\n source = np.array([0.1, 0.2, 0.3])\n strength = 0.4\n target = np.array([0.5, 0.6, 0.7])\n embodiment = QuantumCircuit(2)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.cx(0, 1)\n embodiment.barrier()\n embodiment.rz(0.1, 0)\n embodiment.rz(0.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.3, 0)\n embodiment.rz(0.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.5, 0)\n embodiment.rz(0.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.7, 0)\n embodiment.rz(0.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.9, 0)\n embodiment.rz(1.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.1, 0)\n embodiment.rz(1.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.3, 0)\n embodiment.rz(1.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.5, 0)\n embodiment.rz(1.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.7, 0)\n embodiment.rz(1.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.9, 0)\n embodiment.rz(2.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(2.1, 0)\n embodiment.rz(2.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment" + ], + "line": 90, + "token": 807, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_einsum(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_einsum", + "reference": "\n def test_einsum(self):\n x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n self.assertAllClose(\n knp.einsum(\"ijk,lkj->il\", x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"ijk,ikj->i\", x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"i...,j...k->...ijk\", x, y),\n np.einsum(\"i..., j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n self.assertAllClose(\n knp.Einsum(\"ijk,lkj->il\")(x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"ijk,ikj->i\")(x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n np.einsum(\"i...,j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))", + "completions": [ + "```python\n\n def test_arange(self):\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32),\n torch.arange(0, 10, 1, dtype=torch.float32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64),\n torch.arange(0, 10, 1, dtype=torch.float64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32),\n torch.arange(0, 10, 1, dtype=torch.int32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64),\n torch.arange(0, 10, 1, dtype=torch.int64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool),\n torch.arange(0, 10, 1, dtype=torch.bool),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64),\n torch.arange(0, 10, 1, dtype=torch.complex64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128),\n torch.arange(0, 10, 1, dtype=torch.complex128),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n )\n```\n" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 7, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_center_padding(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_center_padding", + "reference": "\n def test_spectrogram_center_padding(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"reflect\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,\n 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,\n 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,\n 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,\n 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,\n 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,\n 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,\n 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,\n 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,\n 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,\n 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,\n 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,\n 0.00217659, 0.00276204, 0.00260835, 0.00299299,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"constant\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,\n 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,\n 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,\n 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,\n 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,\n 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,\n 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,\n 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,\n 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,\n 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,\n 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,\n 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,\n 0.00788239, 0.00664407, 0.00824227, 0.00628301,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=False,\n )\n self.assertEqual(spec.shape, (257, 728))\n\n # fmt: off\n expected = np.array([\n 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,\n 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,\n 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,\n 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,\n 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,\n 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,\n 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,\n 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,\n 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,\n 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,\n 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,\n 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,\n 0.00811857, 0.00538216, 0.00685749, 0.00535275,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))", + "completions": [ + "```python\n def test_spectrogram_center_padding(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 4\n hop_length = 2\n fft_length = 16\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n np.testing.assert_array_equal(\n spectrogram,\n np.array(\n [\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0," + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 8, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_kword_ref_line", + "reference": "\ndef test_nanoplot_out_bars_with_kword_ref_line():\n\n out_bars_with_kword_ref_line = _generate_nanoplot(**CASES[12])\n\n assert _is_nanoplot_output(out_bars_with_kword_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = \"max\"\n y_ref_area = [1, 2]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = \"min\"\n expand_y = \"min\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_stroke_color=data_bar_negative_stroke_" + ], + "line": 749, + "token": 7038, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 9, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_take_along_axis(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_take_along_axis", + "reference": "\n def test_take_along_axis(self):\n x = np.arange(24).reshape([1, 2, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=1),\n np.take_along_axis(x, indices, axis=1),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=1)(x, indices),\n np.take_along_axis(x, indices, axis=1),\n )\n\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )\n\n # Test with axis=None\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.array([1, 2, 3], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=None),\n np.take_along_axis(x, indices, axis=None),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=None)(x, indices),\n np.take_along_axis(x, indices, axis=None),\n )\n\n # Test with negative indices\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )", + "completions": [ + "```python\n\n def test_take_along_axis(self):\n # test arange\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=1, dtype=torch.float32),\n torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=2, dtype=torch.float32),\n torch.tensor([0, 2, 4, 6, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=3, dtype=torch.float32),\n torch.tensor([0, 3, 6, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=4, dtype=torch.float32),\n torch.tensor([0, 4, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=5, dtype=torch.float32),\n torch.tensor([0, 5], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=6, dtype=torch.float32),\n torch.tensor([0, 6], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=7, dtype=torch.float32),\n torch.tensor([0, 7], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=8, dtype=torch.float32),\n torch.tensor([0, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=9, dtype=torch.float32),\n torch.tensor([0, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=10, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=11, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=12, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=13, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=14, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 10, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_correlation(\n response_level_scores: dict[str, list[dict[str, dict[str, Any]]]],\n name: str = '',\n) -> dict[str, dict[str, dict[str, float]]]:\n factscore_metrics = set(list_metrics(response_level_scores[FACTSCORE]))\n pred_metrics = set(list_metrics(response_level_scores[EVAL_METHOD]))\n result = {}\n\n for metric in factscore_metrics.intersection(pred_metrics):\n factscore_scores = [\n find_metric(response_level_scores[FACTSCORE][i], metric)\n for i in range(len(response_level_scores[FACTSCORE]))\n ]\n pred_scores = [\n find_metric(response_level_scores[EVAL_METHOD][i], metric)\n for i in range(len(response_level_scores[EVAL_METHOD]))\n ]\n\n # Remove failed runs\n for i in range(len(factscore_scores) - 1, -1, -1):\n if factscore_scores[i] == -1 or pred_scores[i] == -1:\n factscore_scores.pop(i)\n pred_scores.pop(i)\n\n # Cannot calculate correlation with only one data point\n if len(factscore_scores) <= 1 or len(pred_scores) <= 1:\n pearson_result, spearman_result = None, None\n else:\n pearson_result = stats.pearsonr(factscore_scores, pred_scores)\n spearman_result = stats.spearmanr(factscore_scores, pred_scores)\n\n scatter_plot(\n factscore_scores,\n pred_scores,\n f'{name}-{metric}',\n FACTSCORE,\n EVAL_METHOD,\n )\n result[metric] = {\n _PEARSON: CorrelationResult(pearson_result).result,\n _SPEARMAN: CorrelationResult(spearman_result).result,\n }\n\n return result\n```\n###test function signature:\n```python\n\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_correlation_one_data_point", + "reference": "\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n response_level_scores = {\n correlation_vs_factscore.FACTSCORE: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'N': 1}},\n ],\n correlation_vs_factscore.EVAL_METHOD: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'M': 1}},\n ],\n }\n actual_output = correlation_vs_factscore.compute_correlation(\n response_level_scores=response_level_scores\n )\n mock_pearsonr.assert_not_called()\n mock_spearmanr.assert_not_called()\n mock_scatter_plot.assert_called()\n self.assertIsInstance(actual_output, dict)\n self.assertLen(actual_output, 1)\n self.assertIn('S', actual_output)\n self.assertIsInstance(actual_output['S'], dict)\n self.assertLen(actual_output['S'], 2)\n self.assertIn('Pearson', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Pearson'], dict)\n self.assertIn('correlation', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['correlation']))\n self.assertIn('p_value', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['p_value']))\n self.assertIn('Spearman', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Spearman'], dict)\n self.assertIn('correlation', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['correlation']))\n self.assertIn('p_value', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['p_value']))", + "completions": [ + "```python\ndef test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n) -> None:\n response_level_scores = {\n FACTSCORE: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n EVAL_METHOD: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n }\n result = compute_correlation(response_level_scores)\n self.assertEqual(result, {})\n mock_spearmanr.assert_not_called()\n mock_pearsonr.assert_not_called()\n mock_scatter_plot.assert_not_called()\n```\n" + ], + "line": 54, + "token": 384, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_data_lines():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_data_lines", + "reference": "\ndef test_nanoplot_out_data_lines():\n\n out_data_lines = _generate_nanoplot(**CASES[0])\n\n assert _is_nanoplot_output(out_data_lines)\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_data_lines():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n expected = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_" + ], + "line": 749, + "token": 7038, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 12, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_daily_temp_mean", + "reference": "\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data = il_electricity_cdd_hdd_daily[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_daily[\"temperature_data\"]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert df.shape == (810, 3)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_mean\",\n ]\n\n assert round(df.temperature_mean.mean()) == 55.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"D\", tz=\"UTC\"\n )\n temperature_data = pd.DataFrame(\n {\"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]},\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n df = compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=[],\n cooling_balance_points=[],\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n )\n expected = pd.DataFrame(\n {\n \"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],\n \"n_hours_kept\": [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],\n \"n_hours_dropped\": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n },\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n pd.testing.assert_frame_equal(df, expected)\n```\n" + ], + "line": 166, + "token": 983, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 13, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef power_color(\n frequency,\n power,\n power_err=None,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n):\n freq_edges = np.asarray(freq_edges)\n if len(freq_edges) != 5:\n raise ValueError(\"freq_edges must have 5 elements\")\n\n frequency = np.asarray(frequency)\n power = np.asarray(power)\n\n if df is None:\n df = np.median(np.diff(frequency))\n input_frequency_low_edges = frequency - df / 2\n input_frequency_high_edges = frequency + df / 2\n\n if freq_edges.min() < input_frequency_low_edges[0]:\n raise ValueError(\"The minimum frequency is larger than the first frequency edge\")\n if freq_edges.max() > input_frequency_high_edges[-1]:\n raise ValueError(\"The maximum frequency is lower than the last frequency edge\")\n\n if power_err is None:\n power_err = power / np.sqrt(m)\n else:\n power_err = np.asarray(power_err)\n\n if freqs_to_exclude is not None:\n if len(np.shape(freqs_to_exclude)) == 1:\n freqs_to_exclude = [freqs_to_exclude]\n\n if (\n not isinstance(freqs_to_exclude, Iterable)\n or len(np.shape(freqs_to_exclude)) != 2\n or np.shape(freqs_to_exclude)[1] != 2\n ):\n raise ValueError(\"freqs_to_exclude must be of format [[f0, f1], [f2, f3], ...]\")\n for f0, f1 in freqs_to_exclude:\n frequency_mask = (input_frequency_low_edges > f0) & (input_frequency_high_edges < f1)\n idx0, idx1 = np.searchsorted(frequency, [f0, f1])\n power[frequency_mask] = np.mean([power[idx0], power[idx1]])\n\n var00, var00_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[:2],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var01, var01_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[2:4],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var10, var10_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[1:3],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var11, var11_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[3:5],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n pc0 = var00 / var01\n pc1 = var10 / var11\n pc0_err = pc0 * (var00_err / var00 + var01_err / var01)\n pc1_err = pc1 * (var10_err / var10 + var11_err / var11)\n if return_log:\n pc0_err = 1 / pc0 * pc0_err\n pc1_err = 1 / pc1 * pc1_err\n pc0 = np.log10(pc0)\n pc1 = np.log10(pc1)\n return pc0, pc0_err, pc1, pc1_err\n```\n###test function signature:\n```python\n\n def test_with_power_err(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_power_err", + "reference": "\n def test_with_power_err(self):\n pc0, pc0_err, pc1, pc1_err = power_color(\n self.freq,\n self.power,\n power_err=self.power / 2,\n )\n pc0e, pc0e_err, pc1e, pc1e_err = power_color(\n self.freq,\n self.power,\n power_err=self.power,\n )\n assert np.isclose(pc0, 1, atol=0.001)\n assert np.isclose(pc1, 1, atol=0.001)\n assert np.isclose(pc0e, 1, atol=0.001)\n assert np.isclose(pc1e, 1, atol=0.001)\n assert np.isclose(pc0e_err / pc0_err, 2, atol=0.001)\n assert np.isclose(pc1e_err / pc1_err, 2, atol=0.001)", + "completions": [ + "```python\n def test_with_power_err(self):\n frequency = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power_err = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n pc0, pc0_err, pc1, pc1_err = power_color(\n frequency,\n power,\n power_err=power_err,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n )\n self.assertEqual(pc0, 0.0)\n self.assertEqual(pc0_err, 0.0)\n self.assertEqual(pc1, 0.0)\n self.assertEqual(pc1_err, 0.0)\n```\n" + ], + "line": 97, + "token": 560, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 14, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def apply_top_k(\n self, logits: numpy.ndarray, filter_value=_MIN_FLOAT\n ) -> numpy.ndarray:\n logits_shape = logits.shape\n logits = logits.reshape(logits.shape[-1])\n top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]\n logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value\n\n return logits.reshape(logits_shape)\n```\n###test function signature:\n```python\n\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_apply_topk", + "reference": "\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11))\n\n token_generator = TokenGenerator(\n logits_shape=logits[-1].shape[-1],\n top_k=3,\n )\n\n filter_value = -float(\"Inf\")\n new_logits = token_generator.apply_top_k(\n logits.copy(), filter_value=filter_value\n )\n\n for _ in range(token_generator.top_k):\n curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits)\n assert curr_max > filter_value\n new_logits = numpy.delete(new_logits, idx)\n\n assert numpy.all(new_logits == filter_value)", + "completions": [ + "```python\n def test_apply_topk(\n self,\n ):\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9," + ], + "line": 21, + "token": 152, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(\n self,\n sources: List[Union[str, Path, ByteStream]],\n meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,\n ):\n parser = MarkdownIt(renderer_cls=RendererPlain)\n if self.table_to_single_line:\n parser.enable(\"table\")\n\n documents = []\n meta_list = normalize_metadata(meta=meta, sources_count=len(sources))\n\n for source, metadata in tqdm(\n zip(sources, meta_list),\n total=len(sources),\n desc=\"Converting markdown files to Documents\",\n disable=not self.progress_bar,\n ):\n try:\n bytestream = get_bytestream_from_source(source)\n except Exception as e:\n logger.warning(\"Could not read {source}. Skipping it. Error: {error}\", source=source, error=e)\n continue\n try:\n file_content = bytestream.data.decode(\"utf-8\")\n text = parser.render(file_content)\n except Exception as conversion_e:\n logger.warning(\n \"Failed to extract text from {source}. Skipping it. Error: {error}\",\n source=source,\n error=conversion_e,\n )\n continue\n\n merged_metadata = {**bytestream.meta, **metadata}\n document = Document(content=text, meta=merged_metadata)\n documents.append(document)\n\n return {\"documents\": documents}\n```\n###test function signature:\n```python\n\n def test_run_calls_normalize_metadata(self, test_files_path):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_calls_normalize_metadata", + "reference": "\n def test_run_calls_normalize_metadata(self, test_files_path):\n bytestream = ByteStream(data=b\"test\", meta={\"author\": \"test_author\", \"language\": \"en\"})\n\n converter = MarkdownToDocument()\n\n with patch(\"haystack.components.converters.markdown.normalize_metadata\") as normalize_metadata, patch(\n \"haystack.components.converters.markdown.MarkdownIt\"\n ):\n converter.run(sources=[bytestream, test_files_path / \"markdown\" / \"sample.md\"], meta={\"language\": \"it\"})\n\n # check that the metadata normalizer is called properly\n normalize_metadata.assert_called_with(meta={\"language\": \"it\"}, sources_count=2)", + "completions": [ + "```python\n def test_run_calls_normalize_metadata(self, test_files_path):\n # Arrange\n sources = [test_files_path / \"test_file.md\"]\n meta = {\"key\": \"value\"}\n\n # Act\n run(sources, meta)\n\n # Assert\n normalize_metadata.assert_called_once_with(meta=meta, sources_count=1)\n```\n" + ], + "line": 45, + "token": 311, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 16, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_config(config: object) -> None:\n if hasattr(config, \"kind\"):\n Config.running_config.classes[config.kind] = config # type: ignore\n Config.running_config.types[config.kind] = {}\n for field in fields(config): # type: ignore\n if hasattr(field, \"type\"):\n Config.running_config.types[config.kind][field.name] = optional_origin(field.type)\n else:\n raise RuntimeError(\"Config must have a 'kind' attribute\")\n```\n###test function signature:\n```python\n\ndef test_config():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_config", + "reference": "\ndef test_config():\n arg_parser = get_arg_parser()\n core_add_args(arg_parser)\n arg_parser.parse_known_args()\n cfg = Config(\"test\")\n cfg.add_config(ConfigTest)\n cfg.init_default_config()\n assert Config.dict() == {\n \"configtest\": {\n \"testvar1\": \"testing123\",\n \"testvar2\": 12345,\n \"testvar3\": {\n \"mydict\": {\"foo\": \"bar\", \"abc\": {\"def\": \"ghi\"}},\n \"myint\": 0,\n \"mystr\": \"Hello\",\n },\n }\n }\n cfg2 = Config(\"test2\")\n assert cfg.configtest.testvar1 == cfg2.configtest.testvar1\n assert cfg.configtest.testvar1 == Config.configtest.testvar1\n assert cfg.configtest.testvar2 == cfg2.configtest.testvar2\n assert cfg.configtest.testvar3 == cfg2.configtest.testvar3\n Config.configtest.testvar2 += 1\n assert cfg.configtest.testvar2 == 12346\n with pytest.raises(ConfigNotFoundError):\n Config.does_not_exist.foo = \"bar\"\n with pytest.raises(ConfigNotFoundError):\n cfg.does_not_exist.foo = \"bar\"", + "completions": [ + "```python\ndef test_config():\n from dataclasses import dataclass, field\n from typing import Optional\n\n @dataclass\n class Config:\n kind: str\n a: Optional[int] = field(default=None)\n b: Optional[int] = field(default=None)\n\n add_config(Config)\n\n assert Config.running_config.classes[\"Config\"] == Config\n assert Config.running_config.types[\"Config\"] == {\"a\": int, \"b\": int}\n```\n" + ], + "line": 19, + "token": 166, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 17, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_with_num_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_with_num_ref_line", + "reference": "\ndef test_nanoplot_out_with_num_ref_line():\n\n out_with_num_ref_line = _generate_nanoplot(**CASES[1])\n\n assert _is_nanoplot_output(out_with_num_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_with_num_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color," + ], + "line": 749, + "token": 7038, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 18, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def begin_typing(self, x=None):\n self.capturekeystrokes = True\n # Disable keypress shortcuts, which may otherwise cause the figure to\n # be saved, closed, etc., until the user stops typing. The way to\n # achieve this depends on whether toolmanager is in use.\n stack = ExitStack() # Register cleanup actions when user stops typing.\n self._on_stop_typing = stack.close\n toolmanager = getattr(\n self.ax.figure.canvas.manager, \"toolmanager\", None)\n if toolmanager is not None:\n # If using toolmanager, lock keypresses, and plan to release the\n # lock when typing stops.\n toolmanager.keypresslock(self)\n stack.callback(toolmanager.keypresslock.release, self)\n else:\n # If not using toolmanager, disable all keypress-related rcParams.\n # Avoid spurious warnings if keymaps are getting deprecated.\n with _api.suppress_matplotlib_deprecation_warning():\n stack.enter_context(mpl.rc_context(\n {k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n```\n###test function signature:\n```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_TextBox", + "reference": "\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n plt.rcParams._set(\"toolbar\", toolbar)\n\n submit_event = mock.Mock(spec=noop, return_value=None)\n text_change_event = mock.Mock(spec=noop, return_value=None)\n tool = widgets.TextBox(ax, '')\n tool.on_submit(submit_event)\n tool.on_text_change(text_change_event)\n\n assert tool.text == ''\n\n do_event(tool, '_click')\n\n tool.set_val('x**2')\n\n assert tool.text == 'x**2'\n assert text_change_event.call_count == 1\n\n tool.begin_typing()\n tool.stop_typing()\n\n assert submit_event.call_count == 2\n\n do_event(tool, '_click')\n do_event(tool, '_keypress', key='+')\n do_event(tool, '_keypress', key='5')\n\n assert text_change_event.call_count == 3", + "completions": [ + "```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n with ExitStack() as stack:\n stack.enter_context(mpl.rc_context({k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n ax.begin_typing()\n assert ax.capturekeystrokes is True\n assert toolbar.keypresslock.is_locked(ax)\n```\n" + ], + "line": 31, + "token": 258, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def read_events(\n db: PrefectDBInterface,\n session: AsyncSession,\n events_filter: EventFilter,\n limit: \"int | None\" = None,\n offset: \"int | None\" = None,\n) -> Sequence[ORMEvent]:\n # Always order by occurred timestamp, with placeholder for order direction\n order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc\n\n # Check if distinct fields are provided\n if distinct_fields := build_distinct_queries(events_filter):\n # Define window function\n window_function = (\n sa.func.row_number()\n .over(partition_by=distinct_fields, order_by=order(db.Event.occurred))\n .label(\"row_number\")\n )\n # Create a subquery with the window function\n subquery = (\n sa.select(db.Event, window_function)\n .where(\n sa.and_(\n *events_filter.build_where_clauses(db)\n ) # Ensure the same filters are applied here\n )\n .subquery()\n )\n\n # Alias the subquery for easier column references\n aliased_table = aliased(db.Event, subquery)\n\n # Create the final query from the subquery, filtering to get only rows with row_number = 1\n select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1)\n\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(subquery.c.occurred))\n\n else:\n # If no distinct fields are provided, create a query for all events\n select_events_query = sa.select(db.Event).where(\n sa.and_(*events_filter.build_where_clauses(db))\n )\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(db.Event.occurred))\n\n if limit is not None:\n limit = max(0, min(limit, events_filter.logical_limit))\n select_events_query = select_events_query.limit(limit=limit)\n if offset is not None:\n select_events_query = select_events_query.offset(offset=offset)\n\n logger.debug(\"Running PostgreSQL query: %s\", select_events_query)\n\n select_events_query_result = await session.execute(select_events_query)\n return select_events_query_result.scalars().unique().all()\n```\n###test function signature:\n```python\n\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_read_events", + "reference": "\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_1.id\n\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=3),\n until=pendulum.now(\"UTC\").subtract(days=1),\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_2.id\n\n # Read criteria should apply AND logic\n # Nothing should be returned for this resource\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n resource=EventResourceFilter(id=[\"prefect.garbage.foo\"]),\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 0", + "completions": [ + "```python\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n # Arrange\n events_filter = EventFilter(\n occurred_after=datetime.datetime(2022, 1, 1),\n occurred_before=datetime.datetime(2022, 1, 2),\n order=EventOrder.DESC,\n )\n # Act\n events = await read_events(\n db=self.db,\n session=session,\n events_filter=events_filter,\n )\n # Assert\n self.assertEqual(len(events), 2)\n self.assertEqual(events[0], event_2)\n self.assertEqual(events[1], event_1)\n```\n" + ], + "line": 58, + "token": 412, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 20, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n def test_evaluate_diff_dim(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_evaluate_diff_dim", + "reference": " def test_evaluate_diff_dim(self):\n x1 = np.arange(3, 10, 2)\n kde = mlab.GaussianKDE(x1)\n x2 = np.arange(3, 12, 2)\n y_expected = [\n 0.08797252, 0.11774109, 0.11774109, 0.08797252, 0.0370153\n ]\n y = kde.evaluate(x2)\n np.testing.assert_array_almost_equal(y, y_expected, 7)", + "completions": [ + "```python\n def test_evaluate_diff_dim(self):\n \"\"\"\n Test evaluate with different dimensions\n \"\"\"\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n\n # Generate a random dataset\n X, _ = make_blobs(n_samples=100, n_features=10, centers=5,\n random_state=42)\n\n # Standardize the data\n scaler = StandardScaler()\n X = scaler.fit_transform(X)\n\n # Create a GaussianMixture object\n gmm = GaussianMixture(n_components=5, covariance_type='full')\n gmm.fit(X)\n\n # Generate a random set of points\n n_points = 100\n points = np.random.rand(n_points, 10)\n\n # Evaluate the points\n result = gmm.evaluate(points)\n\n # Check that the result is correct\n self.assertEqual(result.shape, (n_points,))\n self.assertTrue(np.all(result >= 0))\n self.assertTrue(np.all(result <= 1))\n```\n" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 21, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _differentiate(func, x, *, args=(), atol=None, rtol=None, maxiter=10,\n order=8, initial_step=0.5, step_factor=2.0,\n step_direction=0, preserve_shape=False, callback=None):\n # TODO (followup):\n # - investigate behavior at saddle points\n # - array initial_step / step_factor?\n # - multivariate functions?\n\n res = _differentiate_iv(func, x, args, atol, rtol, maxiter, order, initial_step,\n step_factor, step_direction, preserve_shape, callback)\n (func, x, args, atol, rtol, maxiter, order,\n h0, fac, hdir, preserve_shape, callback) = res\n\n # Initialization\n # Since f(x) (no step) is not needed for central differences, it may be\n # possible to eliminate this function evaluation. However, it's useful for\n # input validation and standardization, and everything else is designed to\n # reduce function calls, so let's keep it simple.\n temp = eim._initialize(func, (x,), args, preserve_shape=preserve_shape)\n func, xs, fs, args, shape, dtype = temp\n x, f = xs[0], fs[0]\n df = np.full_like(f, np.nan)\n # Ideally we'd broadcast the shape of `hdir` in `_elementwise_algo_init`, but\n # it's simpler to do it here than to generalize `_elementwise_algo_init` further.\n # `hdir` and `x` are already broadcasted in `_differentiate_iv`, so we know\n # that `hdir` can be broadcasted to the final shape.\n hdir = np.broadcast_to(hdir, shape).flatten()\n\n status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress\n nit, nfev = 0, 1 # one function evaluations performed above\n # Boolean indices of left, central, right, and (all) one-sided steps\n il = hdir < 0\n ic = hdir == 0\n ir = hdir > 0\n io = il | ir\n\n # Most of these attributes are reasonably obvious, but:\n # - `fs` holds all the function values of all active `x`. The zeroth\n # axis corresponds with active points `x`, the first axis corresponds\n # with the different steps (in the order described in\n # `_differentiate_weights`).\n # - `terms` (which could probably use a better name) is half the `order`,\n # which is always even.\n work = _RichResult(x=x, df=df, fs=f[:, np.newaxis], error=np.nan, h=h0,\n df_last=np.nan, error_last=np.nan, h0=h0, fac=fac,\n atol=atol, rtol=rtol, nit=nit, nfev=nfev,\n status=status, dtype=dtype, terms=(order+1)//2,\n hdir=hdir, il=il, ic=ic, ir=ir, io=io)\n # This is the correspondence between terms in the `work` object and the\n # final result. In this case, the mapping is trivial. Note that `success`\n # is prepended automatically.\n res_work_pairs = [('status', 'status'), ('df', 'df'), ('error', 'error'),\n ('nit', 'nit'), ('nfev', 'nfev'), ('x', 'x')]\n\n def pre_func_eval(work):\n \"\"\"Determine the abscissae at which the function needs to be evaluated.\n\n See `_differentiate_weights` for a description of the stencil (pattern\n of the abscissae).\n\n In the first iteration, there is only one stored function value in\n `work.fs`, `f(x)`, so we need to evaluate at `order` new points. In\n subsequent iterations, we evaluate at two new points. Note that\n `work.x` is always flattened into a 1D array after broadcasting with\n all `args`, so we add a new axis at the end and evaluate all point\n in one call to the function.\n\n For improvement:\n - Consider measuring the step size actually taken, since `(x + h) - x`\n is not identically equal to `h` with floating point arithmetic.\n - Adjust the step size automatically if `x` is too big to resolve the\n step.\n - We could probably save some work if there are no central difference\n steps or no one-sided steps.\n \"\"\"\n n = work.terms # half the order\n h = work.h # step size\n c = work.fac # step reduction factor\n d = c**0.5 # square root of step reduction factor (one-sided stencil)\n # Note - no need to be careful about dtypes until we allocate `x_eval`\n\n if work.nit == 0:\n hc = h / c**np.arange(n)\n hc = np.concatenate((-hc[::-1], hc))\n else:\n hc = np.asarray([-h, h]) / c**(n-1)\n\n if work.nit == 0:\n hr = h / d**np.arange(2*n)\n else:\n hr = np.asarray([h, h/d]) / c**(n-1)\n\n n_new = 2*n if work.nit == 0 else 2 # number of new abscissae\n x_eval = np.zeros((len(work.hdir), n_new), dtype=work.dtype)\n il, ic, ir = work.il, work.ic, work.ir\n x_eval[ir] = work.x[ir, np.newaxis] + hr\n x_eval[ic] = work.x[ic, np.newaxis] + hc\n x_eval[il] = work.x[il, np.newaxis] - hr\n return x_eval\n\n def post_func_eval(x, f, work):\n \"\"\" Estimate the derivative and error from the function evaluations\n\n As in `pre_func_eval`: in the first iteration, there is only one stored\n function value in `work.fs`, `f(x)`, so we need to add the `order` new\n points. In subsequent iterations, we add two new points. The tricky\n part is getting the order to match that of the weights, which is\n described in `_differentiate_weights`.\n\n For improvement:\n - Change the order of the weights (and steps in `pre_func_eval`) to\n simplify `work_fc` concatenation and eliminate `fc` concatenation.\n - It would be simple to do one-step Richardson extrapolation with `df`\n and `df_last` to increase the order of the estimate and/or improve\n the error estimate.\n - Process the function evaluations in a more numerically favorable\n way. For instance, combining the pairs of central difference evals\n into a second-order approximation and using Richardson extrapolation\n to produce a higher order approximation seemed to retain accuracy up\n to very high order.\n - Alternatively, we could use `polyfit` like Jacobi. An advantage of\n fitting polynomial to more points than necessary is improved noise\n tolerance.\n \"\"\"\n n = work.terms\n n_new = n if work.nit == 0 else 1\n il, ic, io = work.il, work.ic, work.io\n\n # Central difference\n # `work_fc` is *all* the points at which the function has been evaluated\n # `fc` is the points we're using *this iteration* to produce the estimate\n work_fc = (f[ic, :n_new], work.fs[ic, :], f[ic, -n_new:])\n work_fc = np.concatenate(work_fc, axis=-1)\n if work.nit == 0:\n fc = work_fc\n else:\n fc = (work_fc[:, :n], work_fc[:, n:n+1], work_fc[:, -n:])\n fc = np.concatenate(fc, axis=-1)\n\n # One-sided difference\n work_fo = np.concatenate((work.fs[io, :], f[io, :]), axis=-1)\n if work.nit == 0:\n fo = work_fo\n else:\n fo = np.concatenate((work_fo[:, 0:1], work_fo[:, -2*n:]), axis=-1)\n\n work.fs = np.zeros((len(ic), work.fs.shape[-1] + 2*n_new))\n work.fs[ic] = work_fc\n work.fs[io] = work_fo\n\n wc, wo = _differentiate_weights(work, n)\n work.df_last = work.df.copy()\n work.df[ic] = fc @ wc / work.h\n work.df[io] = fo @ wo / work.h\n work.df[il] *= -1\n\n work.h /= work.fac\n work.error_last = work.error\n # Simple error estimate - the difference in derivative estimates between\n # this iteration and the last. This is typically conservative because if\n # convergence has begin, the true error is much closer to the difference\n # between the current estimate and the *next* error estimate. However,\n # we could use Richarson extrapolation to produce an error estimate that\n # is one order higher, and take the difference between that and\n # `work.df` (which would just be constant factor that depends on `fac`.)\n work.error = abs(work.df - work.df_last)\n\n def check_termination(work):\n \"\"\"Terminate due to convergence, non-finite values, or error increase\"\"\"\n stop = np.zeros_like(work.df).astype(bool)\n\n i = work.error < work.atol + work.rtol*abs(work.df)\n work.status[i] = eim._ECONVERGED\n stop[i] = True\n\n if work.nit > 0:\n i = ~((np.isfinite(work.x) & np.isfinite(work.df)) | stop)\n work.df[i], work.status[i] = np.nan, eim._EVALUEERR\n stop[i] = True\n\n # With infinite precision, there is a step size below which\n # all smaller step sizes will reduce the error. But in floating point\n # arithmetic, catastrophic cancellation will begin to cause the error\n # to increase again. This heuristic tries to avoid step sizes that are\n # too small. There may be more theoretically sound approaches for\n # detecting a step size that minimizes the total error, but this\n # heuristic seems simple and effective.\n i = (work.error > work.error_last*10) & ~stop\n work.status[i] = _EERRORINCREASE\n stop[i] = True\n\n return stop\n\n def post_termination_check(work):\n return\n\n def customize_result(res, shape):\n return shape\n\n return eim._loop(work, callback, shape, maxiter, func, args, dtype,\n pre_func_eval, post_func_eval, check_termination,\n post_termination_check, customize_result, res_work_pairs,\n preserve_shape)\n```\n###test function signature:\n```python\n\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_maxiter_callback", + "reference": "\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n x = 0.612814\n dist = stats.norm()\n maxiter = 3\n\n def f(x):\n res = dist.cdf(x)\n return res\n\n default_order = 8\n res = differentiate(f, x, maxiter=maxiter, rtol=1e-15)\n assert not np.any(res.success)\n assert np.all(res.nfev == default_order + 1 + (maxiter - 1)*2)\n assert np.all(res.nit == maxiter)\n\n def callback(res):\n callback.iter += 1\n callback.res = res\n assert hasattr(res, 'x')\n assert res.df not in callback.dfs\n callback.dfs.add(res.df)\n assert res.status == eim._EINPROGRESS\n if callback.iter == maxiter:\n raise StopIteration\n callback.iter = -1 # callback called once before first iteration\n callback.res = None\n callback.dfs = set()\n\n res2 = differentiate(f, x, callback=callback, rtol=1e-15)\n # terminating with callback is identical to terminating due to maxiter\n # (except for `status`)\n for key in res.keys():\n if key == 'status':\n assert res[key] == eim._ECONVERR\n assert callback.res[key] == eim._EINPROGRESS\n assert res2[key] == eim._ECALLBACK\n else:\n assert res2[key] == callback.res[key] == res[key]", + "completions": [ + "```python\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n def callback(work):\n if work.nit == 0:\n work.maxiter = 1\n else:\n work.maxiter = 0\n res = _differentiate(lambda x: x**2, 0, maxiter=2, callback=callback)\n self.assertEqual(res.nit, 1)\n self.assertEqual(res.status, 0)\n```\n" + ], + "line": 187, + "token": 1981, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 22, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef get_standard_metadata(\n pyproject_dict: Mapping[str, Any],\n settings: ScikitBuildSettings,\n) -> StandardMetadata:\n new_pyproject_dict = dict(pyproject_dict)\n # Handle any dynamic metadata\n for field, provider, config in load_dynamic_metadata(settings.metadata):\n if provider is None:\n msg = f\"{field} is missing provider\"\n raise KeyError(msg)\n if field not in pyproject_dict.get(\"project\", {}).get(\"dynamic\", []):\n msg = f\"{field} is not in project.dynamic\"\n raise KeyError(msg)\n new_pyproject_dict[\"project\"][field] = provider.dynamic_metadata(field, config)\n new_pyproject_dict[\"project\"][\"dynamic\"].remove(field)\n\n metadata = StandardMetadata.from_pyproject(new_pyproject_dict)\n\n # For scikit-build-core < 0.5, we keep the normalized name for back-compat\n if settings.minimum_version is not None and settings.minimum_version < Version(\n \"0.5\"\n ):\n metadata.name = metadata.canonical_name\n\n # The description field is required to be one line. Instead of merging it\n # or cutting off subsequent lines (setuptools), we throw a nice error.\n # But we didn't validate before 0.9.\n if (\n settings.minimum_version is None or settings.minimum_version >= Version(\"0.9\")\n ) and \"\\n\" in (metadata.description or \"\"):\n msg = \"Multiple lines in project.description are not supported; this is supposed to be a one line summary\"\n raise ValueError(msg)\n\n return metadata\n```\n###test function signature:\n```python\n\ndef test_plugin_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_plugin_metadata", + "reference": "\ndef test_plugin_metadata():\n reason_msg = (\n \"install hatch-fancy-pypi-readme and setuptools-scm to test the \"\n \"dynamic metadata plugins\"\n )\n pytest.importorskip(\"hatch_fancy_pypi_readme\", reason=reason_msg)\n pytest.importorskip(\"setuptools_scm\", reason=reason_msg)\n if shutil.which(\"git\") is None:\n pytest.skip(\"git is not installed\")\n\n shutil.copy(\"plugin_project.toml\", \"pyproject.toml\")\n\n subprocess.run([\"git\", \"init\", \"--initial-branch=main\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.name\", \"bot\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.email\", \"bot@scikit-build.org\"], check=True)\n subprocess.run([\"git\", \"add\", \"pyproject.toml\"], check=True)\n subprocess.run([\"git\", \"commit\", \"-m\", \"initial commit\"], check=True)\n subprocess.run([\"git\", \"tag\", \"v0.1.0\", \"-m\", \"initial commint\"], check=True)\n\n with Path(\"pyproject.toml\").open(\"rb\") as ft:\n pyproject = tomllib.load(ft)\n settings_reader = SettingsReader(pyproject, {}, state=\"metadata_wheel\")\n settings = settings_reader.settings\n\n settings_reader.validate_may_exit()\n\n metadata = get_standard_metadata(pyproject, settings)\n\n assert str(metadata.version) == \"0.1.0\"\n assert metadata.readme == pyproject_metadata.Readme(\n \"Fragment #1Fragment #2\", None, \"text/x-rst\"\n )\n\n assert set(GetRequires().dynamic_metadata()) == {\n \"hatch-fancy-pypi-readme>=22.3\",\n \"setuptools-scm\",\n }", + "completions": [ + "```python\n\ndef test_plugin_metadata():\n from skbuild import __version__ as skbuild_version\n from skbuild.settings import ScikitBuildSettings\n from skbuild.metadata import StandardMetadata\n from skbuild.utils import load_dynamic_metadata\n from skbuild.version import Version\n\n pyproject_dict = {\n \"build-system\": {\n \"requires\": [\"setuptools>=40.8.0\", \"wheel\"],\n \"build-backend\": \"setuptools.build_meta\",\n },\n \"project\": {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n \"description\": \"A build system for Python packages\",\n \"dynamic\": [\"name\", \"version\"],\n },\n }\n\n settings = ScikitBuildSettings(\n minimum_version=Version(\"0.1.0\"),\n metadata=[\n (\"name\", None, None),\n (\"version\", None, None),\n ],\n )\n\n metadata = get_standard_metadata(pyproject_dict, settings)\n\n assert metadata.name == \"scikit-build\"\n assert metadata.version == \"0.1.0\"\n assert metadata.description == \"A build system for Python packages\"\n\n assert metadata.canonical_name == \"scikit-build\"\n assert metadata.canonical_version == \"0.1.0\"\n\n assert metadata.dynamic_metadata == {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n }\n\n assert metadata.dynamic_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_providers == [None, None]\n\n assert metadata.dynamic_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_metadata_providers == [None, None]\n\n assert metadata.dynamic_metadata_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields_by_provider == {None: [\"name\", \"version\"]}\n\n assert metadata.dynamic_metadata_providers_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_configs_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field == {\n (None, \"name\"): \"scikit-build\",\n (None, \"version\"): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider == {\n (\"name\", None): \"scikit-build\",\n (\"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider == {\n (\"name\", None): None,\n (\"version\", None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config == {\n (None, \"name\", None): \"scikit-build\",\n (None, \"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config == {\n (\"name\", None, None): \"scikit-build\",\n (\"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider_and_config == {\n (\"name\", None, None): None,\n (\"version\", None, None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config_and_provider == {\n (None, \"name\", None, None): \"scikit-build\",\n (None, \"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config_and_provider == {\n (\"name\", None, None" + ], + "line": 40, + "token": 347, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 23, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _get_datetime_format(value):\n if not isinstance(value, pd.Series):\n value = pd.Series(value)\n\n value = value[~value.isna()]\n value = value.astype(str).to_numpy()\n\n return _guess_datetime_format_for_array(value)\n```\n###test function signature:\n```python\ndef test__get_datetime_format():\n # Setup\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test__get_datetime_format", + "reference": "def test__get_datetime_format():\n # Setup\n string_value = '2021-02-02'\n list_value = [np.nan, '2021-02-02']\n series_value = pd.Series(['2021-02-02T12:10:59'])\n\n # Run\n string_out = _get_datetime_format(string_value)\n list_out = _get_datetime_format(list_value)\n series_out = _get_datetime_format(series_value)\n\n # Assert\n expected_output = '%Y-%m-%d'\n assert string_out == expected_output\n assert list_out == expected_output\n assert series_out == '%Y-%m-%dT%H:%M:%S'", + "completions": [ + "```python\ndef test__get_datetime_format():\n # Setup\n value = pd.Series([\"2020-01-01\", \"2020-01-02\", \"2020-01-03\"])\n expected = \"%Y-%m-%d\"\n\n # Exercise\n actual = _get_datetime_format(value)\n\n # Verify\n assert actual == expected\n```\n" + ], + "line": 17, + "token": 124, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 24, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def load(self) -> None:\n super().load()\n\n _path = str(\n importlib_resources.files(\"distilabel\")\n / \"steps\"\n / \"tasks\"\n / \"templates\"\n / \"ultrafeedback\"\n / f\"{self.aspect}.jinja2\"\n )\n\n self._template = Template(open(_path).read())\n```\n###test function signature:\n```python\n\n def test_process_with_complex_aspect(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_process_with_complex_aspect", + "reference": "\n def test_process_with_complex_aspect(self) -> None:\n pipeline = Pipeline(name=\"unit-test-pipeline\")\n llm = UltraFeedbackLLM()\n\n task = UltraFeedback(\n name=\"ultrafeedback\",\n aspect=\"truthfulness\",\n llm=llm,\n pipeline=pipeline,\n )\n task.load()\n\n assert next(\n task.process([{\"instruction\": \"test\", \"generations\": [\"A\", \"B\"]}])\n ) == [\n {\n \"instruction\": \"test\",\n \"generations\": [\"A\", \"B\"],\n \"types\": [1, 2],\n \"rationales\": [\"text\", \"text\"],\n \"ratings\": [1, 2],\n \"rationales-for-ratings\": [\"text\", \"text\"],\n \"model_name\": \"ultrafeedback-model\",\n }\n ]", + "completions": [ + "```python\n def test_process_with_complex_aspect(self) -> None:\n self.aspect = \"complex_aspect\"\n self.load()\n self.assertEqual(\n self._template.render(\n {\n \"aspect\": self.aspect,\n \"aspect_name\": \"Complex Aspect\",\n \"aspect_description\": \"This is a complex aspect\",\n \"aspect_type\": \"complex\",\n \"aspect_value\": \"1\",\n \"aspect_weight\": \"1\",\n \"aspect_color\": \"red\",\n \"aspect_icon\": \"icon-circle\",\n \"aspect_icon_color\": \"red\",\n \"aspect_icon_bg_color\": \"white\",\n \"aspect_icon_size\": \"20\",\n \"aspect_icon_border_color\": \"red\",\n \"aspect_icon_border_width\": \"2\",\n \"aspect_icon_border_radius\": \"50\",\n \"aspect_icon_padding\": \"5\",\n \"aspect_icon_margin\": \"5\",\n \"aspect_icon_bg_size\": \"20\",\n \"aspect_icon_bg_border_color\": \"red\",\n \"aspect_icon_bg_border_width\": \"2\",\n \"aspect_icon_bg_border_radius\": \"50\",\n \"aspect_icon_bg_padding\": \"5\",\n \"aspect_icon_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_" + ], + "line": 21, + "token": 145, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 25, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _setup_cuda_device(device: torch.device) -> torch.device:\n local_rank = _get_local_rank() or 0\n if device.index is None:\n device = torch.device(type=\"cuda\", index=local_rank)\n\n # Ensure index is available before setting device\n if device.index >= torch.cuda.device_count():\n raise RuntimeError(\n \"The local rank is larger than the number of available GPUs.\"\n )\n\n torch.cuda.set_device(device)\n return device\n```\n###test function signature:\n```python\n\n def test_get_gpu_device(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_gpu_device", + "reference": "\n def test_get_gpu_device(self) -> None:\n device_idx = torch.cuda.device_count() - 1\n assert device_idx >= 0\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(device_idx)}, clear=True):\n device = get_device()\n assert device.type == \"cuda\"\n assert device.index == device_idx\n assert device.index == torch.cuda.current_device()\n\n # Test that we raise an error if the device index is specified on distributed runs\n if device_idx > 0:\n with pytest.raises(\n RuntimeError,\n match=f\"Device specified is cuda:0 but was assigned cuda:{device_idx}\",\n ):\n device = get_device(\"cuda:0\")\n\n invalid_device_idx = device_idx + 10\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(invalid_device_idx)}):\n with pytest.raises(\n RuntimeError,\n match=\"The local rank is larger than the number of available GPUs\",\n ):\n device = get_device(\"cuda\")\n\n # Test that we fall back to 0 if LOCAL_RANK is not specified\n device = torch.device(_get_device_type_from_env())\n device = _setup_cuda_device(device)\n assert device.type == \"cuda\"\n assert device.index == 0\n assert device.index == torch.cuda.current_device()", + "completions": [ + "```python\n def test_get_gpu_device(self) -> None:\n device = torch.device(type=\"cuda\", index=0)\n self.assertEqual(device, _setup_cuda_device(device))\n```\n" + ], + "line": 21, + "token": 157, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):\n local_dir = settings.home_dir / settings.cache_dir / \"models\" / model_id\n client.load_artifact(local_dir, posixpath.join(\"models\", model_id))\n meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)\n # internal worker case, no token based http client\n if not local_dir.exists():\n raise RuntimeError(f\"Cannot find existing model {project_key}.{model_id} in {local_dir}\")\n with (Path(local_dir) / META_FILENAME).open(encoding=\"utf-8\") as f:\n file_meta = yaml.load(f, Loader=yaml.Loader)\n classification_labels = cls.cast_labels(meta_response)\n meta = ModelMeta(\n name=meta_response.name,\n description=meta_response.description,\n model_type=SupportedModelTypes[meta_response.modelType],\n feature_names=meta_response.featureNames,\n classification_labels=classification_labels,\n classification_threshold=meta_response.threshold,\n loader_module=file_meta[\"loader_module\"],\n loader_class=file_meta[\"loader_class\"],\n )\n\n model_py_ver = (\n tuple(meta_response.languageVersion.split(\".\")) if \"PYTHON\" == meta_response.language.upper() else None\n )\n\n clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver)\n\n constructor_params = meta.__dict__\n constructor_params[\"id\"] = str(model_id)\n\n del constructor_params[\"loader_module\"]\n del constructor_params[\"loader_class\"]\n\n model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params)\n return model\n```\n###test function signature:\n```python\n\ndef test_model_download(request):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_model_download", + "reference": "\ndef test_model_download(request):\n model: BaseModel = request.getfixturevalue(\"german_credit_model\")\n project_key = str(uuid.uuid4())\n\n with utils.MockedProjectCacheDir():\n with utils.MockedClient(mock_all=False) as (client, mr):\n # The model needs to request files\n requested_urls = []\n requested_urls.extend(utils.register_uri_for_model_meta_info(mr, model, project_key))\n requested_urls.extend(\n utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True)\n )\n\n downloaded_model = BaseModel.download(client=client, project_key=project_key, model_id=str(model.id))\n\n for requested_url in requested_urls:\n assert utils.is_url_requested(mr.request_history, requested_url)\n\n assert downloaded_model.id == model.id\n assert downloaded_model.meta == model.meta", + "completions": [ + "```python\ndef test_model_download(request):\n from pathlib import Path\n from unittest.mock import MagicMock\n from giskard.client import GiskardClient\n from giskard.models import ModelMetaInfo, ModelMeta, SupportedModelTypes\n from giskard.utils import cast_labels\n from giskard.models.base import Model\n from giskard.models.classification import ClassificationModel\n from giskard.models.regression import RegressionModel\n from giskard.models.clustering import ClusteringModel\n from giskard.models.dim_reduction import DimReductionModel\n from giskard.models.anomaly_detection import AnomalyDetectionModel\n from giskard.models.time_series import TimeSeriesModel\n from giskard.models.text_classification import TextClassificationModel\n from giskard.models.text_recommendation import TextRecommendationModel\n from giskard.models.image_classification import ImageClassificationModel\n from giskard.models.image_recommendation import ImageRecommendationModel\n from giskard.models.image_segmentation import ImageSegmentationModel\n from giskard.models.image_generation import ImageGenerationModel\n from giskard.models.image_style_transfer import ImageStyleTransferModel\n from giskard.models.image_super_resolution import ImageSuperResolutionModel\n from giskard.models.image_denoising import ImageDenoisingModel\n from giskard.models.image_colorization import ImageColorizationModel\n from giskard.models.image_deraining import ImageDerainingModel\n from giskard.models.image_deblurring import ImageDeblurringModel\n from giskard.models.image_edge_detection import ImageEdgeDetectionModel\n from giskard.models.image_face_detection import ImageFaceDetectionModel\n from giskard.models.image_face_recognition import ImageFaceRecognitionModel\n from giskard.models.image_object_detection import ImageObjectDetectionModel\n from giskard.models.image_object_recognition import ImageObjectRecognitionModel\n from giskard.models.image_segmentation_mask_rcnn import ImageSegmentationMaskRCNNModel\n from giskard.models.image_segmentation_yolo import ImageSegmentationYOLOModel\n from giskard.models.image_segmentation_faster_rcnn import ImageSegmentationFasterRCNNModel\n from giskard.models.image_segmentation_unet import ImageSegmentationUNetModel\n from giskard.models.image_segmentation_deeplab import ImageSegmentationDeepLabModel\n from giskard.models.image_segmentation_fcn import ImageSegmentationFCNModel\n from giskard.models.image_segmentation_pspnet import ImageSegmentationPSPNetModel\n from giskard.models.image_segmentation_linknet import ImageSegmentationLinkNetModel\n from giskard.models.image_segmentation_unet_plus import ImageSegmentationUNetPlusModel\n from giskard.models.image_segmentation_deeplab_v3 import ImageSegmentationDeepLabV3Model\n from giskard.models.image_segmentation_deeplab_v3_plus import ImageSegmentationDeepLabV3PlusModel\n from giskard.models.image_segmentation_fcn_resnet101 import ImageSegmentationFCNResNet101Model\n from giskard.models.image_segmentation_fcn_resnet50 import ImageSegmentationFCNResNet50Model\n from giskard.models.image_segmentation_fcn_mobilenet_v2 import ImageSegmentationFCNMobileNetV2Model\n from giskard.models.image_segmentation_fcn_mobilenet_v3 import ImageSegmentationFCNMobileNetV3Model\n from giskard.models.image_segmentation_fcn_efficientnet_b" + ], + "line": 40, + "token": 316, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\ndef test_rounded_date_columns():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rounded_date_columns", + "reference": "def test_rounded_date_columns():\n df = pl.DataFrame(\n [\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.000000\", \"exp\": True},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.123456\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:01.000000\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00\", \"exp\": True},\n ]\n )\n\n df = df.with_columns(df[\"a\"].str.to_date().alias(\"a_dt\"))\n actual = columns_equal(df[\"a_dt\"], df[\"b\"])\n expected = df[\"exp\"]\n assert_series_equal(actual, expected, check_names=False)", + "completions": [ + "```python\ndef test_rounded_date_columns():\n # given\n col_1 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n col_2 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n\n # when\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n\n # then\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef annotations_to_instances(annos, image_size, mask_format=\"polygon\"):\n boxes = [BoxMode.convert(obj[\"bbox\"], obj[\"bbox_mode\"], BoxMode.XYXY_ABS) for obj in annos]\n target = Instances(image_size)\n boxes = target.gt_boxes = Boxes(boxes)\n boxes.clip(image_size)\n\n classes = [obj[\"category_id\"] for obj in annos]\n classes = torch.tensor(classes, dtype=torch.int64)\n target.gt_classes = classes\n\n if len(annos) and \"segmentation\" in annos[0]:\n segms = [obj[\"segmentation\"] for obj in annos]\n if mask_format == \"polygon\":\n masks = PolygonMasks(segms)\n else:\n assert mask_format == \"bitmask\", mask_format\n masks = []\n for segm in segms:\n if isinstance(segm, list):\n # polygon\n masks.append(polygons_to_bitmask(segm, *image_size))\n elif isinstance(segm, dict):\n # COCO RLE\n masks.append(mask_util.decode(segm))\n elif isinstance(segm, np.ndarray):\n assert segm.ndim == 2, \"Expect segmentation of 2 dimensions, got {}.\".format(\n segm.ndim\n )\n # mask array\n masks.append(segm)\n else:\n raise ValueError(\n \"Cannot convert segmentation of type '{}' to BitMasks!\"\n \"Supported types are: polygons as list[list[float] or ndarray],\"\n \" COCO-style RLE as a dict, or a full-image segmentation mask \"\n \"as a 2D ndarray.\".format(type(segm))\n )\n # torch.from_numpy does not support array with negative stride.\n masks = BitMasks(\n torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])\n )\n target.gt_masks = masks\n\n if len(annos) and \"keypoints\" in annos[0]:\n kpts = [obj.get(\"keypoints\", []) for obj in annos]\n target.gt_keypoints = Keypoints(kpts)\n\n return target\n```\n###test function signature:\n```python\n\n def test_transform_simple_annotation(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_simple_annotation", + "reference": "\n def test_transform_simple_annotation(self):\n transforms = T.TransformList([T.HFlipTransform(400)])\n anno = {\n \"bbox\": np.asarray([10, 10, 200, 300]),\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 3,\n \"segmentation\": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],\n }\n\n output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))\n self.assertTrue(np.allclose(output[\"bbox\"], [200, 10, 390, 300]))\n self.assertEqual(len(output[\"segmentation\"]), len(anno[\"segmentation\"]))\n self.assertTrue(np.allclose(output[\"segmentation\"][0], [390, 10, 300, 100, 300, 10]))\n\n detection_utils.annotations_to_instances([output, output], (400, 400))", + "completions": [ + "```python\n def test_transform_simple_annotation(self):\n annos = [\n {\n \"bbox\": [10, 10, 100, 100],\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 0,\n \"segmentation\": [\n [10, 10, 50, 50, 90, 90, 100, 100],\n [20, 20, 50, 50, 80, 80, 100, 100],\n ],\n \"keypoints\": [10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n }\n ]\n image_size = (100, 100)\n target = annotations_to_instances(annos, image_size)\n assert target.gt_boxes.tensor.tolist() == [[10, 10, 100, 100]]\n assert target.gt_classes.tolist() == [0]\n assert target.gt_masks.polygons().tolist() == [\n [[10, 10], [50, 50], [90, 90], [100, 100]],\n [[20, 20], [50, 50], [80, 80], [100, 100]],\n ]\n assert target.gt_keypoints.tensor.tolist() == [[10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n```\n" + ], + "line": 54, + "token": 437, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 29, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_simple(seed):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_simple", + "reference": "\ndef test_sabre_simple(seed):\n placer = Trivial()\n circ = Circuit(5)\n circ.add(gates.CZ(0, 1))\n initial_layout = placer(circ)\n router = Sabre(connectivity=star_connectivity(), seed=seed)\n routed_circuit, final_map = router(circuit=circ, initial_layout=initial_layout)\n assert router.added_swaps == 1\n assert final_map == {\"q0\": 2, \"q1\": 1, \"q2\": 0, \"q3\": 3, \"q4\": 4}\n assert routed_circuit.queue[0].qubits == (0, 2)\n assert isinstance(routed_circuit.queue[0], gates.SWAP)\n assert isinstance(routed_circuit.queue[1], gates.CZ)\n assert_connectivity(circuit=routed_circuit, connectivity=star_connectivity())\n assert_circuit_equivalence(\n original_circuit=circ,\n transpiled_circuit=routed_circuit,\n final_map=final_map,\n initial_map=initial_layout,\n )", + "completions": [ + "```python\ndef test_sabre_simple(seed):\n import random\n random.seed(seed)\n import networkx as nx\n from pyquil.quil import Program\n from pyquil.api import QVMConnection\n from pyquil.gates import CNOT, H, MEASURE\n from pyquil.quilbase import Declare\n from pyquil.quilbase import DefGate\n from pyquil.quilbase import DefCircuit\n from pyquil.quilbase import Gate\n from pyquil.quilbase import Circuit\n from pyquil.quilbase import QubitPlaceholder\n from pyquil.quilbase import MemoryReference\n from pyquil.quilbase import Parameter\n from pyquil.quilbase import Pragma\n from pyquil.quilbase import PragmaDeclare\n from pyquil.quilbase import PragmaDefGate\n from pyquil.quilbase import PragmaDefCircuit\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefQubit\n from pyquil.quilbase import PragmaDefMemory\n from pyquil.quilbase import PragmaDefRegister\n from pyquil.quilbase import PragmaDefLabel\n from pyquil.quilbase import PragmaDefFunction\n from pyquil.quilbase import PragmaDefProcedure\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyqu" + ], + "line": 27, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 30, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef unique_patterns_2d(\n agrid: NDArray[np.int64], ksize: int, periodic_input: bool\n) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:\n assert ksize >= 1\n if periodic_input:\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n else:\n # TODO: implement non-wrapped image handling\n # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n\n patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(\n agrid,\n (\n agrid.shape[0] - ksize + 1,\n agrid.shape[1] - ksize + 1,\n ksize,\n ksize,\n *agrid.shape[2:],\n ),\n agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],\n writeable=False,\n )\n patch_codes = hash_downto(patches, 2)\n uc, ui = np.unique(patch_codes, return_index=True)\n locs = np.unravel_index(ui, patch_codes.shape)\n up: NDArray[np.int64] = patches[locs[0], locs[1]]\n ids: NDArray[np.int64] = np.vectorize(\n {code: ind for ind, code in enumerate(uc)}.get\n )(patch_codes)\n return ids, up, patch_codes\n```\n###test function signature:\n```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_patterns_2d", + "reference": "\ndef test_unique_patterns_2d(img_redmaze) -> None:\n img = img_redmaze\n tile_size = 1\n pattern_width = 2\n _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(\n img, tile_size\n )\n\n (\n _patterns_in_grid,\n pattern_contents_list,\n patch_codes,\n ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)\n assert patch_codes[1][2] == 4867810695119132864\n assert pattern_contents_list[7][1][1] == 8253868773529191888", + "completions": [ + "```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n ids, up, patch_codes = unique_patterns_2d(img_redmaze, 3, True)\n assert ids.shape == (10,)\n assert up.shape == (10, 3, 3)\n assert patch_codes.shape == (10,)\n```\n" + ], + "line": 48, + "token": 425, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 31, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef serialize(data: Any, encoding: str | None = None) -> str | bytes:\n serialized: bytes | str\n\n if encoding is not None:\n serialized = yaml.dump(data, encoding=encoding, Dumper=AiiDADumper)\n else:\n serialized = yaml.dump(data, Dumper=AiiDADumper)\n\n return serialized\n```\n###test function signature:\n```python\ndef test_mixed_attribute_normal_dict():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mixed_attribute_normal_dict", + "reference": "def test_mixed_attribute_normal_dict():\n from aiida.common.extendeddicts import AttributeDict\n\n # Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively\n dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}\n attribute_dict = AttributeDict(dictionary)\n\n # Now add a normal dictionary in the attribute dictionary\n attribute_dict['nested']['normal'] = {'a': 2}\n\n serialized = serialize.serialize(attribute_dict)\n deserialized = serialize.deserialize_unsafe(serialized)\n\n assert attribute_dict, deserialized", + "completions": [ + "```python\ndef test_mixed_attribute_normal_dict():\n data = {\n \"a\": 1,\n \"b\": 2,\n \"c\": 3,\n \"d\": 4,\n \"e\": 5,\n \"f\": 6,\n \"g\": 7,\n \"h\": 8,\n \"i\": 9,\n \"j\": 10,\n \"k\": 11,\n \"l\": 12,\n \"m\": 13,\n \"n\": 14,\n \"o\": 15,\n \"p\": 16,\n \"q\": 17,\n \"r\": 18,\n \"s\": 19,\n \"t\": 20,\n \"u\": 21,\n \"v\": 22,\n \"w\": 23,\n \"x\": 24,\n \"y\": 25,\n \"z\": 26,\n }\n assert serialize(data) == \"\"\"a: 1\nb: 2\nc: 3\nd: 4\ne: 5\nf: 6\ng: 7\nh: 8\ni: 9\nj: 10\nk: 11\nl: 12\nm: 13\nn: 14\no: 15\np: 16\nq: 17\nr: 18\ns: 19\nt: 20\nu: 21\nv: 22\nw: 23\nx: 24\ny: 25\nz: 26\n\"\"\"\n```\n" + ], + "line": 17, + "token": 133, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected_against_alias", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n http_client.select_model(model_id=\"yolov8n-640\")\n\n # when\n result = http_client.unload_model(model_id=\"yolov8n-640\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"coco/3\",\n }\n assert (\n http_client.selected_model is None\n ), \"Even when alias is in use - selected model should be emptied\"", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n requests_mock.post(\n f\"{API_URL}/model/remove\",\n json={\n \"model_id\": \"test_model_id\",\n },\n headers=DEFAULT_HEADERS,\n )\n client = RoboflowClient(api_key=\"test_api_key\", api_secret=\"test_api_secret\")\n # when\n response = client.unload_model(model_id=\"test_model_alias\")\n # then\n assert response.model_id == \"test_model_id\"\n assert response.model_name == \"test_model_name\"\n assert response.model_version == \"test_model_version\"\n assert response.model_type == \"test_model_type\"\n assert response.model_status == \"test_model_status\"\n assert response.model_created_at == \"test_model_created_at\"\n assert response.model_updated_at == \"test_model_updated_at\"\n assert response.model_tags == \"test_model_tags\"\n assert response.model_description == \"test_model_description\"\n assert response.model_url == \"test_model_url\"\n assert response.model_thumbnail_url == \"test_model_thumbnail_url\"\n assert response.model_training_data_url == \"test_model_training_data_url\"\n assert response.model_training_data_type == \"test_model_training_data_type\"\n assert response.model_training_data_size == \"test_model_training_data_size\"\n assert response.model_training_data_description == \"test_model_training_data_description\"\n assert response.model_training_data_tags == \"test_model_training_data_tags\"\n assert response.model_training_data_created_at == \"test_model_training_data_created_at\"\n assert response.model_training_data_updated_at == \"test_model_training_data_updated_at\"\n assert response.model_training_data_status == \"test_model_training_data_status\"\n assert response.model_training_data_status_message == \"test_model_training_data_status_message\"\n assert response.model_training_data_status_details == \"test_model_training_data_status_details\"\n assert response.model_training_data_status_created_at == \"test_model_training_data_status_created_at\"\n assert response.model_training_data_status_updated_at == \"test_model_training_data_status_updated_at\"\n assert response.model_training_data_status_completed_at == \"test_model_training_data_status_completed_at\"\n assert response.model_training_data_status_duration == \"test_model_training_data_status_duration\"\n assert response.model_training_data_status_budget_type == \"test_model_training_data_status_budget_type\"\n assert response.model_training_data_status_budget_amount == \"test_model_training_data_status_budget_amount\"\n assert response.model_training_data_status_budget_unit == \"test_model_training_data_status_budget_unit\"\n assert response.model_training_data_status_budget_remaining == \"test_model_training_data_status_budget_remaining\"\n assert response.model_training_data_status_budget_message == \"test_model_training_data_status_budget_message\"\n assert response.model_training_data_status_budget_details == \"test_model_training_data_status_budget_details\"\n assert response.model_training_data_status_budget_created_at == \"test_model_training_data_status_budget_created_at\"\n assert response.model_training_data_status_budget_updated_at == \"test_" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 33, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef edge_index_to_vector(\n edge_index: Tensor,\n size: Tuple[int, int],\n bipartite: bool,\n force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n row, col = edge_index\n\n if bipartite: # No need to account for self-loops.\n idx = (row * size[1]).add_(col)\n population = size[0] * size[1]\n return idx, population\n\n elif force_undirected:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We only operate on the upper triangular matrix:\n mask = row < col\n row, col = row[mask], col[mask]\n offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n idx = row.mul_(num_nodes).add_(col).sub_(offset)\n population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n return idx, population\n\n else:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We remove self-loops as we do not want to take them into account\n # when sampling negative values.\n mask = row != col\n row, col = row[mask], col[mask]\n col[row < col] -= 1\n idx = row.mul_(num_nodes - 1).add_(col)\n population = num_nodes * num_nodes - num_nodes\n return idx, population\n```\n###test function signature:\n```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_edge_index_to_vector_and_vice_versa", + "reference": "\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n N = 10\n row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n edge_index = torch.stack([row, col], dim=0)\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n assert population == N * N\n assert idx.tolist() == list(range(population))\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n assert is_undirected(edge_index2)\n assert edge_index.tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n assert population == N * N - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n force_undirected=True)\n assert population == (N * (N + 1)) / 2 - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n force_undirected=True)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()", + "completions": [ + "```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n\n # Test force_undirected:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test force_undirected with bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n```\n" + ], + "line": 43, + "token": 351, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 34, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_integration_test(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_integration_test", + "reference": "\n def test_spectrogram_integration_test(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,\n 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,\n 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,\n 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,\n 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,\n 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,\n 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,\n 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,\n 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,\n 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,\n 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,\n 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,\n 0.0293578 , 0.03452379, 0.02194803, 0.01676056,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=256,\n num_mel_filters=400,\n min_frequency=20,\n max_frequency=8000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"kaldi\",\n triangularize_in_mel_space=True,\n )\n\n mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"povey\", periodic=False),\n frame_length=400,\n hop_length=160,\n fft_length=512,\n power=2.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n preemphasis=0.97,\n mel_filters=mel_filters,\n log_mel=\"log\",\n mel_floor=1.1920928955078125e-07,\n remove_dc_offset=True,\n )\n self.assertEqual(spec.shape, (400, 584))\n\n # fmt: off\n expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -6.52463769, -7.73677889, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.18650018, -3.37195286,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -4.70190154, -2.4217066 , -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -5.62755239, -3.53385194,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -9.43303023, -8.77480925, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.2951092 , -5.51585994,\n -15.94238515, -15.94238515, -15.94238515, -4.40151721,\n -3.95228878, -15.94238515, -15.94238515, -15.94238515,\n -6.10365415, -4.59494697, -15.94238515, -15.94238515,\n -15.94238515, -8.10727767, -6.2585298 , -15.94238515,\n -15.94238515, -15.94238515, -5.60161702, -4.47217004,\n -15.94238515, -15.94238515, -15.94238515, -5.91641988]\n )\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))", + "completions": [ + "```python\n def test_spectrogram_integration_test(self):\n waveform = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])\n window = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n expected_spectrogram = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]])\n\n self.assertTrue(np.allclose(spectrogram, expected_spectrogram))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 35, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef run_query_expansion_node(modules: List[Callable],\n module_params: List[Dict],\n previous_result: pd.DataFrame,\n node_line_dir: str,\n strategies: Dict,\n ) -> pd.DataFrame:\n if not os.path.exists(node_line_dir):\n os.makedirs(node_line_dir)\n node_dir = os.path.join(node_line_dir, \"query_expansion\")\n if not os.path.exists(node_dir):\n os.makedirs(node_dir)\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n\n # run query expansion\n results, execution_times = zip(*map(lambda task: measure_speed(\n task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))\n average_times = list(map(lambda x: x / len(results[0]), execution_times))\n\n # save results to folder\n pseudo_module_params = deepcopy(module_params)\n for i, module_param in enumerate(pseudo_module_params):\n if 'prompt' in module_params:\n module_param['prompt'] = str(i)\n filepaths = list(map(lambda x: os.path.join(node_dir, f'{x}.parquet'), range(len(modules))))\n list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet\n filenames = list(map(lambda x: os.path.basename(x), filepaths))\n\n # make summary file\n summary_df = pd.DataFrame({\n 'filename': filenames,\n 'module_name': list(map(lambda module: module.__name__, modules)),\n 'module_params': module_params,\n 'execution_time': average_times,\n })\n\n # Run evaluation when there are more than one module.\n if len(modules) > 1:\n # pop general keys from strategies (e.g. metrics, speed_threshold)\n general_key = ['metrics', 'speed_threshold']\n general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))\n extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))\n\n # first, filter by threshold if it is enabled.\n if general_strategy.get('speed_threshold') is not None:\n results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],\n filenames)\n\n # check metrics in strategy\n if general_strategy.get('metrics') is None:\n raise ValueError(\"You must at least one metrics for query expansion evaluation.\")\n\n if extra_strategy.get('top_k') is None:\n extra_strategy['top_k'] = 10 # default value\n\n # get retrieval modules from strategy\n retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)\n\n # get retrieval_gt\n retrieval_gt = pd.read_parquet(os.path.join(project_dir, \"data\", \"qa.parquet\"))['retrieval_gt'].tolist()\n\n # run evaluation\n evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(\n retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,\n general_strategy['metrics'], project_dir, previous_result), results))\n\n evaluation_df = pd.DataFrame({\n 'filename': filenames,\n **{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))\n for metric_name in general_strategy['metrics']}\n })\n summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')\n\n best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)\n # change metric name columns to query_expansion_metric_name\n best_result = best_result.rename(columns={\n metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})\n best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])\n else:\n best_result, best_filename = results[0], filenames[0]\n best_result = pd.concat([previous_result, best_result], axis=1)\n\n # add 'is_best' column at summary file\n summary_df['is_best'] = summary_df['filename'] == best_filename\n\n # save files\n summary_df.to_csv(os.path.join(node_dir, \"summary.csv\"), index=False)\n best_result.to_parquet(os.path.join(node_dir, f\"best_{os.path.splitext(best_filename)[0]}.parquet\"), index=False)\n\n return best_result\n```\n###test function signature:\n```python\n\ndef test_run_query_expansion_node_default(node_line_dir):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_query_expansion_node_default", + "reference": "\ndef test_run_query_expansion_node_default(node_line_dir):\n generator_models['mock'] = MockLLM\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n qa_path = os.path.join(project_dir, \"data\", \"qa.parquet\")\n previous_result = pd.read_parquet(qa_path)\n\n modules = [query_decompose, hyde]\n module_params = [{'llm': \"mock\", 'temperature': 0.2, 'batch': 7}, {'llm': \"mock\"}]\n strategies = {\n 'metrics': metrics\n }\n best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n base_query_expansion_test(best_result, node_line_dir)", + "completions": [ + "```python\ndef test_run_query_expansion_node_default(node_line_dir):\n modules = [query_expansion_module]\n module_params = [{'prompt': 'default'}]\n previous_result = pd.DataFrame({'query': ['a', 'b', 'c']})\n strategies = {'metrics': ['recall', 'precision', 'f1'], 'speed_threshold': 1000}\n run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n```\n" + ], + "line": 84, + "token": 850, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 36, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(self) -> Tuple[pd.DataFrame, FreqDetectionResult]:\n try:\n if self._prophet_detector_params[\"suppress_stan\"]:\n pd.set_option(\"mode.chained_assignment\", None)\n\n # Skip measurements based on feedbacks given from SODA Cloud\n preprocessed_df = self.preprocess(time_series_df=self.raw_time_series_df)\n\n # Automatically detect frequency of the time series\n freq_detector = FrequencyDetector(\n logs=self.logs,\n params=self.params,\n time_series_df=preprocessed_df,\n manual_freq=self.training_dataset_params.frequency,\n )\n freq_detection_result = freq_detector.detect_frequency()\n\n # Return if frequency detection failed\n if freq_detection_result.error_code_int >= ERROR_CODE_LEVEL_CUTTOFF:\n return self.exit_with_warning(freq_detection_result)\n\n # Apply training dataset configurations\n training_df = self.apply_training_dataset_configs(\n time_series_df=preprocessed_df, freq_detection_result=freq_detection_result\n )\n\n # Remove big gaps from the time series to not confuse Prophet\n training_df = self.remove_big_gaps_from_time_series(\n time_series_df=training_df, freq_detection_result=freq_detection_result\n )\n\n # Only use the last n points for training based on the window length\n window_length = self.get_window_length(training_df=training_df)\n training_df = training_df.iloc[-window_length:]\n\n training_df_shape = training_df[\"y\"].dropna().shape[0]\n if training_df_shape <= self._min_n_points:\n freq_detection_result = get_not_enough_measurements_freq_result(n_data_points=training_df_shape)\n return self.exit_with_warning(freq_detection_result)\n\n model_hyperparameters = self.get_prophet_hyperparameters(time_series_df=training_df)\n\n predictions_df = self.setup_fit_predict(\n time_series_df=training_df, model_hyperparameters=model_hyperparameters\n )\n anomalies_df = self.detect_anomalies(time_series_df=training_df, predictions_df=predictions_df)\n anomalies_df = self.generate_severity_zones(anomalies_df=anomalies_df)\n anomalies_df = self.compute_alert_level(anomalies_df=anomalies_df)\n return anomalies_df, freq_detection_result\n except Exception as e:\n raise e\n```\n###test function signature:\n```python\n\ndef test_with_exit() -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_exit", + "reference": "\ndef test_with_exit() -> None:\n time_series_df = generate_random_dataframe(size=2, n_rows_to_convert_none=0, frequency=\"D\")\n detector = ProphetDetector(\n logs=LOGS,\n params=PARAMS,\n time_series_df=time_series_df,\n model_cfg=ModelConfigs(),\n training_dataset_params=TrainingDatasetParameters(),\n severity_level_params=SeverityLevelParameters(),\n )\n df_anomalies, frequency_result = detector.run()\n assert df_anomalies.empty\n assert frequency_result.error_code_int == 100\n assert frequency_result.freq_detection_strategy == \"not_enough_measurements_custom\"", + "completions": [ + "```python\ndef test_with_exit() -> None:\n from soda.sodacl.checks.check_frequency import FrequencyDetector\n from soda.sodacl.checks.check_frequency import get_not_enough_measurements_freq_result\n from soda.sodacl.checks.check_frequency import ERROR_CODE_LEVEL_CUTTOFF\n from soda.sodacl.checks.check_frequency import FreqDetectionResult\n from soda.sodacl.checks.check_frequency import CheckFrequency\n from soda.sodacl.checks.check_frequency import CheckFrequencyParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyLogs\n from soda.sodacl.checks.check_frequency import CheckFrequencyTrainingDatasetParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyPro" + ], + "line": 52, + "token": 310, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 37, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef average_transforms(\n R: torch.Tensor,\n t: torch.Tensor,\n w: torch.Tensor,\n mask: torch.Tensor,\n dim: int,\n t_edge: Optional[torch.Tensor] = None,\n dither: Optional[bool] = True,\n dither_eps: float = 1e-4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert dim >= 0, \"dimension must index from the left\"\n w = torch.where(\n mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)\n )\n\n # We use different averaging models based on the number of weights\n num_transform_weights = w.size(-1)\n if num_transform_weights == 1:\n # Share a single scalar weight between t and R.\n probs = w.softmax(dim)\n t_probs = probs\n R_probs = probs[..., None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 2:\n # Use separate scalar weights for each of t and R.\n probs = w.softmax(dim)\n t_probs, R_probs = probs.unbind(-1)\n t_probs = t_probs[..., None]\n R_probs = R_probs[..., None, None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 3:\n # For R use a signed scalar weight.\n R_probs = w[..., 2].softmax(dim)[..., None, None]\n\n # For t use a two-parameter precision matrix P = P_isometric + P_radial.\n # We need to hand compute softmax over the shared dim x 2 elements.\n w_t = w[..., :2]\n w_t_total = w_t.logsumexp([dim, -1], True)\n p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)\n\n # Use Gaussian fusion for translation.\n t_edge = t_edge * mask.to(t_edge.dtype)[..., None]\n t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)\n else:\n raise NotImplementedError\n\n # Average rotation via SVD\n R_avg_unc = (R * R_probs).sum(dim)\n R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)\n U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)\n R_avg = U @ Vh\n\n # Enforce that matrix is rotation matrix\n d = torch.linalg.det(R_avg)\n d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)\n Vh = Vh * d_expand\n R_avg = U @ Vh\n return R_avg, t_avg\n```\n###test function signature:\n```python\n\ndef test_average_transforms(rotations, translations):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_average_transforms", + "reference": "\ndef test_average_transforms(rotations, translations):\n R = torch.stack([rotations[0], torch.eye(3)])\n t = torch.stack([translations[0], torch.zeros(3)])\n w = torch.ones(2, 2)\n mask = torch.ones(2)\n # average of a transform with the identity is \"half\" the transform\n R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)\n R_total_fromavg, _ = compose_transforms(\n R_avg, torch.zeros(3), R_avg, torch.zeros(3)\n )\n _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)\n assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)\n assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)", + "completions": [ + "```python\ndef test_average_transforms(rotations, translations):\n # test case 1\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 2\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 3\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 4\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 5\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 6\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 7\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 8\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 9\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w" + ], + "line": 65, + "token": 537, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n return GPT2Config(\n vocab_size=btlm_config.vocab_size,\n n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n n_embd=btlm_config.hidden_size,\n n_layer=btlm_config.num_hidden_layers,\n n_head=btlm_config.num_attention_heads,\n n_inner=btlm_config.n_inner,\n activation_function=btlm_config.activation_function,\n resid_pdrop=btlm_config.resid_pdrop,\n embd_pdrop=btlm_config.embd_pdrop,\n attn_pdrop=btlm_config.attn_pdrop,\n layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n initializer_range=btlm_config.initializer_range,\n bos_token_id=btlm_config.bos_token_id,\n eos_token_id=btlm_config.eos_token_id,\n # These are new arguments not in the original GPT2Config\n use_alibi=btlm_config.position_embedding_type == \"alibi\",\n use_flash_attn=btlm_config.position_embedding_type == \"alibi\", # Alibi code path requires flash_attn\n mup_width_scale=btlm_config.mup_width_scale,\n mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n mup_output_multiplier=btlm_config.mup_output_alpha,\n mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n mlp_multiple_of=1,\n )\n```\n###test function signature:\n```python\ndef test_btlm_optimized(model_name):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_btlm_optimized", + "reference": "def test_btlm_optimized(model_name):\n dtype = torch.float16\n device = \"cuda\"\n config = btlm_config_to_gpt2_config(\n AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n )\n config.fused_bias_fc = True\n config.fused_dropout_add_ln = True\n config.residual_in_fp32 = True\n\n pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n model = GPTLMHeadModel(config, device=device, dtype=dtype)\n model.load_state_dict(pretrained_state_dict)\n model.eval()\n\n torch.manual_seed(0)\n batch_size = 2\n max_seqlen = 256\n seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n input_ids = torch.randint(\n 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n )\n with torch.no_grad():\n out = model.transformer(input_ids)\n logits = model(input_ids).logits\n del model\n\n # Without device_map, the model is loaded on the CPU, which is very slow\n # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n model_ref = AutoModelForCausalLM.from_pretrained(\n model_name, device_map=\"auto\", trust_remote_code=True\n )\n model_ref.eval()\n with torch.no_grad():\n out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n logits_ref = model_ref(input_ids).logits.to(device=device)\n del model_ref\n\n model_hf = AutoModelForCausalLM.from_pretrained(\n model_name,\n torch_dtype=dtype,\n device_map={\"\": device},\n trust_remote_code=True,\n )\n model_hf.eval()\n with torch.no_grad():\n out_hf = model_hf.transformer(input_ids).last_hidden_state\n logits_hf = model_hf(input_ids).logits\n del model_hf\n\n print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n assert (logits - logits_ref).abs().max().item() < 3 * (\n logits_hf - logits_ref\n ).abs().max().item()", + "completions": [ + "```python\ndef test_btlm_optimized(model_name):\n from transformers import GPT2Config, PretrainedConfig\n from bert_extended.optimized import btlm_config_to_gpt2_config\n\n config = PretrainedConfig.from_pretrained(model_name)\n gpt2_config = btlm_config_to_gpt2_config(config)\n gpt2_config.to_json_file(f\"{model_name}_optimized.json\")\n```\n" + ], + "line": 35, + "token": 169, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 39, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef grad_solid_harmonic(r: jnp.ndarray, l_max: int) -> jnp.ndarray:\n def cartesian_sum(r, ell):\n # computes all cartesian sums of order ell\n # i.e. all terms (a*x + b*y + c*z) for a+b+c = ell\n if ell == 0:\n return np.array([0.])\n return np.sum(np.stack(\n np.meshgrid(*[r]*ell, indexing='ij'), axis=0).reshape(ell, -1), axis=0)\n\n dharmonics = []\n for ell in range(l_max + 1):\n dx = cartesian_sum(np.array([1., 0., 0.]), ell)\n dy = cartesian_sum(np.array([0., 1., 0.]), ell)\n dz = cartesian_sum(np.array([0., 0., 1.]), ell)\n d = np.stack([dx, dy, dz], axis=1)\n cp = jax.vmap(functools.partial(cartesian_product, ell=ell), out_axes=-1)\n dcp = jnp.where(r.T[None, :, :] == 0, 0,\n d[:, :, None]*cp(r)[:, None, :] / r.T[None, :, :])\n sh = jnp.einsum('cs,cxn->snx', full_cart2sph(ell, True), dcp)\n pad = l_max - ell\n sh = jnp.pad(sh, [[pad, pad], [0, 0], [0, 0]])\n dharmonics.append(sh)\n\n return jnp.stack(dharmonics, axis=1) # [L, M, grid, xyz]\n```\n###test function signature:\n```python\n\n def test_grad_solid_harmonic(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_grad_solid_harmonic", + "reference": "\n def test_grad_solid_harmonic(self):\n np.random.seed(0)\n r = np.random.randn(100, 3)\n l_max = 5\n\n jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max))\n expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])),\n [1, 2, 0, 3])\n\n with self.subTest('by hand'):\n observed = gto.grad_solid_harmonic(r, l_max)\n np.testing.assert_allclose(observed, expected, atol=1.e-4)\n with self.subTest('by jax'):\n observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max)\n np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4)", + "completions": [ + "```python\n def test_grad_solid_harmonic(self):\n r = jnp.array([[0.5, 0.5, 0.5]])\n l_max = 2\n dharmonics = grad_solid_harmonic(r, l_max)\n self.assertEqual(dharmonics.shape, (l_max + 1, l_max + 1, 1, 3))\n self.assertTrue(jnp.allclose(dharmonics[0, 0, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 0, 0, :], jnp.array([1., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 1, 0, :], jnp.array([0., 1., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 0, 0, :], jnp.array([0., 0., 1.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 1, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 2, 0, :], jnp.array([0., 0., 0.])))\n```\n" + ], + "line": 32, + "token": 389, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 40, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_basis(\n base_shape, angular_tesselation, remove_symmetries=True, eps=1e-4\n):\n\n if base_shape == 'tetrahedron':\n verts = np.array([\n (np.sqrt(8 / 9), 0, -1 / 3),\n (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),\n (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3),\n (0, 0, 1),\n ])\n faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)])\n elif base_shape == 'icosahedron':\n a = (np.sqrt(5) + 1) / 2\n verts = np.array([\n (-1, 0, a),\n (1, 0, a),\n (-1, 0, -a),\n (1, 0, -a),\n (0, a, 1),\n (0, a, -1),\n (0, -a, 1),\n (0, -a, -1),\n (a, 1, 0),\n (-a, 1, 0),\n (a, -1, 0),\n (-a, -1, 0),\n ]) / np.sqrt(a + 2)\n faces = np.array([\n (0, 4, 1),\n (0, 9, 4),\n (9, 5, 4),\n (4, 5, 8),\n (4, 8, 1),\n (8, 10, 1),\n (8, 3, 10),\n (5, 3, 8),\n (5, 2, 3),\n (2, 7, 3),\n (7, 10, 3),\n (7, 6, 10),\n (7, 11, 6),\n (11, 0, 6),\n (0, 1, 6),\n (6, 1, 10),\n (9, 0, 11),\n (9, 11, 2),\n (9, 2, 5),\n (7, 2, 11),\n ])\n elif base_shape == 'octahedron':\n verts = np.array(\n [(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]\n )\n corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n else:\n raise ValueError(f'base_shape {base_shape} not supported')\n verts = tesselate_geodesic(verts, faces, angular_tesselation)\n\n if remove_symmetries:\n # Remove elements of `verts` that are reflections of each other.\n match = compute_sq_dist(verts.T, -verts.T) < eps\n verts = verts[~np.any(np.triu(match), axis=0), :]\n\n basis = verts[:, ::-1]\n return basis\n```\n###test function signature:\n```python\n def test_generate_basis_golden(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_basis_golden", + "reference": " def test_generate_basis_golden(self):\n\n basis = geopoly.generate_basis('tetrahedron', 1)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('tetrahedron', 2)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.57735027, 0.00000000, -0.81649658],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.57735027, -0.70710678, 0.40824829],\n [-0.57735027, 0.70710678, 0.40824829],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('icosahedron', 2)\n basis_golden = np.array([\n [0.85065081, 0.00000000, 0.52573111],\n [0.80901699, 0.50000000, 0.30901699],\n [0.52573111, 0.85065081, 0.00000000],\n [1.00000000, 0.00000000, 0.00000000],\n [0.80901699, 0.50000000, -0.30901699],\n [0.85065081, 0.00000000, -0.52573111],\n [0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, -0.85065081],\n [0.50000000, 0.30901699, -0.80901699],\n [0.00000000, 1.00000000, 0.00000000],\n [-0.52573111, 0.85065081, 0.00000000],\n [-0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, 0.85065081],\n [-0.30901699, 0.80901699, 0.50000000],\n [0.30901699, 0.80901699, 0.50000000],\n [0.50000000, 0.30901699, 0.80901699],\n [0.50000000, -0.30901699, 0.80901699],\n [0.00000000, 0.00000000, 1.00000000],\n [-0.50000000, 0.30901699, 0.80901699],\n [-0.80901699, 0.50000000, 0.30901699],\n [-0.80901699, 0.50000000, -0.30901699],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('octahedron', 4)\n basis_golden = np.array([\n [0.00000000, 0.00000000, -1.00000000],\n [0.00000000, -0.31622777, -0.94868330],\n [0.00000000, -0.70710678, -0.70710678],\n [0.00000000, -0.94868330, -0.31622777],\n [0.00000000, -1.00000000, 0.00000000],\n [-0.31622777, 0.00000000, -0.94868330],\n [-0.40824829, -0.40824829, -0.81649658],\n [-0.40824829, -0.81649658, -0.40824829],\n [-0.31622777, -0.94868330, 0.00000000],\n [-0.70710678, 0.00000000, -0.70710678],\n [-0.81649658, -0.40824829, -0.40824829],\n [-0.70710678, -0.70710678, 0.00000000],\n [-0.94868330, 0.00000000, -0.31622777],\n [-0.94868330, -0.31622777, 0.00000000],\n [-1.00000000, 0.00000000, 0.00000000],\n [0.00000000, -0.31622777, 0.94868330],\n [0.00000000, -0.70710678, 0.70710678],\n [0.00000000, -0.94868330, 0.31622777],\n [0.40824829, -0.40824829, 0.81649658],\n [0.40824829, -0.81649658, 0.40824829],\n [0.31622777, -0.94868330, 0.00000000],\n [0.81649658, -0.40824829, 0.40824829],\n [0.70710678, -0.70710678, 0.00000000],\n [0.94868330, -0.31622777, 0.00000000],\n [0.31622777, 0.00000000, -0.94868330],\n [0.40824829, 0.40824829, -0.81649658],\n [0.40824829, 0.81649658, -0.40824829],\n [0.70710678, 0.00000000, -0.70710678],\n [0.81649658, 0.40824829, -0.40824829],\n [0.94868330, 0.00000000, -0.31622777],\n [0.40824829, -0.40824829, -0.81649658],\n [0.40824829, -0.81649658, -0.40824829],\n [0.81649658, -0.40824829, -0.40824829],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))", + "completions": [ + "```python\n def test_generate_basis_golden(self):\n basis = generate_basis(\n base_shape='tetrahedron', angular_tesselation=10, remove_symmetries=True, eps=1e-4\n )\n self.assertTrue(np.allclose(basis, np.array([[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0]])))\n```\n" + ], + "line": 75, + "token": 692, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 41, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:\n if not isinstance(t, torch.Tensor):\n t = torch.Tensor([t]).float()\n\n t_max = self.t_max.to(t.device)\n t_min = self.t_min.to(t.device)\n t_tilde = t_min + (t_max - t_min) * t\n\n return t_tilde\n```\n###test function signature:\n```python\n\ndef test_noise_schedule_diffusion_coeff(kind):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_noise_schedule_diffusion_coeff", + "reference": "\ndef test_noise_schedule_diffusion_coeff(kind):\n noise_schedule = GaussianNoiseSchedule(kind=kind, log_snr_range=(-6, 6))\n ts = torch.linspace(1e-2, 1 - 1e-2, 10)\n t_map = noise_schedule.t_map(ts) # map time to the prescribed log_SNR range\n\n if kind == \"log_snr\":\n g = noise_schedule.g(ts)\n\n # compute true beta_t\n l_range = noise_schedule.log_snr_range\n snr = torch.exp((1 - t_map) * l_range[1] + t_map * l_range[0])\n g_true = (-(l_range[0] - l_range[1]) / (snr + 1)).sqrt()\n\n assert torch.allclose(g, g_true, atol=1e-4)\n\n if kind == \"ot_linear\":\n g = noise_schedule.g(ts)\n g_true = (2.0 * t_map / (1.0 - t_map)).sqrt()\n assert torch.allclose(g, g_true, atol=1e-4)", + "completions": [ + "```python\ndef test_noise_schedule_diffusion_coeff(kind):\n from torch.testing import assert_allclose\n from torch.nn import functional as F\n from torch import tensor\n from torch.distributions import Uniform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import Transform\n from torch.distributions.transforms import SigmoidTransform\n from torch.distributions.transforms import TanhTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import ComposeTransform\n from torch.distributions.transforms import PowerTransform\n from torch.distributions.transforms import ExpTransform\n from torch.distributions.transforms import LogTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import" + ], + "line": 17, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 42, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_data_quality", + "reference": "\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index, temperature_data, temperature_mean=False, data_quality=True\n )\n assert df.shape == (27, 4)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_not_null\",\n \"temperature_null\",\n ]\n assert round(df.temperature_not_null.mean(), 2) == 729.23\n assert round(df.temperature_null.mean(), 2) == 0.0", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"1H\", tz=\"America/Chicago\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0" + ], + "line": 168, + "token": 984, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 43, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n\n # when\n result = http_client.unload_model(model_id=\"some/1\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"some/1\",\n }", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n model_id = \"my_model_id\"\n de_aliased_model_id = \"my_de_aliased_model_id\"\n response_payload = {\n \"model_id\": de_aliased_model_id,\n \"model_name\": \"my_model_name\",\n \"model_version\": \"my_model_version\",\n \"model_type\": \"my_model_type\",\n \"model_status\": \"my_model_status\",\n \"model_description\": \"my_model_description\",\n \"model_tags\": [\"my_model_tag1\", \"my_model_tag2\"],\n \"model_created_at\": \"my_model_created_at\",\n \"model_updated_at\": \"my_model_updated_at\",\n }\n requests_mock.post(\n f\"{self.__api_url}/model/remove\",\n json=response_payload,\n headers=DEFAULT_HEADERS,\n )\n # when\n response = self.client.unload_model(model_id=model_id)\n # then\n assert response.model_id == de_aliased_model_id\n assert response.model_name == \"my_model_name\"\n assert response.model_version == \"my_model_version\"\n assert response.model_type == \"my_model_type\"\n assert response.model_status == \"my_model_status\"\n assert response.model_description == \"my_model_description\"\n assert response.model_tags == [\"my_model_tag1\", \"my_model_tag2\"]\n assert response.model_created_at == \"my_model_created_at\"\n assert response.model_updated_at == \"my_model_updated_at\"\n assert self.client.selected_model is None\n```\n" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=20.0,\n max_iteration=100,\n distance_strict=False,\n seed=None,\n):\n rng = np.random.default_rng(seed=seed)\n units_locations = np.zeros((num_units, 3), dtype=\"float32\")\n\n minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um\n minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um\n\n units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)\n units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)\n units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)\n\n if minimum_distance is not None:\n solution_found = False\n renew_inds = None\n for i in range(max_iteration):\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n inds0, inds1 = np.nonzero(distances < minimum_distance)\n mask = inds0 != inds1\n inds0 = inds0[mask]\n inds1 = inds1[mask]\n\n if inds0.size > 0:\n if renew_inds is None:\n renew_inds = np.unique(inds0)\n else:\n # random only bad ones in the previous set\n renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]\n\n units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)\n units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)\n units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)\n else:\n solution_found = True\n break\n\n if not solution_found:\n if distance_strict:\n raise ValueError(\n f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} \"\n \"You can use distance_strict=False or reduce minimum distance\"\n )\n else:\n warnings.warn(f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}\")\n\n return units_locations\n```\n###test function signature:\n```python\n\ndef test_generate_unit_locations():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_unit_locations", + "reference": "\ndef test_generate_unit_locations():\n seed = 0\n\n probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)\n channel_locations = probe.contact_positions\n\n num_units = 100\n minimum_distance = 20.0\n unit_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=minimum_distance,\n max_iteration=500,\n distance_strict=False,\n seed=seed,\n )\n distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)\n dist_flat = np.triu(distances, k=1).flatten()\n dist_flat = dist_flat[dist_flat > 0]\n assert np.all(dist_flat > minimum_distance)\n\n # import matplotlib.pyplot as plt\n # fig, ax = plt.subplots()\n # ax.hist(dist_flat, bins = np.arange(0, 400, 10))\n # fig, ax = plt.subplots()\n # from probeinterface.plotting import plot_probe\n # plot_probe(probe, ax=ax)\n # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)\n # plt.show()", + "completions": [ + "```python\n\ndef test_generate_unit_locations():\n num_units = 10\n channel_locations = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])\n margin_um = 20.0\n minimum_z = 5.0\n maximum_z = 40.0\n minimum_distance = 20.0\n max_iteration = 100\n distance_strict = False\n seed = None\n\n units_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=margin_um,\n minimum_z=minimum_z,\n maximum_z=maximum_z,\n minimum_distance=minimum_distance,\n max_iteration=max_iteration,\n distance_strict=distance_strict,\n seed=seed,\n )\n\n assert units_locations.shape == (num_units, 3)\n assert np.all(units_locations[:, 0] >= margin_um)\n assert np.all(units_locations[:, 0] <= 100 - margin_um)\n assert np.all(units_locations[:, 1] >= margin_um)\n assert np.all(units_locations[:, 1] <= 100 - margin_um)\n assert np.all(units_locations[:, 2] >= minimum_z)\n assert np.all(units_locations[:, 2] <= maximum_z)\n\n # check distance\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n assert np.all(distances >= minimum_distance)\n\n # check no duplicate\n assert np.all(np.unique(units_locations, axis=0) == units_locations)\n```\n" + ], + "line": 58, + "token": 453, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 45, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def to_dataframe(\n self, index: Optional[Union[str, None]] = \"date\", sort_by: Optional[str] = None\n ) -> pd.DataFrame:\n\n def is_list_of_basemodel(items: Union[List[T], T]) -> bool:\n return isinstance(items, list) and all(\n isinstance(item, BaseModel) for item in items\n )\n\n if self.results is None or not self.results:\n raise OpenBBError(\"Results not found.\")\n\n if isinstance(self.results, pd.DataFrame):\n return self.results\n\n try:\n res = self.results\n df = None\n sort_columns = True\n\n # List[Dict]\n if isinstance(res, list) and len(res) == 1 and isinstance(res[0], dict):\n r = res[0]\n dict_of_df = {}\n\n for k, v in r.items():\n # Dict[str, List[BaseModel]]\n if is_list_of_basemodel(v):\n dict_of_df[k] = basemodel_to_df(v, index)\n sort_columns = False\n # Dict[str, Any]\n else:\n dict_of_df[k] = pd.DataFrame(v)\n\n df = pd.concat(dict_of_df, axis=1)\n\n # List[BaseModel]\n elif is_list_of_basemodel(res):\n dt: Union[List[Data], Data] = res # type: ignore\n df = basemodel_to_df(dt, index)\n sort_columns = False\n # List[List | str | int | float] | Dict[str, Dict | List | BaseModel]\n else:\n try:\n df = pd.DataFrame(res)\n # Set index, if any\n if index is not None and index in df.columns:\n df.set_index(index, inplace=True)\n\n except ValueError:\n if isinstance(res, dict):\n df = pd.DataFrame([res])\n\n if df is None:\n raise OpenBBError(\"Unsupported data format.\")\n\n # Drop columns that are all NaN, but don't rearrange columns\n if sort_columns:\n df.sort_index(axis=1, inplace=True)\n df = df.dropna(axis=1, how=\"all\")\n\n # Sort by specified column\n if sort_by:\n df.sort_values(by=sort_by, inplace=True)\n\n except OpenBBError as e:\n raise e\n except ValueError as ve:\n raise OpenBBError(\n f\"ValueError: {ve}. Ensure the data format matches the expected format.\"\n ) from ve\n except TypeError as te:\n raise OpenBBError(\n f\"TypeError: {te}. Check the data types in your results.\"\n ) from te\n except Exception as ex:\n raise OpenBBError(f\"An unexpected error occurred: {ex}\") from ex\n\n return df\n```\n###test function signature:\n```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_to_dataframe_w_args", + "reference": "def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n co = OBBject(results=results)\n\n # Act and Assert\n result = co.to_dataframe(index=index, sort_by=sort_by)\n assert isinstance(result, pd.DataFrame)\n assert result.index.name == index\n\n # check if dataframe is properly sorted\n assert result[sort_by].is_monotonic_increasing", + "completions": [ + "```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n from openbb_terminal.decorators import log_start_end\n from openbb_terminal.stocks.fundamental_analysis import financedatabase_model\n\n @log_start_end(log=logger)\n def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n df = financedatabase_model.to_dataframe(results, index, sort_by)\n # Act\n # Assert\n assert df.shape[0] > 0\n assert df.shape[1] > 0\n assert df.index.name == index\n assert df.columns.tolist() == sort_by\n\n # Act\n test_to_dataframe_w_args(results, index, sort_by)\n # Assert\n```\n" + ], + "line": 76, + "token": 566, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 46, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef unique(ar, return_index=False, return_inverse=False,\n return_counts=False, axis=None, *, equal_nan=True):\n ar = np.asanyarray(ar)\n if axis is None:\n ret = _unique1d(ar, return_index, return_inverse, return_counts, \n equal_nan=equal_nan)\n return _unpack_tuple(ret)\n\n # axis was specified and not None\n try:\n ar = np.moveaxis(ar, axis, 0)\n except np.AxisError:\n # this removes the \"axis1\" or \"axis2\" prefix from the error message\n raise np.AxisError(axis, ar.ndim) from None\n\n # Must reshape to a contiguous 2D array for this to work...\n orig_shape, orig_dtype = ar.shape, ar.dtype\n ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))\n ar = np.ascontiguousarray(ar)\n dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]\n\n # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured\n # data type with `m` fields where each field has the data type of `ar`.\n # In the following, we create the array `consolidated`, which has\n # shape `(n,)` with data type `dtype`.\n try:\n if ar.shape[1] > 0:\n consolidated = ar.view(dtype)\n else:\n # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is\n # a data type with itemsize 0, and the call `ar.view(dtype)` will\n # fail. Instead, we'll use `np.empty` to explicitly create the\n # array with shape `(len(ar),)`. Because `dtype` in this case has\n # itemsize 0, the total size of the result is still 0 bytes.\n consolidated = np.empty(len(ar), dtype=dtype)\n except TypeError as e:\n # There's no good way to do this for object arrays, etc...\n msg = 'The axis argument to unique is not supported for dtype {dt}'\n raise TypeError(msg.format(dt=ar.dtype)) from e\n\n def reshape_uniq(uniq):\n n = len(uniq)\n uniq = uniq.view(orig_dtype)\n uniq = uniq.reshape(n, *orig_shape[1:])\n uniq = np.moveaxis(uniq, 0, axis)\n return uniq\n\n output = _unique1d(consolidated, return_index,\n return_inverse, return_counts, equal_nan=equal_nan)\n output = (reshape_uniq(output[0]),) + output[1:]\n return _unpack_tuple(output)\n```\n###test function signature:\n```python\n\n def test_unique_1d(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_1d", + "reference": "\n def test_unique_1d(self):\n\n def check_all(a, b, i1, i2, c, dt):\n base_msg = 'check {0} failed for type {1}'\n\n msg = base_msg.format('values', dt)\n v = unique(a)\n assert_array_equal(v, b, msg)\n\n msg = base_msg.format('return_index', dt)\n v, j = unique(a, True, False, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i1, msg)\n\n msg = base_msg.format('return_inverse', dt)\n v, j = unique(a, False, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i2, msg)\n\n msg = base_msg.format('return_counts', dt)\n v, j = unique(a, False, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, c, msg)\n\n msg = base_msg.format('return_index and return_inverse', dt)\n v, j1, j2 = unique(a, True, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n\n msg = base_msg.format('return_index and return_counts', dt)\n v, j1, j2 = unique(a, True, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format('return_inverse and return_counts', dt)\n v, j1, j2 = unique(a, False, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i2, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format(('return_index, return_inverse '\n 'and return_counts'), dt)\n v, j1, j2, j3 = unique(a, True, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n assert_array_equal(j3, c, msg)\n\n a = [5, 7, 1, 2, 1, 5, 7]*10\n b = [1, 2, 5, 7]\n i1 = [2, 3, 0, 1]\n i2 = [2, 3, 0, 1, 0, 2, 3]*10\n c = np.multiply([2, 1, 2, 2], 10)\n\n # test for numeric arrays\n types = []\n types.extend(np.typecodes['AllInteger'])\n types.extend(np.typecodes['AllFloat'])\n types.append('datetime64[D]')\n types.append('timedelta64[D]')\n for dt in types:\n aa = np.array(a, dt)\n bb = np.array(b, dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for object arrays\n dt = 'O'\n aa = np.empty(len(a), dt)\n aa[:] = a\n bb = np.empty(len(b), dt)\n bb[:] = b\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for structured arrays\n dt = [('', 'i'), ('', 'i')]\n aa = np.array(list(zip(a, a)), dt)\n bb = np.array(list(zip(b, b)), dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for ticket #2799\n aa = [1. + 0.j, 1 - 1.j, 1]\n assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])\n\n # test for ticket #4785\n a = [(1, 2), (1, 2), (2, 3)]\n unq = [1, 2, 3]\n inv = [0, 1, 0, 1, 1, 2]\n a1 = unique(a)\n assert_array_equal(a1, unq)\n a2, a2_inv = unique(a, return_inverse=True)\n assert_array_equal(a2, unq)\n assert_array_equal(a2_inv, inv)\n\n # test for chararrays with return_inverse (gh-5099)\n a = np.chararray(5)\n a[...] = ''\n a2, a2_inv = np.unique(a, return_inverse=True)\n assert_array_equal(a2_inv, np.zeros(5))\n\n # test for ticket #9137\n a = []\n a1_idx = np.unique(a, return_index=True)[1]\n a2_inv = np.unique(a, return_inverse=True)[1]\n a3_idx, a3_inv = np.unique(a, return_index=True,\n return_inverse=True)[1:]\n assert_equal(a1_idx.dtype, np.intp)\n assert_equal(a2_inv.dtype, np.intp)\n assert_equal(a3_idx.dtype, np.intp)\n assert_equal(a3_inv.dtype, np.intp)\n\n # test for ticket 2111 - float\n a = [2.0, np.nan, 1.0, np.nan]\n ua = [1.0, 2.0, np.nan]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - complex\n a = [2.0-1j, np.nan, 1.0+1j, complex(0.0, np.nan), complex(1.0, np.nan)]\n ua = [1.0+1j, 2.0-1j, complex(0.0, np.nan)]\n ua_idx = [2, 0, 3]\n ua_inv = [1, 2, 0, 2, 2]\n ua_cnt = [1, 1, 3]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - datetime64\n nat = np.datetime64('nat')\n a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]\n ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - timedelta\n nat = np.timedelta64('nat')\n a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]\n ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for gh-19300\n all_nans = [np.nan] * 4\n ua = [np.nan]\n ua_idx = [0]\n ua_inv = [0, 0, 0, 0]\n ua_cnt = [4]\n assert_equal(np.unique(all_nans), ua)\n assert_equal(np.unique(all_nans, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(all_nans, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(all_nans, return_counts=True), (ua, ua_cnt))", + "completions": [ + "```python\n def test_unique_1d(self):\n ar = np.array([1, 2, 3, 4, 5, 5, 5, 5, 5, 5])\n self.assertEqual(unique(ar), np.array([1, 2, 3, 4, 5]))\n self.assertEqual(unique(ar, return_index=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n```\n" + ], + "line": 56, + "token": 571, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef scaled_dot_product_attention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n att_mask: Optional[Union[AttentionMask, \"SparseCS\", torch.Tensor]],\n dropout: Optional[torch.nn.Module] = None,\n block_size: int = 128,\n) -> torch.Tensor:\n autocast_disabled = (\n _has_cpp_library\n and isinstance(att_mask, SparseCS)\n or (att_mask is not None and att_mask.is_sparse)\n )\n seq_len = q.shape[-2]\n\n # switch if:\n # causal is required but mask is not sparse\n # fp16 or under amp context\n # sequence length is divisible by block size\n # same seq len for K and Q\n switch_to_blocksparse = (\n _is_blocksparse_available\n and (att_mask is not None and not att_mask.is_sparse)\n and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)\n and (q.dtype == torch.float16 or torch.is_autocast_enabled())\n and not seq_len % block_size\n and q.shape[-2] == k.shape[-2]\n )\n\n if switch_to_blocksparse:\n logger.info(\"Switching causal attention to Triton blocksparse...\")\n return blocksparse_attention(q, k, v, dropout, block_size)\n\n with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore\n if autocast_disabled:\n q, k, v = q.float(), k.float(), v.float()\n\n att = scaled_query_key_softmax(q, k, att_mask=att_mask)\n\n # Optional dropout, could be part of the masking in the future\n att = _apply_dropout(att, dropout)\n\n # Get to the predicted values, for all heads\n # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)\n y = bmm(att, v)\n return y\n```\n###test function signature:\n```python\n\ndef test_amp_attention_dense(device):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_amp_attention_dense", + "reference": "\ndef test_amp_attention_dense(device):\n b, s, d = 8, 64, 32\n prob = 0.9\n\n a = torch.rand(b, s, d, device=device)\n m = torch.rand(s, s, device=device) > prob\n\n with torch.cuda.amp.autocast():\n r = scaled_dot_product_attention(a, a, a, m)\n\n expected_device = torch.float16 if device == \"cuda\" else torch.float32\n assert r.dtype == expected_device", + "completions": [ + "```python\n\ndef test_amp_attention_dense(device):\n q = torch.randn(1, 128, 128, device=device)\n k = torch.randn(1, 128, 128, device=device)\n v = torch.randn(1, 128, 128, device=device)\n att_mask = torch.randn(1, 128, 128, device=device)\n dropout = torch.nn.Dropout(0.1)\n block_size = 128\n y = scaled_dot_product_attention(q, k, v, att_mask, dropout, block_size)\n assert y.shape == (1, 128, 128)\n```\n" + ], + "line": 50, + "token": 390, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def quantize(\n cls,\n tensor: Tensor,\n nbits: int = 4,\n channel_wise: bool = True,\n group_size: int = 64,\n optimize: bool = False,\n round_zero: bool = False,\n axis: int = 0,\n bitpack: bool = True,\n compute_dtype: torch.dtype | None = None,\n view_as_float: bool = False,\n device: str = \"cuda\",\n ) -> tuple:\n assert nbits in Quantizer.SUPPORTED_BITS, (\n \"nbits=\" + str(nbits) + \" not supported.\"\n )\n assert axis in [0, 1], \"axis should be either 0 or 1\"\n if group_size is not None:\n assert is_divisible(tensor.numel(), group_size), (\n \"group_size should be divisble by the total tensor dimensions. shape: \"\n + str(tensor.shape)\n + \", group_size: \"\n + str(group_size)\n )\n\n W = tensor.float()\n shape = W.shape\n\n # Reshape for grouping\n if (group_size is not None) and channel_wise:\n W = (\n W.reshape([-1, group_size])\n if (axis == 1)\n else W.reshape([group_size, -1])\n )\n\n # Get min/max values\n if not channel_wise:\n _min, _max = W.min(), W.max()\n optimize = False\n else:\n _min = W.min(axis=axis, keepdim=True)[0]\n _max = W.max(axis=axis, keepdim=True)[0]\n\n max_v = 2**nbits - 1\n min_v = 0\n min_max = [min_v, max_v]\n\n # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.\n scale = (max_v / (_max - _min)).clamp(\n max=2e4\n ) # clamp to avoid half-precision problems\n zero = -_min * scale\n\n # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14\n if round_zero:\n zero = torch.round(zero)\n\n # Fine-tune weights\n if optimize:\n W_q, scale, zero = Quantizer.optimize_weights(\n tensor=W,\n scale=scale,\n zero=zero,\n min_max=min_max,\n axis=axis,\n device=device,\n )\n else:\n W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])\n\n # Store meta-data (we invert the scale for dequantization)\n meta = {\n \"nbits\": nbits,\n \"group_size\": group_size,\n \"shape\": shape,\n \"scale\": 1.0 / scale,\n \"zero\": zero,\n \"axis\": axis,\n \"packing\": Quantizer.bit_to_packing[nbits],\n }\n meta[\"unpack_view_dtype\"] = Quantizer.unpack_view_dtype[meta[\"packing\"]]\n\n # Pack bits\n meta[\"view_as_float\"] = view_as_float\n if bitpack:\n W_q = Quantizer.pack[meta[\"packing\"]](W_q)\n if view_as_float:\n W_q = W_q.view(\n torch.float32 if compute_dtype is None else compute_dtype\n ) # store quantized weights as compute_dtype\n else:\n W_q = W_q.to(tensor.dtype)\n meta[\"packing\"] = None\n\n # cleanup\n del W, _min, _max\n torch.cuda.empty_cache()\n\n return W_q, meta\n```\n###test function signature:\n```python\n\n def test_quantizer(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_quantizer", + "reference": "\n def test_quantizer(self): \n for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:\n for nbits in [8,4,3,2,1]:\n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, view_as_float=False)\n if nbits == 3:\n assert W_q.dtype == torch.int32\n else:\n assert W_q.dtype == torch.uint8\n w_dq = Quantizer.dequantize(W_q, meta)\n norm1 = torch.norm(self.w - w_dq, p=0.7)\n \n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, compute_dtype=compute_dtype, view_as_float=True)\n assert W_q.dtype == compute_dtype\n w_dq = Quantizer.dequantize(W_q, meta)\n norm2 = torch.norm(self.w - w_dq, p=0.7)\n \n self.assertTrue(torch.equal(norm1, norm2))", + "completions": [ + "```python\n\n def test_quantizer(self):\n # test case 1\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 2\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=True, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], 0.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 3\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=True, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 4\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor," + ], + "line": 100, + "token": 645, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 49, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_num_ref_area", + "reference": "\ndef test_nanoplot_out_bars_with_num_ref_area():\n\n out_bars_with_num_ref_area = _generate_nanoplot(**CASES[13])\n\n assert _is_nanoplot_output(out_bars_with_num_ref_area)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_area = [1, 3]\n y_ref_line = 2\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_" + ], + "line": 749, + "token": 7038, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 50, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef mean(x, axis=None, keepdims=False):\n if isinstance(x, (list, tuple)):\n x = stack(x)\n x = convert_to_tensor(x)\n if axis == () or axis == []:\n # Torch handles the empty axis case differently from numpy.\n return x\n axis = to_tuple_or_list(axis) # see [NB] below\n\n ori_dtype = standardize_dtype(x.dtype)\n # torch.mean only supports floating point inputs\n compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n if \"int\" in ori_dtype or ori_dtype == \"bool\":\n result_dtype = compute_dtype\n else:\n result_dtype = ori_dtype\n\n # [NB] the python torch op torch.mean() is generated into\n # `torch._C._VariableFunctions.pyi`, and the method\n # signature is overloaded.\n # Dynamo won't actually find the correct signature of\n # `torch.mean()` if arguments are passed via kwargs\n # So we have to pass the arguments via positional args\n # EXCEPT for those that are forced as kwargs via the `*`\n # delimiter in the overloaded method signatures.\n # Additionally, we have to create a singleton-tuple\n # when `axis` is an int to match the existing fn signature\n result = torch.mean(\n x,\n axis,\n keepdims,\n dtype=to_torch_dtype(compute_dtype),\n )\n return cast(result, result_dtype)\n```\n###test function signature:\n```python\n\n def test_mean(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mean", + "reference": "\n def test_mean(self):\n x = np.array([[1, 2, 3], [3, 2, 1]])\n self.assertAllClose(knp.mean(x), np.mean(x))\n self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n self.assertAllClose(\n knp.mean(x, axis=1, keepdims=True),\n np.mean(x, axis=1, keepdims=True),\n )\n\n self.assertAllClose(knp.Mean()(x), np.mean(x))\n self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n self.assertAllClose(\n knp.Mean(axis=1, keepdims=True)(x),\n np.mean(x, axis=1, keepdims=True),\n )\n\n # test overflow\n x = np.array([65504, 65504, 65504], dtype=\"float16\")\n self.assertAllClose(knp.mean(x), np.mean(x))", + "completions": [ + "```python\n\n def test_mean(self):\n x = torch.randn(2, 3, 4)\n self.assertEqual(mean(x), torch.mean(x))\n self.assertEqual(mean(x, axis=0), torch.mean(x, dim=0))\n self.assertEqual(mean(x, axis=1), torch.mean(x, dim=1))\n self.assertEqual(mean(x, axis=2), torch.mean(x, dim=2))\n self.assertEqual(mean(x, axis=(0, 1)), torch.mean(x, dim=(0, 1)))\n self.assertEqual(mean(x, axis=(0, 2)), torch.mean(x, dim=(0, 2)))\n self.assertEqual(mean(x, axis=(1, 2)), torch.mean(x, dim=(1, 2)))\n self.assertEqual(mean(x, axis=(0, 1, 2)), torch.mean(x, dim=(0, 1, 2)))\n self.assertEqual(mean(x, axis=0, keepdims=True), torch.mean(x, dim=0, keepdim=True))\n self.assertEqual(mean(x, axis=1, keepdims=True), torch.mean(x, dim=1, keepdim=True))\n self.assertEqual(mean(x, axis=2, keepdims=True), torch.mean(x, dim=2, keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1), keepdims=True), torch.mean(x, dim=(0, 1), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 2), keepdims=True), torch.mean(x, dim=(0, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(1, 2), keepdims=True), torch.mean(x, dim=(1, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1, 2), keepdims=True), torch.mean(x, dim=(0, 1, 2), keepdim=True))\n```\n" + ], + "line": 42, + "token": 322, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 51, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_intrinsic_value(\n cash_flow: float,\n growth_rate: float,\n perpetual_growth_rate: float,\n weighted_average_cost_of_capital: float,\n cash_and_cash_equivalents: float,\n total_debt: float,\n shares_outstanding: float,\n periods: int = 5,\n) -> pd.DataFrame:\n components = {}\n\n cash_flow_projection = [cash_flow]\n\n # Cash Flow to Use\n for period in range(1, periods + 1):\n if period == 1:\n cash_flow_projection.append(cash_flow_projection[0] * (1 + growth_rate))\n else:\n cash_flow_projection.append(\n cash_flow_projection[period - 1] * (1 + growth_rate)\n )\n\n # Calculate the Terminal Value\n terminal_value = (\n cash_flow_projection[-1]\n * (1 + perpetual_growth_rate)\n / (weighted_average_cost_of_capital - perpetual_growth_rate)\n )\n\n # Add Terminal Value to the end of the cash flow projection\n cash_flow_projection[-1] = cash_flow_projection[-1] + terminal_value\n\n # Calculate the Present Value based on the Discounted Cash Flow Formula\n cash_flow_present_value = []\n for index, cash_flow_value in enumerate(cash_flow_projection):\n cash_flow_present_value.append(\n cash_flow_value / (1 + weighted_average_cost_of_capital) ** (index + 1)\n )\n\n # Calculate the Enterprise Value\n enterprise_value = sum(cash_flow_present_value)\n\n # Calculate the Equity Value\n equity_value = enterprise_value + cash_and_cash_equivalents - total_debt\n\n # Calculate the Intrinsic Value\n intrinsic_value = equity_value / shares_outstanding\n\n # Combine the components into a dictionary\n components = {\n \"Terminal Value\": terminal_value,\n \"Cash Flow Projection\": cash_flow_projection[-1],\n \"Enterprise Value\": enterprise_value,\n \"Equity Value\": equity_value,\n \"Intrinsic Value\": intrinsic_value,\n }\n\n return pd.DataFrame.from_dict(\n components, orient=\"index\", columns=[f\"Periods = {periods}\"]\n )\n```\n###test function signature:\n```python\n\ndef test_get_intrinsic_value(recorder):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_intrinsic_value", + "reference": "\ndef test_get_intrinsic_value(recorder):\n recorder.capture(\n intrinsic_model.get_intrinsic_value(\n cash_flow=500,\n growth_rate=0.05,\n perpetual_growth_rate=0.02,\n weighted_average_cost_of_capital=0.1,\n cash_and_cash_equivalents=100,\n total_debt=100,\n shares_outstanding=100,\n periods=5,\n )\n )", + "completions": [ + "```python\n\ndef test_get_intrinsic_value(recorder):\n cash_flow = 1000000\n growth_rate = 0.05\n perpetual_growth_rate = 0.05\n weighted_average_cost_of_capital = 0.1\n cash_and_cash_equivalents = 1000000\n total_debt = 1000000\n shares_outstanding = 1000000\n periods = 5\n\n result = get_intrinsic_value(\n cash_flow,\n growth_rate,\n perpetual_growth_rate,\n weighted_average_cost_of_capital,\n cash_and_cash_equivalents,\n total_debt,\n shares_outstanding,\n periods,\n )\n\n recorder.capture(result)\n```\n" + ], + "line": 61, + "token": 379, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 52, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef to_numpy(\n evset: EventSet,\n timestamp_to_datetime: bool = True,\n timestamps: bool = True,\n) -> Dict[str, ndarray]:\n timestamp_key = \"timestamp\"\n index_names = evset.schema.index_names()\n feature_names = evset.schema.feature_names()\n\n column_names = index_names + feature_names\n if timestamps:\n column_names += [timestamp_key]\n\n dst = {column_name: [] for column_name in column_names}\n for index, data in evset.data.items():\n assert isinstance(index, tuple)\n\n if timestamps:\n # Timestamps\n if evset.schema.is_unix_timestamp and timestamp_to_datetime:\n dst[timestamp_key].append(\n data.timestamps.astype(\"datetime64[s]\")\n )\n else:\n dst[timestamp_key].append(data.timestamps)\n\n # Features\n for feature_name, feature in zip(feature_names, data.features):\n dst[feature_name].append(feature)\n\n # Indexes\n num_timestamps = len(data.timestamps)\n for index_name, index_item in zip(index_names, index):\n dst[index_name].append(np.repeat(index_item, num_timestamps))\n\n dst = {k: np.concatenate(v) for k, v in dst.items()}\n return dst\n```\n###test function signature:\n```python\n\n def test_correct(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct", + "reference": "\n def test_correct(self):\n evset = event_set(\n timestamps=[\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n features={\n \"feature_1\": [0.5, 0.6],\n \"my_index\": [\"red\", \"blue\"],\n },\n indexes=[\"my_index\"],\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"timestamp\": np.array(\n [\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n dtype=\"datetime64[s]\",\n ),\n \"feature_1\": np.array([0.5, 0.6]),\n \"my_index\": np.array([b\"red\", b\"blue\"]),\n }\n\n for k in expected:\n np.testing.assert_array_equal(\n np.sort(result[k]), np.sort(expected[k])\n )", + "completions": [ + "```python\n def test_correct(self):\n evset = EventSet(\n data={\n \"timestamp\": [1610000000, 1610000000, 1610000000, 1610000000],\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n },\n schema=EventSetSchema(\n index_names=[\"index\"],\n feature_names=[\"feature1\", \"feature2\"],\n is_unix_timestamp=True,\n ),\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n \"timestamp\": [\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n ],\n }\n\n self.assertDictEqual(result, expected)\n```\n" + ], + "line": 41, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def from_file(cls, path: str) -> \"LayerConfig\":\n logger.info(f\"Loading route config from {path}\")\n _, ext = os.path.splitext(path)\n with open(path, \"r\") as f:\n if ext == \".json\":\n layer = json.load(f)\n elif ext in [\".yaml\", \".yml\"]:\n layer = yaml.safe_load(f)\n else:\n raise ValueError(\n \"Unsupported file type. Only .json and .yaml are supported\"\n )\n\n if not is_valid(json.dumps(layer)):\n raise Exception(\"Invalid config JSON or YAML\")\n\n encoder_type = layer[\"encoder_type\"]\n encoder_name = layer[\"encoder_name\"]\n routes = []\n for route_data in layer[\"routes\"]:\n # Handle the 'llm' field specially if it exists\n if \"llm\" in route_data and route_data[\"llm\"] is not None:\n llm_data = route_data.pop(\n \"llm\"\n ) # Remove 'llm' from route_data and handle it separately\n # Use the module path directly from llm_data without modification\n llm_module_path = llm_data[\"module\"]\n # Dynamically import the module and then the class from that module\n llm_module = importlib.import_module(llm_module_path)\n llm_class = getattr(llm_module, llm_data[\"class\"])\n # Instantiate the LLM class with the provided model name\n llm = llm_class(name=llm_data[\"model\"])\n # Reassign the instantiated llm object back to route_data\n route_data[\"llm\"] = llm\n\n # Dynamically create the Route object using the remaining route_data\n route = Route(**route_data)\n routes.append(route)\n\n return cls(\n encoder_type=encoder_type, encoder_name=encoder_name, routes=routes\n )\n```\n###test function signature:\n```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_from_file_yaml", + "reference": "\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n config_path = tmp_path / \"config.yaml\"\n config_path.write_text(\n layer_yaml()\n ) # Assuming layer_yaml() returns a valid YAML string\n\n # Load the LayerConfig from the temporary file\n layer_config = LayerConfig.from_file(str(config_path))\n\n # Assertions to verify the loaded configuration\n assert layer_config.encoder_type == \"cohere\"\n assert layer_config.encoder_name == \"embed-english-v3.0\"\n assert len(layer_config.routes) == 2\n assert layer_config.routes[0].name == \"politics\"", + "completions": [ + "```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n yaml_file = tmp_path / \"test_layer.yaml\"\n yaml_file.write_text(\n \"\"\"\n encoder_type: openai\n encoder_name: gpt2\n routes:\n - name: test_route\n model_name: test_model\n llm:\n module: tests.test_layer\n class: TestLLM\n model: test_model\n encoder_kwargs:\n max_position_embeddings: 512\n num_attention_heads: 8\n intermediate_size: 2048\n hidden_act: gelu\n hidden_dropout_prob: 0.1\n attention_probs_dropout_prob: 0.1\n type_vocab_size: 16\n initializer_range: 0.02\n layer_norm_eps: 1e-12\n use_cache: True\n is_decoder: False\n classifier_dropout: 0.1\n use_multiplicative_input_mask: False\n num_labels: 2\n output_attentions: False\n output_hidden_states: False\n return_dict: True\n task_specific_params:\n summarization:\n max_length: 20\n min_length: 5\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n" + ], + "line": 49, + "token": 405, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 54, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def bind_all(self, circuit: QuantumCircuit) -> np.ndarray:\n arr = np.empty(self.shape, dtype=object)\n for idx in np.ndindex(self.shape):\n arr[idx] = self.bind(circuit, idx)\n return arr\n```\n###test function signature:\n```python\n def test_data(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_data", + "reference": " def test_data(self):\n with self.subTest(\"binding a single value\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 0)\n self.assertEqual(ba.shape, ())\n self.assertEqual(ba.size, 1)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuit = ba.bind(self.circuit, ())\n self.assertEqual(bound_circuit, self.circuit.assign_parameters(vals))\n\n with self.subTest(\"binding an array\"):\n vals = np.linspace(0, 1, 300).reshape((2, 3, 50))\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 2)\n self.assertEqual(ba.shape, (2, 3))\n self.assertEqual(ba.size, 6)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuits = ba.bind_all(self.circuit)\n self.assertEqual(bound_circuits.shape, (2, 3))\n self.assertEqual(bound_circuits[0, 0], self.circuit.assign_parameters(vals[0, 0]))\n self.assertEqual(bound_circuits[0, 1], self.circuit.assign_parameters(vals[0, 1]))\n self.assertEqual(bound_circuits[0, 2], self.circuit.assign_parameters(vals[0, 2]))\n self.assertEqual(bound_circuits[1, 0], self.circuit.assign_parameters(vals[1, 0]))\n self.assertEqual(bound_circuits[1, 1], self.circuit.assign_parameters(vals[1, 1]))\n self.assertEqual(bound_circuits[1, 2], self.circuit.assign_parameters(vals[1, 2]))\n\n with self.subTest(\"binding a single param\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params[0]: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 1)\n self.assertEqual(ba.ndim, 1)\n self.assertEqual(ba.shape, (50,))\n self.assertEqual(ba.size, 50)\n self.assertEqual(list(ba.data.keys()), [(self.params[0].name,)])\n np.testing.assert_allclose(list(ba.data.values()), [vals[..., np.newaxis]])", + "completions": [ + "```python\n def test_data(self):\n from qiskit import QuantumCircuit\n from qiskit.quantum_info import Pauli\n from qiskit.quantum_info import PauliTable\n from qiskit.quantum_info import PauliSumOp\n from qiskit.quantum_info import PauliOp\n from qiskit.quantum_info import PauliExpectation\n from qiskit.quantum_info import PauliMoment\n from qiskit.quantum_info import PauliTensor\n from qiskit.quantum_info import PauliList\n from qiskit.quantum_info import PauliTransform\n from qiskit.quantum_info import PauliTransformOp\n from qiskit.quantum_info import PauliTransformList\n from qiskit.quantum_info import PauliTransformSum\n from qiskit.quantum_info import PauliTransformProduct\n from qiskit.quantum_info import PauliTransformExpectation\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum" + ], + "line": 15, + "token": 122, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef set_seed(\n seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None\n) -> int:\n world_size, rank = get_world_size_and_rank()\n max_val = np.iinfo(np.uint32).max - world_size + 1\n min_val = np.iinfo(np.uint32).min\n if seed is None:\n rand_seed = torch.randint(min_val, max_val, (1,))\n seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks\n if seed < min_val or seed > max_val:\n raise ValueError(\n f\"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]\"\n )\n local_seed = seed + rank\n _log.debug(\n f\"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}\"\n )\n\n torch.manual_seed(local_seed)\n np.random.seed(local_seed)\n random.seed(local_seed)\n\n if debug_mode is not None:\n _log.debug(f\"Setting deterministic debug mode to {debug_mode}\")\n torch.set_deterministic_debug_mode(debug_mode)\n deterministic_debug_mode = torch.get_deterministic_debug_mode()\n if deterministic_debug_mode == 0:\n _log.debug(\"Disabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = False\n torch.backends.cudnn.benchmark = True\n else:\n _log.debug(\"Enabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.benchmark = False\n # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n return seed\n```\n###test function signature:\n```python\n\n def test_deterministic_true(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_deterministic_true", + "reference": "\n def test_deterministic_true(self) -> None:\n for det_debug_mode, det_debug_mode_str in [(1, \"warn\"), (2, \"error\")]:\n warn_only = det_debug_mode == 1\n for debug_mode in (det_debug_mode, det_debug_mode_str):\n set_seed(42, debug_mode=debug_mode)\n assert torch.backends.cudnn.deterministic\n assert not torch.backends.cudnn.benchmark\n assert det_debug_mode == torch.get_deterministic_debug_mode()\n assert torch.are_deterministic_algorithms_enabled()\n assert (\n warn_only == torch.is_deterministic_algorithms_warn_only_enabled()\n )\n assert os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] == \":4096:8\"", + "completions": [ + "```python\n def test_deterministic_true(self) -> None:\n seed = set_seed(debug_mode=True)\n self.assertEqual(seed, 123456789)\n```\n" + ], + "line": 45, + "token": 345, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 56, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n```\n###test function signature:\n```python\n\ndef test_concat_periodarray_2d():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_concat_periodarray_2d", + "reference": "\ndef test_concat_periodarray_2d():\n pi = pd.period_range(\"2016-01-01\", periods=36, freq=\"D\")\n arr = pi._data.reshape(6, 6)\n\n result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)\n tm.assert_period_array_equal(result, arr)\n\n result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)\n tm.assert_period_array_equal(result, arr)\n\n msg = (\n \"all the input array dimensions.* for the concatenation axis must match exactly\"\n )\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)\n\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:2], arr[2:]], axis=1)", + "completions": [ + "```python\n\ndef test_concat_periodarray_2d():\n import numpy as np\n import pandas as pd\n from pandas.core.dtypes.common import is_period_dtype\n\n def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n\n def _concat_datetime(to_concat, axis: AxisInt = 0):\n from pandas.core.arrays import PeriodArray\n\n if axis == 0:\n " + ], + "line": 70, + "token": 610, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 57, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef ingest_ocp_payload(request):\n request_id = uuid4().hex\n response_data = {\"request-id\": request_id, \"payload-name\": []}\n if request.method == \"POST\":\n if not request.FILES:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for _, file in request.FILES.items():\n payload_name = file.name\n response_data[\"payload-name\"].append(payload_name)\n s3_signature = get_s3_signature(settings.S3_ENDPOINT, payload_name, method=\"put_object\")\n res = upload_file_to_s3(s3_signature, data=file.file)\n if res.status_code == HTTPStatus.OK:\n response_data[\"upload\"] = \"success\"\n else:\n response_data[\"upload\"] = \"failed\"\n response_data[\"failed-reason\"] = res.reason\n return Response(response_data, status=res.status_code)\n send_payload(request_id, payload_name)\n else:\n params = request.query_params\n payload_names = params.get(\"payload_name\")\n if not payload_names:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for payload_name in payload_names.split(\",\"):\n send_payload(request_id, payload_name)\n response_data[\"payload-name\"].append(payload_name)\n\n response_data[\"ingest-started\"] = True\n\n return Response(response_data, status=HTTPStatus.ACCEPTED)\n```\n###test function signature:\n```python\n\n def test_ingest_ocp_payload(self):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ingest_ocp_payload", + "reference": "\n def test_ingest_ocp_payload(self):\n # Arrange\n file1 = SimpleUploadedFile(\"file1.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n file2 = SimpleUploadedFile(\"file2.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n test_table = [\n # Happy path tests\n (\n \"POST\",\n {\"file1\": file1, \"file2\": file2},\n \"\",\n HTTPStatus.ACCEPTED,\n {\"upload\": \"success\", \"ingest-started\": True},\n \"happy_path_post\",\n ),\n (\"GET\", {}, \"?payload_name=file1,file2\", HTTPStatus.ACCEPTED, {\"ingest-started\": True}, \"happy_path_get\"),\n # Edge cases\n (\"POST\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_files_post\"),\n (\"GET\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_payload_name_get\"),\n # Error cases\n (\"POST\", {\"file1\": file1}, \"\", HTTPStatus.INTERNAL_SERVER_ERROR, {\"upload\": \"failed\"}, \"error_case_post\"),\n (\"GET\", {}, \"?payload_name=non_existent_file\", HTTPStatus.ACCEPTED, {}, \"error_case_get\"),\n ]\n for test in test_table:\n method, files, query_params, expected_status, expected_response, test_id = test\n with self.subTest(test=test_id):\n\n mock_response = MagicMock()\n mock_response.status_code = (\n HTTPStatus.OK if expected_response.get(\"upload\") != \"failed\" else HTTPStatus.INTERNAL_SERVER_ERROR\n )\n mock_response.reason = \"Some reason\"\n\n with patch(\"masu.api.ingest_ocp_payload.get_s3_signature\") as mock_get_s3_signature, patch(\n \"masu.api.ingest_ocp_payload.upload_file_to_s3\", return_value=mock_response\n ) as mock_upload_file_to_s3, patch(\"masu.api.ingest_ocp_payload.send_payload\") as mock_send_payload:\n\n request = (\n APIRequestFactory().get(f\"/{query_params}\")\n if method == \"GET\"\n else APIRequestFactory().post(\"/\", data=files)\n )\n\n # Act\n response = ingest_ocp_payload(request)\n\n # Assert\n self.assertEqual(response.status_code, expected_status)\n if method == \"POST\" and files:\n mock_get_s3_signature.assert_called()\n mock_upload_file_to_s3.assert_called()\n if method == \"GET\" and \"payload_name\" in query_params:\n mock_send_payload.assert_called()\n for key, value in expected_response.items():\n self.assertEqual(response.data[key], value)", + "completions": [ + "```python\n def test_ingest_ocp_payload(self):\n # Arrange\n request = self.factory.post(\n \"/api/v1/ingest/ocp/payload\",\n data={\"payload_name\": \"payload_name\"},\n content_type=\"multipart/form-data\",\n )\n request.FILES = {\"payload_name\": \"payload_name\"}\n # Act\n response = ingest_ocp_payload(request)\n # Assert\n self.assertEqual(response.status_code, HTTPStatus.ACCEPTED)\n self.assertEqual(response.data[\"request-id\"], \"request_id\")\n self.assertEqual(response.data[\"payload-name\"], [\"payload_name\"])\n self.assertEqual(response.data[\"ingest-started\"], True)\n```\n" + ], + "line": 41, + "token": 308, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def compute_nb_items_with_metadata_field(\n project_id: str,\n metadata_field: str,\n collection_name: str,\n quantile_value: float = 0.1,\n) -> Tuple[int, float, int]:\n\n mongo_db = await get_mongo_db()\n\n try:\n pipeline = [\n # Match on the project id and the existence of the metadata field\n {\n \"$match\": {\n \"project_id\": project_id,\n f\"metadata.{metadata_field}\": {\"$exists\": True, \"$ne\": None},\n }\n },\n {\n \"$group\": {\n \"_id\": \"$user_id\", # Group by 'user_id'\n \"count\": {\"$sum\": 1}, # Count the documents in each group\n }\n },\n {\n \"$sort\": {\"count\": -1} # Sort by 'count' in descending order\n },\n ]\n\n # Execute the aggregation pipeline\n results = [] # List to hold the results\n async for group in mongo_db[collection_name].aggregate(pipeline):\n results.append(group[\"count\"])\n\n logger.debug(\n f\"Results for {metadata_field} in collection {collection_name} for project {project_id}: {results}\"\n )\n\n # If no results, return 0\n if len(results) == 0:\n return 0, 0.0, 0\n\n # Get the average and the quantiles\n average = sum(results) / len(results)\n bottom_quantile = results[int(len(results) * quantile_value)]\n top_quantile = results[int(len(results) * (1 - quantile_value))]\n\n return bottom_quantile, average, top_quantile\n\n except Exception as e:\n logger.warning(\n f\"Failed to fetch the number of {metadata_field} in collection {collection_name} for project {project_id}: {e}\",\n )\n return 0, 0.0, 0\n```\n###test function signature:\n```python\n\nasync def test_main_pipeline(db, populated_project):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_main_pipeline", + "reference": "\nasync def test_main_pipeline(db, populated_project):\n async for mongo_db in db:\n test_project_id = populated_project.id\n\n # Run the function\n topics = await fetch_topics(test_project_id)\n\n logger.debug(f\"Topics: {topics}\")\n\n count = await nb_items_with_a_metadata_field(\n test_project_id, \"tasks\", \"user_id\"\n )\n\n assert count > 0\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_successrate_metadata_quantiles(test_project_id, \"user_id\")\n\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert bottom_quantile >= 0\n # assert bottom_quantile <= top_quantile\n\n test_project_id = populated_project.id\n test_metadata_field = \"user_id\"\n test_collection_name = \"tasks\"\n test_quantile_value = 0.1\n\n # Run the function\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_nb_items_with_metadata_field(\n test_project_id,\n test_metadata_field,\n test_collection_name,\n test_quantile_value,\n )\n\n logger.debug(\"runned compute_nb_items_with_metadata_field\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, int)\n assert isinstance(average, float)\n assert isinstance(top_quantile, int)\n # assert bottom_quantile <= average <= top_quantile\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_session_length_per_metadata(\n test_project_id,\n test_metadata_field,\n test_quantile_value,\n )\n logger.debug(\"runned compute_session_length_per_metadata\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, float)\n assert isinstance(average, float)\n assert isinstance(top_quantile, float)\n # assert bottom_quantile <= average <= top_quantile", + "completions": [ + "```python\nasync def test_main_pipeline(db, populated_project):\n # Arrange\n project_id = populated_project.id\n collection_name = \"collection_name\"\n metadata_field = \"metadata_field\"\n quantile_value = 0.1\n\n # Act\n result = await compute_nb_items_with_metadata_field(\n project_id, metadata_field, collection_name, quantile_value\n )\n\n # Assert\n assert result == (0, 0.0, 0)\n```\n" + ], + "line": 56, + "token": 430, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n```\n###test function signature:\n```python\n\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_engine_auto_error_message", + "reference": "\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n from pandas.compat._optional import VERSIONS\n\n # Do we have engines installed, but a bad version of them?\n pa_min_ver = VERSIONS.get(\"pyarrow\")\n fp_min_ver = VERSIONS.get(\"fastparquet\")\n have_pa_bad_version = (\n False\n if not _HAVE_PYARROW\n else Version(pyarrow.__version__) < Version(pa_min_ver)\n )\n have_fp_bad_version = (\n False\n if not _HAVE_FASTPARQUET\n else Version(fastparquet.__version__) < Version(fp_min_ver)\n )\n # Do we have usable engines installed?\n have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version\n have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version\n\n if not have_usable_pa and not have_usable_fp:\n # No usable engines found.\n if have_pa_bad_version:\n match = f\"Pandas requires version .{pa_min_ver}. or newer of .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n\n if have_fp_bad_version:\n match = f\"Pandas requires version .{fp_min_ver}. or newer of .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")", + "completions": [ + "```python\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n import pyarrow\n import fastparquet\n from pyarrow.parquet import ParquetEngine\n from fastparquet.parquet import ParquetEngine as FastParquetEngine\n\n def get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n\n # pyarrow is installed\n assert get_engine(\"auto\").__class__ == PyArrowImpl\n assert get_engine(\"pyarrow\").__class__ == PyArrowImpl\n assert get_engine(\"fastparquet\").__class__ == FastParquetImpl\n\n # pyarrow is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'pyarrow'\" in str(excinfo.value)\n\n # fastparquet is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'fastparquet'\" in str(excinfo.value)\n\n # pyarrow is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n\n # fastparquet is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n" + ], + "line": 38, + "token": 277, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 60, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def attribute(\n self,\n inp: InterpretableInput,\n target: Union[str, torch.Tensor, None] = None,\n num_trials: int = 1,\n gen_args: Optional[Dict] = None,\n # internal callback hook can be used for logging\n _inspect_forward: Optional[Callable] = None,\n **kwargs,\n ) -> LLMAttributionResult:\n\n assert isinstance(\n inp, self.SUPPORTED_INPUTS\n ), f\"LLMAttribution does not support input type {type(inp)}\"\n\n if target is None:\n # generate when None\n assert hasattr(self.model, \"generate\") and callable(self.model.generate), (\n \"The model does not have recognizable generate function.\"\n \"Target must be given for attribution\"\n )\n\n if not gen_args:\n gen_args = DEFAULT_GEN_ARGS\n\n model_inp = self._format_model_input(inp.to_model_input())\n output_tokens = self.model.generate(model_inp, **gen_args)\n target_tokens = output_tokens[0][model_inp.size(1) :]\n else:\n assert gen_args is None, \"gen_args must be None when target is given\"\n\n if type(target) is str:\n # exclude sos\n target_tokens = self.tokenizer.encode(target)[1:]\n target_tokens = torch.tensor(target_tokens)\n elif type(target) is torch.Tensor:\n target_tokens = target\n\n attr = torch.zeros(\n [\n 1 + len(target_tokens) if self.include_per_token_attr else 1,\n inp.n_itp_features,\n ],\n dtype=torch.float,\n device=self.device,\n )\n\n for _ in range(num_trials):\n attr_input = inp.to_tensor().to(self.device)\n\n cur_attr = self.attr_method.attribute(\n attr_input,\n additional_forward_args=(inp, target_tokens, _inspect_forward),\n **kwargs,\n )\n\n # temp necessary due to FA & Shapley's different return shape of multi-task\n # FA will flatten output shape internally (n_output_token, n_itp_features)\n # Shapley will keep output shape (batch, n_output_token, n_input_features)\n cur_attr = cur_attr.reshape(attr.shape)\n\n attr += cur_attr\n\n attr = attr / num_trials\n\n attr = inp.format_attr(attr)\n\n return LLMAttributionResult(\n attr[0],\n attr[1:]\n if self.include_per_token_attr\n else None, # shape(n_output_token, n_input_features)\n inp.values,\n self.tokenizer.convert_ids_to_tokens(target_tokens),\n )\n```\n###test function signature:\n```python\n\n def test_llm_attr(self, AttrClass) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_llm_attr", + "reference": "\n def test_llm_attr(self, AttrClass) -> None:\n llm = DummyLLM()\n llm.to(self.device)\n tokenizer = DummyTokenizer()\n llm_attr = LLMAttribution(AttrClass(llm), tokenizer)\n\n inp = TextTemplateInput(\"{} b {} {} e {}\", [\"a\", \"c\", \"d\", \"f\"])\n res = llm_attr.attribute(inp, \"m n o p q\")\n\n self.assertEqual(res.seq_attr.shape, (4,))\n self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))\n self.assertEqual(res.input_tokens, [\"a\", \"c\", \"d\", \"f\"])\n self.assertEqual(res.output_tokens, [\"m\", \"n\", \"o\", \"p\", \"q\"])\n self.assertEqual(res.seq_attr.device.type, self.device)\n self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)", + "completions": [ + "```python\n def test_llm_attr(self, AttrClass) -> None:\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for L" + ], + "line": 72, + "token": 464, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 61, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def plan(self, task: str):\n agent_pool_description = self.agent_pool.get_agent_pool_description()\n feedback = self.get_last_feedback()\n last_plan = self.get_last_plan()\n\n # run the planning\n plan = self.planning.execute(task, agent_pool_description, last_plan, feedback)\n\n logger.info(f\"Plan: {plan}\")\n\n self.shared_memory.add(\n event_type=EventType.planning, agent=self.name, content=str(plan)\n )\n\n return plan\n```\n###test function signature:\n```python\n\ndef test_planner_succeeds(get_llm): # noqa: F811\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_planner_succeeds", + "reference": "\ndef test_planner_succeeds(get_llm): # noqa: F811\n llm = get_llm(__file__, test_planner_succeeds.__name__)\n\n physicist_description = (\n \"The physicist agent answers questions or research about physics-related topics\"\n )\n physicist = Physicist(\n name=\"Physicist\",\n description=physicist_description,\n llm=llm,\n )\n\n agent_pool = AgentPool()\n agent_pool.add_agents([physicist])\n\n shared_memory = SharedMemory(\n objective=\"Share the information across different agents.\",\n agent_pool=agent_pool,\n )\n\n planner = Planner(\n name=\"planner\",\n agent_pool=agent_pool,\n shared_memory=shared_memory,\n llm=llm,\n )\n\n task = \"\"\"We need to render a highly complex 3D image on the solar system. We can use any publicly avaliable\n resources to achieve this task.\"\"\" # noqa: E501\n\n p = planner.plan(task=task)\n assert len(p.steps) > 0", + "completions": [ + "```python\ndef test_planner_succeeds(get_llm): # noqa: F811\n planner = get_llm\n planner.agent_pool.set_agent_pool_description(\n {\n \"agent_pool\": {\n \"agents\": [\n {\n \"name\": \"agent1\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n {\n \"name\": \"agent2\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n ]\n }\n }\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task1\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task2\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task3\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task4\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task5\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task6\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task7\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\"," + ], + "line": 21, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_edge(self, from_step: str, to_step: str) -> None:\n if from_step not in self.G:\n raise ValueError(f\"Step with name '{from_step}' does not exist\")\n\n if to_step not in self.G:\n raise ValueError(f\"Step with name '{to_step}' does not exist\")\n\n if to_step in self.G[from_step]:\n raise ValueError(\n f\"There is already a edge from '{to_step}' to '{from_step}'\"\n )\n\n if to_step in nx.ancestors(self.G, from_step):\n raise ValueError(\n f\"Cannot add edge from '{from_step}' to '{to_step}' as it would create a cycle.\"\n )\n\n self.G.add_edge(from_step, to_step)\n```\n###test function signature:\n```python\n\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validate_missing_step_input", + "reference": "\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n class DummyStep3(Step):\n @property\n def inputs(self) -> List[str]:\n return [\"instruction\"]\n\n @property\n def outputs(self) -> List[str]:\n return [\"response\"]\n\n def process(self) -> \"StepOutput\": # type: ignore\n yield [{\"response\": \"response1\"}]\n\n dag = DAG()\n dag.add_step(dummy_generator_step)\n dag.add_step(DummyStep3(name=\"dummy_step_3\", pipeline=pipeline))\n dag.add_edge(\"dummy_generator_step\", \"dummy_step_3\")\n\n with pytest.raises(\n ValueError,\n match=\"Step 'dummy_step_3' should have a parameter with type hint `StepInput`\",\n ):\n dag.validate()", + "completions": [ + "```python\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_2\")\n```\n" + ], + "line": 26, + "token": 238, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 63, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_hourly_temp_mean", + "reference": "\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n meter_data = il_electricity_cdd_hdd_hourly[\"meter_data\"][\"2016-03-01\":\"2016-07-01\"]\n temperature_data = il_electricity_cdd_hdd_hourly[\"temperature_data\"][\n \"2016-03-01\":\"2016-07-01\"\n ]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert list(sorted(df.columns)) == [\n \"n_hours_dropped\",\n \"n_hours_kept\",\n \"temperature_mean\",\n ]\n assert df.shape == (2952, 3)\n\n assert round(df.temperature_mean.mean()) == 62.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n df = il_electricity_cdd_hdd_hourly.copy()\n df = df.iloc[100:150]\n df = df.assign(\n **{\n \"cdd_10\": np.maximum(df.temperature_mean - 10, 0),\n \"cdd_20\": np.maximum(df.temperature_mean - 20, 0),\n \"hdd_10\": np.maximum(10 - df.temperature_mean, 0),\n \"hdd_20\": np.maximum(20 - df.temperature_mean, 0),\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n df = df.drop([\"temperature_mean\"], axis=1)\n df = df.reindex(il_electricity_cdd_hdd_hourly.index)\n df = overwrite_partial_rows_with_nan(df)\n df = df.iloc[:-1].reindex(df.index)\n assert_frame_equal(df, il_electricity_cdd_hdd_hourly)\n```\n" + ], + "line": 167, + "token": 992, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 64, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef pandasSQL_builder(\n con,\n schema: str | None = None,\n need_transaction: bool = False,\n) -> PandasSQL:\n import sqlite3\n\n if isinstance(con, sqlite3.Connection) or con is None:\n return SQLiteDatabase(con)\n\n sqlalchemy = import_optional_dependency(\"sqlalchemy\", errors=\"ignore\")\n\n if isinstance(con, str) and sqlalchemy is None:\n raise ImportError(\"Using URI string without sqlalchemy installed.\")\n\n if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):\n return SQLDatabase(con, schema, need_transaction)\n\n warnings.warn(\n \"pandas only supports SQLAlchemy connectable (engine/connection) or \"\n \"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 \"\n \"objects are not tested. Please consider using SQLAlchemy.\",\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return SQLiteDatabase(con)\n```\n###test function signature:\n```python\n\n def test_execute_closed_connection(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_execute_closed_connection", + "reference": "\n def test_execute_closed_connection(self):\n create_sql = \"\"\"\n CREATE TABLE test\n (\n a TEXT,\n b TEXT,\n c REAL,\n PRIMARY KEY (a, b)\n );\n \"\"\"\n with contextlib.closing(sqlite3.connect(\":memory:\")) as conn:\n cur = conn.cursor()\n cur.execute(create_sql)\n\n with sql.pandasSQL_builder(conn) as pandas_sql:\n pandas_sql.execute('INSERT INTO test VALUES(\"foo\", \"bar\", 1.234)')\n\n msg = \"Cannot operate on a closed database.\"\n with pytest.raises(sqlite3.ProgrammingError, match=msg):\n tquery(\"select * from test\", con=conn)", + "completions": [ + "```python\n def test_execute_closed_connection(self):\n from pandas.io.sql import PandasSQL\n from pandas.io.sql.sqldatabase import SQLiteDatabase\n from pandas.io.sql.pandas_sql import PandasSQLQueryCompiler\n from pandas.io.sql.pandas_sql import PandasSQLEngine\n from pandas.io.sql.pandas_sql import PandasSQLTable\n from pandas.io.sql.pandas_sql import PandasSQLConverter\n from pandas.io.sql.pandas_sql import PandasSQLAlchemy\n from pandas.io.sql.pandas_sql import PandasSQLQuery\n from pandas.io.sql.pandas_sql import PandasSQLExecutionContext\n from pandas.io.sql.pandas_sql import PandasSQLCompiler\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas." + ], + "line": 31, + "token": 237, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 65, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_hourly_degree_days", + "reference": "\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index,\n temperature_data,\n heating_balance_points=[60, 61],\n cooling_balance_points=[65, 66],\n temperature_mean=False,\n degree_day_method=\"hourly\",\n )\n assert df.shape == (27, 6)\n assert list(sorted(df.columns)) == [\n \"cdd_65\",\n \"cdd_66\",\n \"hdd_60\",\n \"hdd_61\",\n \"n_hours_dropped\",\n \"n_hours_kept\",\n ]\n snapshot.assert_match(\n [\n round(df.hdd_60.mean(), 2),\n round(df.hdd_61.mean(), 2),\n round(df.cdd_65.mean(), 2),\n round(df.cdd_66.mean(), 2),\n round(df.n_hours_kept.mean(), 2),\n round(df.n_hours_dropped.mean(), 2),\n ],\n \"values\",\n )", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"H\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 1" + ], + "line": 168, + "token": 985, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef initialize_model_parallel(\n tensor_model_parallel_size: int = 1,\n pipeline_model_parallel_size: int = 1,\n virtual_pipeline_model_parallel_size: Optional[int] = None,\n pipeline_model_parallel_split_rank: Optional[int] = None,\n use_fp8: bool = False,\n) -> None:\n # Get world size and rank. Ensure some consistencies.\n assert torch.distributed.is_initialized()\n world_size: int = torch.distributed.get_world_size()\n\n if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:\n raise RuntimeError(\n f\"world_size ({world_size}) is not divisible by tensor_model_parallel_size \"\n f\"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n )\n\n data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n\n num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n num_data_parallel_groups: int = world_size // data_parallel_size\n\n if virtual_pipeline_model_parallel_size is not None:\n if not pipeline_model_parallel_size > 2:\n raise RuntimeError(\"pipeline-model-parallel size should be greater than 2 with \" \"interleaved schedule\")\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size\n\n if pipeline_model_parallel_split_rank is not None:\n global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK\n _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank\n\n rank = torch.distributed.get_rank()\n\n # Build the data-parallel groups.\n global _DATA_PARALLEL_GROUP\n global _DATA_PARALLEL_GROUP_GLOO\n global _DATA_PARALLEL_GLOBAL_RANKS\n assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'\n all_data_parallel_group_ranks = []\n for i in range(pipeline_model_parallel_size):\n start_rank = i * num_pipeline_model_parallel_groups\n end_rank = (i + 1) * num_pipeline_model_parallel_groups\n for j in range(tensor_model_parallel_size):\n ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)\n all_data_parallel_group_ranks.append(list(ranks))\n group = torch.distributed.new_group(ranks)\n group_gloo = torch.distributed.new_group(ranks, backend=\"gloo\")\n if rank in ranks:\n _DATA_PARALLEL_GROUP = group\n _DATA_PARALLEL_GROUP_GLOO = group_gloo\n _DATA_PARALLEL_GLOBAL_RANKS = ranks\n\n # Build the model-parallel groups.\n global _MODEL_PARALLEL_GROUP\n assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'\n for i in range(data_parallel_size):\n ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _MODEL_PARALLEL_GROUP = group\n\n # Build the tensor model-parallel groups.\n global _TENSOR_MODEL_PARALLEL_GROUP\n assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized'\n for i in range(num_tensor_model_parallel_groups):\n ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _TENSOR_MODEL_PARALLEL_GROUP = group\n\n # Build the pipeline model-parallel groups and embedding groups\n # (first and last rank in each pipeline model-parallel group).\n global _PIPELINE_MODEL_PARALLEL_GROUP\n global _PIPELINE_GLOBAL_RANKS\n assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized'\n global _EMBEDDING_GROUP\n global _EMBEDDING_GLOBAL_RANKS\n assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'\n global _POSITION_EMBEDDING_GROUP\n global _POSITION_EMBEDDING_GLOBAL_RANKS\n assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'\n for i in range(num_pipeline_model_parallel_groups):\n ranks = range(i, world_size, num_pipeline_model_parallel_groups)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _PIPELINE_MODEL_PARALLEL_GROUP = group\n _PIPELINE_GLOBAL_RANKS = ranks\n # Setup embedding group (to exchange gradients between\n # first and last stages).\n if len(ranks) > 1:\n embedding_ranks = [ranks[0], ranks[-1]]\n position_embedding_ranks = [ranks[0]]\n if pipeline_model_parallel_split_rank is not None:\n if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:\n embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]]\n if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:\n position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]\n else:\n embedding_ranks = ranks\n position_embedding_ranks = ranks\n\n group = torch.distributed.new_group(embedding_ranks)\n if rank in embedding_ranks:\n _EMBEDDING_GROUP = group\n if rank in ranks:\n _EMBEDDING_GLOBAL_RANKS = embedding_ranks\n\n group = torch.distributed.new_group(position_embedding_ranks)\n if rank in position_embedding_ranks:\n _POSITION_EMBEDDING_GROUP = group\n if rank in ranks:\n _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks\n\n # Build the FP8 groups.\n global _AMAX_REDUCTION_GROUP\n assert _AMAX_REDUCTION_GROUP is None, \\\n 'FP8 amax reduction group is already initialized'\n if use_fp8:\n amax_group_size: int = tensor_model_parallel_size * data_parallel_size\n num_amax_groups: int = world_size // amax_group_size\n for i in range(num_amax_groups):\n start_rank = i * amax_group_size\n end_rank = (i + 1) * amax_group_size\n ranks = range(start_rank, end_rank)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _AMAX_REDUCTION_GROUP = group\n\n # Initialize global memory buffer\n # This isn't really \"parallel state\" but there isn't another good place to\n # put this. If we end up with a more generic initialization of megatron-core\n # we could stick it there\n _set_global_memory_buffer()\n```\n###test function signature:\n```python\n\ndef test_initialize__and_destroy_model_parallel():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_initialize__and_destroy_model_parallel", + "reference": "\ndef test_initialize__and_destroy_model_parallel():\n with pytest.raises(AssertionError):\n assert(ps.initialize_model_parallel())\n Utils.initialize_distributed()\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))\n Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)\n\n assert(ps.model_parallel_is_initialized())\n assert(ps.get_model_parallel_group() is not None)\n assert(ps.get_tensor_model_parallel_group() is not None)\n assert(ps.get_pipeline_model_parallel_group() is not None)\n assert(ps.get_data_parallel_group() is not None) \n Utils.destroy_model_parallel()\n assert(ps._MODEL_PARALLEL_GROUP is None)", + "completions": [ + "```python\ndef test_initialize_and_destroy_model_parallel():\n initialize_model_parallel(\n tensor_model_parallel_size=1,\n pipeline_model_parallel_size=1,\n virtual_pipeline_model_parallel_size=None,\n pipeline_model_parallel_split_rank=None,\n use_fp8=False,\n )\n destroy_model_parallel()\n```\n" + ], + "line": 133, + "token": 886, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_shapes(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_shapes", + "reference": "\n def test_spectrogram_shapes(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 729))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 183))", + "completions": [ + "```python\n def test_spectrogram_shapes(self):\n waveform = np.random.rand(1000)\n window = np.hanning(100)\n frame_length = 100\n hop_length = 10\n fft_length = 100\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (10, 100))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_base64_image_given_and_no_resize_needed", + "reference": "\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n _, image = example_local_image\n _, buffer = cv2.imencode(\".jpg\", image)\n base64_image = base64.b64encode(buffer).decode(\"ascii\")\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=base64_image,\n max_width=None,\n max_height=None,\n )\n\n # then\n assert (\n scaling_factor is None\n ), \"No resize parameters given, so scaling factor should not be established\"\n assert (\n serialised_image == base64_image\n ), \"Serialised image should be identical with input\"", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n image_bytes = base64.b64decode(example_local_image[0])\n image = bytes_to_opencv_image(payload=image_bytes)\n # when\n result = load_image_from_string(\n reference=example_local_image[0],\n max_height=None,\n max_width=None,\n )\n # then\n assert result[0] == example_local_image[0]\n assert result[1] is None\n assert np.array_equal(image, result[1])\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 69, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def submit_payout_address(self, robot_index=1):\n path = reverse(\"order\")\n params = f\"?order_id={self.order_id}\"\n headers = self.get_robot_auth(robot_index)\n\n payout_address = create_address(\"robot\")\n signed_payout_address = sign_message(\n payout_address,\n passphrase_path=f\"tests/robots/{robot_index}/token\",\n private_key_path=f\"tests/robots/{robot_index}/enc_priv_key\",\n )\n body = {\n \"action\": \"update_address\",\n \"address\": signed_payout_address,\n \"mining_fee_rate\": 50,\n }\n self.response = self.client.post(path + params, body, **headers)\n```\n###test function signature:\n```python\n def test_successful_onchain(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_successful_onchain", + "reference": " def test_successful_onchain(self):\n trade = Trade(self.client)\n trade.publish_order()\n trade.take_order()\n trade.lock_taker_bond()\n trade.lock_escrow(trade.taker_index)\n trade.submit_payout_address(trade.maker_index)\n trade.confirm_fiat(trade.maker_index)\n trade.confirm_fiat(trade.taker_index)\n\n trade.process_payouts(mine_a_block=True)\n trade.get_order(trade.maker_index)\n\n data = trade.response.json()\n\n self.assertEqual(trade.response.status_code, 200)\n self.assertResponse(trade.response)\n\n self.assertEqual(data[\"status_message\"], Order.Status(Order.Status.SUC).label)\n self.assertTrue(data[\"is_fiat_sent\"])\n self.assertFalse(data[\"is_disputed\"])\n self.assertIsInstance(data[\"maker_summary\"][\"address\"], str)\n self.assertIsHash(data[\"maker_summary\"][\"txid\"])\n\n self.assert_order_logs(data[\"id\"])", + "completions": [ + "```python\n def test_successful_onchain(self):\n self.submit_payout_address(robot_index=1)\n self.assertEqual(self.response.status_code, 200)\n self.assertEqual(self.response.json()[\"status\"], \"success\")\n```\n" + ], + "line": 26, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 70, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_blank_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_blank_line", + "reference": "\ndef test_header_to_metadata_and_cell_blank_line():\n text = \"\"\"---\ntitle: Sample header\n---\n\nHeader is followed by a blank line\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {}\n assert lines[pos].startswith(\"Header is\")", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_blank_line():\n lines = [\"#' ---\", \"\"]\n header_prefix = \"#'\"\n header_suffix = \"\"\n ext = None\n root_level_metadata_as_raw_cell = True\n metadata, jupyter, cell, i = header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext, root_level_metadata_as_raw_cell\n )\n assert metadata == {}\n assert jupyter == []\n assert cell == new_raw_cell(source=\"\\n\".join([\"---\"]), metadata={\"lines_to_next_cell\": 1})\n assert i == 1\n```\n" + ], + "line": 104, + "token": 642, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 71, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def _create_output(\n self,\n accepted: torch.Tensor, # [batch_size, k]\n recovered_token_ids: torch.Tensor, # [batch_size, k]\n draft_token_ids: torch.Tensor, # [batch_size, k]\n bonus_token_ids: torch.Tensor, # [batch_size]\n ) -> torch.Tensor:\n bonus_token_ids = bonus_token_ids.squeeze()\n batch_size, k = recovered_token_ids.shape\n\n # Determine the index of the first False value for each row.\n limits = (accepted == 0).max(1).indices\n limits[~(accepted == 0).any(1)] = k\n\n # Create masks using the indices.\n indices = torch.arange(k, device=accepted.device).unsqueeze(0)\n accepted_mask = indices < limits.unsqueeze(1)\n after_false_mask = indices == limits.unsqueeze(1)\n\n # Create an extended output tensor\n output_with_bonus_tokens = -torch.ones(\n (batch_size, k + self._num_bonus_tokens),\n dtype=self.token_id_dtype,\n device=accepted.device)\n output = output_with_bonus_tokens[:, :k]\n\n # Fill in the first k columns of the output tensor using masks and data\n # tensors.\n output[:, :k] = torch.where(accepted_mask, draft_token_ids,\n -torch.ones_like(draft_token_ids))\n\n # Fill the last column.\n # We check output directly as accepted may have True values inconsistent\n # with causal acceptance.\n output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,\n bonus_token_ids, -1)\n\n # We disable bonus tokens because it causes corrupt KV cache for\n # proposal methods that require KV cache. We can fix it by \"prefilling\"\n # the bonus token in the proposer. The following issue tracks the fix.\n # https://github.com/vllm-project/vllm/issues/4212\n output_with_bonus_tokens[:, -1] = -1\n\n # Fill the recovered token ids.\n output.mul_(~after_false_mask).add_(\n recovered_token_ids.mul(after_false_mask))\n\n self.num_accepted_tokens += accepted.sum()\n self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()\n self.num_draft_tokens += batch_size * k\n\n return output_with_bonus_tokens\n```\n###test function signature:\n```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct_output_format", + "reference": "def test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n set_random_seed(seed)\n torch.set_default_device(device)\n\n batch_size = 10\n k = 5\n vocab_size = 3000\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"no_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"some_tokens_accepted\":\n last_accepted_indices = torch.randint(low=-1,\n high=k,\n size=(batch_size, ))\n accepted = mock_causal_accepted_tensor(k, last_accepted_indices)\n else:\n raise AssertionError()\n\n recovered_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n draft_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n bonus_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, 1),\n dtype=torch.int64)\n\n rejection_sampler = RejectionSampler()\n rejection_sampler.init_gpu_tensors(rank=0)\n output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access\n accepted,\n recovered_token_ids,\n draft_token_ids,\n bonus_token_ids,\n )\n\n # Bonus tokens are currently disabled. Verify they're set to -1.\n # See https://github.com/vllm-project/vllm/issues/4212\n expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n # Expect all tokens to be equal to draft tokens.\n assert torch.equal(output_token_ids[:, :-1], draft_token_ids)\n\n # Expect all bonus tokens to be included.\n assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)\n elif which_tokens_accepted == \"no_tokens_accepted\":\n # Expect first token to be equal to recovered tokens.\n assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])\n\n # Expect everything else to be -1.\n assert torch.equal(output_token_ids[:, 1:],\n torch.ones_like(output_token_ids[:, 1:]) * -1)\n elif which_tokens_accepted == \"some_tokens_accepted\":\n recovered_plus_bonus = torch.cat(\n (recovered_token_ids, expected_bonus_token_ids), dim=-1)\n # Assert first rejected token is a recovered token or bonus token.\n assert torch.equal(\n recovered_plus_bonus[torch.arange(0, batch_size),\n last_accepted_indices + 1],\n output_token_ids[torch.arange(0, batch_size),\n last_accepted_indices + 1])\n\n # Assert every subsequent token is -1.\n subsequent_mask = torch.arange(0, k + 1).expand(\n batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)\n assert torch.all(output_token_ids[subsequent_mask] == -1)", + "completions": [ + "```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n torch.manual_seed(seed)\n batch_size = 10\n k = 10\n num_bonus_tokens = 1\n accepted = torch.randint(0, 2, (batch_size, k), dtype=torch.int32)\n recovered_token_ids = torch.randint(0, 10, (batch_size, k),\n dtype=torch.int32)\n draft_token_ids = torch.randint(0, 10, (batch_size, k), dtype=torch.int32)\n bonus_token_ids = torch.randint(0, 10, (batch_size,), dtype=torch.int32)\n output = _create_output(accepted, recovered_token_ids, draft_token_ids,\n bonus_token_ids)\n assert output.shape == (batch_size, k + num_bonus_tokens)\n assert output.dtype == torch.int32\n assert output.device == device\n assert (output[:, :k] == torch.where(accepted, draft_token_ids,\n -torch.ones_like(draft_token_ids))).all()\n assert (output[:, -1] == torch.where(output[:, -1] != -1, bonus_token_ids,\n -1)).all()\n assert (output[:, -1] == -1).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted" + ], + "line": 54, + "token": 433, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 72, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_array(fp, allow_pickle=False, pickle_kwargs=None, *,\n max_header_size=_MAX_HEADER_SIZE):\n if allow_pickle:\n # Effectively ignore max_header_size, since `allow_pickle` indicates\n # that the input is fully trusted.\n max_header_size = 2**64\n\n version = read_magic(fp)\n _check_version(version)\n shape, fortran_order, dtype = _read_array_header(\n fp, version, max_header_size=max_header_size)\n if len(shape) == 0:\n count = 1\n else:\n count = numpy.multiply.reduce(shape, dtype=numpy.int64)\n\n # Now read the actual data.\n if dtype.hasobject:\n # The array contained Python objects. We need to unpickle the data.\n if not allow_pickle:\n raise ValueError(\"Object arrays cannot be loaded when \"\n \"allow_pickle=False\")\n if pickle_kwargs is None:\n pickle_kwargs = {}\n try:\n array = pickle.load(fp, **pickle_kwargs)\n except UnicodeError as err:\n # Friendlier error message\n raise UnicodeError(\"Unpickling a python object failed: %r\\n\"\n \"You may need to pass the encoding= option \"\n \"to numpy.load\" % (err,)) from err\n else:\n if isfileobj(fp):\n # We can use the fast fromfile() function.\n array = numpy.fromfile(fp, dtype=dtype, count=count)\n else:\n # This is not a real file. We have to read it the\n # memory-intensive way.\n # crc32 module fails on reads greater than 2 ** 32 bytes,\n # breaking large reads from gzip streams. Chunk reads to\n # BUFFER_SIZE bytes to avoid issue and reduce memory overhead\n # of the read. In non-chunked case count < max_read_count, so\n # only one read is performed.\n\n # Use np.ndarray instead of np.empty since the latter does\n # not correctly instantiate zero-width string dtypes; see\n # https://github.com/numpy/numpy/pull/6430\n array = numpy.ndarray(count, dtype=dtype)\n\n if dtype.itemsize > 0:\n # If dtype.itemsize == 0 then there's nothing more to read\n max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)\n\n for i in range(0, count, max_read_count):\n read_count = min(max_read_count, count - i)\n read_size = int(read_count * dtype.itemsize)\n data = _read_bytes(fp, read_size, \"array data\")\n array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n count=read_count)\n\n if fortran_order:\n array.shape = shape[::-1]\n array = array.transpose()\n else:\n array.shape = shape\n\n return array\n```\n###test function signature:\n```python\n\ndef test_version_2_0():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_version_2_0", + "reference": "\ndef test_version_2_0():\n f = BytesIO()\n # requires more than 2 byte for header\n dt = [((\"%d\" % i) * 100, float) for i in range(500)]\n d = np.ones(1000, dtype=dt)\n\n format.write_array(f, d, version=(2, 0))\n with warnings.catch_warnings(record=True) as w:\n warnings.filterwarnings('always', '', UserWarning)\n format.write_array(f, d)\n assert_(w[0].category is UserWarning)\n\n # check alignment of data portion\n f.seek(0)\n header = f.readline()\n assert_(len(header) % format.ARRAY_ALIGN == 0)\n\n f.seek(0)\n n = format.read_array(f, max_header_size=200000)\n assert_array_equal(d, n)\n\n # 1.0 requested but data cannot be saved this way\n assert_raises(ValueError, format.write_array, f, d, (1, 0))", + "completions": [ + "```python\ndef test_version_2_0():\n fp = open(\"test_file.npy\", \"rb\")\n read_array(fp)\n```\n" + ], + "line": 70, + "token": 516, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 73, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):\n if not torch.is_tensor(audio):\n if isinstance(audio, str):\n audio = load_audio(audio)\n audio = torch.from_numpy(audio)\n\n window = torch.hann_window(N_FFT).to(audio.device)\n stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)\n magnitudes = stft[..., :-1].abs() ** 2\n\n filters = mel_filters(audio.device, n_mels)\n mel_spec = filters @ magnitudes\n\n log_spec = torch.clamp(mel_spec, min=1e-10).log10()\n log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)\n log_spec = (log_spec + 4.0) / 4.0\n return log_spec\n```\n###test function signature:\n```python\n\ndef test_audio():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_audio", + "reference": "\ndef test_audio():\n audio_path = os.path.join(os.path.dirname(__file__), \"jfk.flac\")\n audio = load_audio(audio_path)\n assert audio.ndim == 1\n assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12\n assert 0 < audio.std() < 1\n\n mel_from_audio = log_mel_spectrogram(audio)\n mel_from_file = log_mel_spectrogram(audio_path)\n\n assert np.allclose(mel_from_audio, mel_from_file)\n assert mel_from_audio.max() - mel_from_audio.min() <= 2.0", + "completions": [ + "```python\ndef test_audio():\n audio = torch.randn(1, 16000)\n log_mel_spectrogram(audio)\n```\n" + ], + "line": 24, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 74, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef transform_grid_cell(\n grid: Grid,\n cell: npt.NDArray[np.float_] = np.eye(3),\n output_shape: Optional[Tuple[int, int, int]] = None,\n mode: str = \"constant\",\n order: int = 1,\n **kwargs,\n) -> Grid:\n # Take the current shape of the grid if no output shape was provided\n if output_shape is None:\n output_shape = grid.shape\n\n # Make sure the cell has type float\n cell = np.asarray(cell, dtype=float)\n\n # Get the current cell in coordinates of the destination axes\n inv_cell = cell_invert(cell).T\n projected_cell = grid.cell.dot(inv_cell)\n\n # From that, infere how long will the bounding box of the cell be\n lengths = abs(projected_cell).sum(axis=0)\n\n # Create the transformation matrix. Since we want to control the shape\n # of the output, we can not use grid.dcell directly, we need to modify it.\n scales = output_shape / lengths\n forward_t = (grid.dcell.dot(inv_cell) * scales).T\n\n # Scipy's affine transform asks for the inverse transformation matrix, to\n # map from output pixels to input pixels. By taking the inverse of our\n # transformation matrix, we get exactly that.\n tr = cell_invert(forward_t).T\n\n # Calculate the offset of the image so that all points of the grid \"fall\" inside\n # the output array.\n # For this we just calculate the centers of the input and output images\n center_input = 0.5 * (_a.asarrayd(grid.shape) - 1)\n center_output = 0.5 * (_a.asarrayd(output_shape) - 1)\n\n # And then make sure that the input center that is interpolated from the output\n # falls in the actual input's center\n offset = center_input - tr.dot(center_output)\n\n # We pass all the parameters to scipy's affine_transform\n transformed_image = affine_transform(\n grid.grid,\n tr,\n order=1,\n offset=offset,\n output_shape=output_shape,\n mode=mode,\n **kwargs,\n )\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1))\n new_grid.grid = transformed_image\n new_grid.geometry = grid.geometry\n new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset))\n\n # Find the offset between the origin before and after the transformation\n return new_grid\n```\n###test function signature:\n```python\n\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_grid_cell", + "reference": "\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10))\n\n assert new_grid.shape == (10, 10, 10)\n assert new_grid.lattice.is_cartesian()\n\n if not skewed:\n assert np.allclose(new_grid.lattice.cell, grid.lattice.cell)\n\n assert np.allclose(new_grid.grid, grid.grid) == (not skewed)\n assert not np.allclose(new_grid.grid, 0)\n\n assert np.allclose(new_grid.origin, grid.origin) == (not skewed)\n\n # Convert to a skewed cell\n directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]])\n new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5))\n\n assert new_grid.shape == (5, 5, 5)\n for i in range(3):\n n = new_grid.lattice.cell[i] / directions[i]\n assert np.allclose(n, n[0])", + "completions": [ + "```python\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n cell = grid.cell.dot(cell_invert(grid.cell))\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = transform_grid_cell(grid, cell, grid.shape, mode=\"constant\", order=1)\n\n # Check that the new grid has the same shape and cell as the original grid\n assert new_grid.shape == grid.shape\n assert np.allclose(new_grid.cell, grid.cell)\n\n # Check that the new grid has the same geometry as the original grid\n assert new_grid.geometry == grid.geometry\n\n # Check that the new grid has the same origin as the original grid\n assert np.allclose(new_grid.origin, grid.origin)\n\n # Check that the new grid has the same dcell as the original grid\n assert np.allclose(new_grid.dcell, grid.dcell)\n\n # Check that the new grid has the same grid as the original grid\n assert np.allclose(new_grid.grid, grid.grid)\n\n # Check that the new grid has the same skewed as the original grid\n assert np.allclose(new_grid.skewed, grid.skewed)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid" + ], + "line": 63, + "token": 533, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:\n\n index_norm = norm(index, axis=1)\n xq_norm = norm(xq.T)\n sim = np.dot(index, xq.T) / (index_norm * xq_norm)\n return sim\n```\n###test function signature:\n```python\ndef test_top_scores__is_sorted(test_index):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_top_scores__is_sorted", + "reference": "def test_top_scores__is_sorted(test_index):\n\n xq = test_index[0] # should have max similarity\n\n sim = similarity_matrix(xq, test_index)\n _, idx = top_scores(sim, 3)\n\n # Scores and indexes should be sorted ascending\n assert np.array_equal(idx, np.array([2, 1, 0]))", + "completions": [ + "```python\ndef test_top_scores__is_sorted(test_index):\n sim = similarity_matrix(xq, test_index)\n assert np.all(np.diff(sim, axis=1) >= 0)\n```\n" + ], + "line": 15, + "token": 125, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_mixed_ref_area_2", + "reference": "\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n\n out_bars_with_mixed_ref_area_2 = _generate_nanoplot(**CASES[16])\n\n assert _is_nanoplot_output(out_bars_with_mixed_ref_area_2)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n y_ref_area = [1, 10]\n y_ref_line = 5\n x_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n expand_x = \"2020-01-01 00:00:00 UTC\"\n expand_y = \"2020-01-01 00:00:00 UTC\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n all_single_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_" + ], + "line": 749, + "token": 7038, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 77, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def mean(self, lower=0.0, upper=np.inf):\n loc = []\n for idx in self.index:\n qpd = self.qpd.loc[idx, :].values[0]\n l, _ = quad(exp_func, args=(qpd), a=lower, b=upper)\n loc.append(l)\n loc_arr = np.array(loc)\n return pd.DataFrame(loc_arr, index=self.index, columns=self.columns)\n```\n###test function signature:\n```python\ndef test_qpd_b_simple_use():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_qpd_b_simple_use", + "reference": "def test_qpd_b_simple_use():\n from skpro.distributions.qpd import QPD_B\n\n qpd = QPD_B(\n alpha=0.2,\n qv_low=[1, 2],\n qv_median=[3, 4],\n qv_high=[5, 6],\n lower=0,\n upper=10,\n )\n\n qpd.mean()", + "completions": [ + "```python\ndef test_qpd_b_simple_use():\n import numpy as np\n import pandas as pd\n from scipy.integrate import quad\n from scipy.special import expit\n\n def exp_func(x, qpd):\n return expit(qpd * x)\n\n qpd = pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z'])\n qpd_b = qpd.mean(lower=0.0, upper=np.inf)\n assert qpd_b.equals(pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z']))\n```\n" + ], + "line": 18, + "token": 146, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef create(\n supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False\n) -> adapters.HTTPAdapter:\n settings = supply_settings()\n allowed_methods = set(util.Retry.DEFAULT_ALLOWED_METHODS)\n if include_post:\n allowed_methods.add(\"POST\")\n\n retries = util.Retry(\n total=settings.total,\n backoff_factor=settings.backoff_factor,\n status_forcelist=settings.statuses,\n allowed_methods=allowed_methods,\n )\n return adapters.HTTPAdapter(max_retries=retries)\n```\n###test function signature:\n```python\n\ndef test_should_create_adapter(include_post):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_should_create_adapter", + "reference": "\ndef test_should_create_adapter(include_post):\n total_retries = 5\n backoff_factor = 0.4\n statuses = (random.randint(400, 600),)\n settings = config.ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses)\n adapter = http_adapter.create(supply_settings=lambda: settings, include_post=include_post)\n assert adapter.max_retries.total == total_retries\n assert adapter.max_retries.backoff_factor == backoff_factor\n assert adapter.max_retries.status_forcelist == statuses\n allowed_methods = adapter.max_retries.allowed_methods or []\n if include_post:\n assert \"POST\" in allowed_methods\n else:\n assert \"POST\" not in allowed_methods", + "completions": [ + "```python\n\ndef test_should_create_adapter(include_post):\n assert create(include_post=include_post) == adapters.HTTPAdapter(max_retries=util.Retry(\n total=config.ResilienceSettings().total,\n backoff_factor=config.ResilienceSettings().backoff_factor,\n status_forcelist=config.ResilienceSettings().statuses,\n allowed_methods=set(util.Retry.DEFAULT_ALLOWED_METHODS) if not include_post else set(util.Retry.DEFAULT_ALLOWED_METHODS).union({\"POST\"})\n ))\n```\n" + ], + "line": 24, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 79, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:\n summary_path = os.path.join(trial_path, 'summary.csv')\n if not os.path.exists(summary_path):\n raise ValueError(f\"summary.csv does not exist in {trial_path}.\")\n trial_summary_df = load_summary_file(summary_path, dict_columns=['best_module_params'])\n config_yaml_path = os.path.join(trial_path, 'config.yaml')\n with open(config_yaml_path, 'r') as f:\n config_dict = yaml.safe_load(f)\n yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)\n if output_path is not None:\n with open(output_path, 'w') as f:\n yaml.dump(yaml_dict, f)\n return yaml_dict\n```\n###test function signature:\n```python\n\ndef test_runner(evaluator):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_runner", + "reference": "\ndef test_runner(evaluator):\n os.environ['BM25'] = 'bm25'\n evaluator.start_trial(os.path.join(resource_dir, 'simple.yaml'))\n project_dir = evaluator.project_dir\n\n def runner_test(runner: Runner):\n answer = runner.run('What is the best movie in Korea? Have Korea movie ever won Oscar?',\n 'retrieved_contents')\n assert len(answer) == 1\n assert isinstance(answer, list)\n assert isinstance(answer[0], str)\n\n runner = Runner.from_trial_folder(os.path.join(project_dir, '0'))\n runner_test(runner)\n runner_test(runner)\n\n with tempfile.NamedTemporaryFile(suffix='yaml', mode='w+t') as yaml_path:\n extract_best_config(os.path.join(project_dir, '0'), yaml_path.name)\n runner = Runner.from_yaml(yaml_path.name, project_dir=project_dir)\n runner_test(runner)", + "completions": [ + "```python\ndef test_runner(evaluator):\n def test_extract_best_config():\n trial_path = 'path/to/trial'\n output_path = 'path/to/output'\n yaml_dict = extract_best_config(trial_path, output_path)\n assert yaml_dict is not None\n assert isinstance(yaml_dict, dict)\n assert 'best_module_params' in yaml_dict\n assert isinstance(yaml_dict['best_module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['" + ], + "line": 23, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 80, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def init_weights(self):\n logger = MMLogger.get_current_instance()\n if self.init_cfg is None:\n logger.warn(f'No pre-trained weights for '\n f'{self.__class__.__name__}, '\n f'training start from scratch')\n if self.use_abs_pos_embed:\n trunc_normal_(self.absolute_pos_embed, std=0.02)\n for m in self.modules():\n if isinstance(m, nn.Linear):\n trunc_normal_init(m, std=.02, bias=0.)\n elif isinstance(m, nn.LayerNorm):\n constant_init(m, 1.0)\n else:\n assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n f'specify `Pretrained` in ' \\\n f'`init_cfg` in ' \\\n f'{self.__class__.__name__} '\n ckpt = CheckpointLoader.load_checkpoint(\n self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n if 'state_dict' in ckpt:\n _state_dict = ckpt['state_dict']\n elif 'model' in ckpt:\n _state_dict = ckpt['model']\n else:\n _state_dict = ckpt\n if self.convert_weights:\n # supported loading weight from original repo,\n _state_dict = swin_converter(_state_dict)\n\n state_dict = OrderedDict()\n for k, v in _state_dict.items():\n if k.startswith('backbone.'):\n state_dict[k[9:]] = v\n\n # strip prefix of state_dict\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n # reshape absolute position embedding\n if state_dict.get('absolute_pos_embed') is not None:\n absolute_pos_embed = state_dict['absolute_pos_embed']\n N1, L, C1 = absolute_pos_embed.size()\n N2, C2, H, W = self.absolute_pos_embed.size()\n if N1 != N2 or C1 != C2 or L != H * W:\n logger.warning('Error in loading absolute_pos_embed, pass')\n else:\n state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n # interpolate position bias table if needed\n relative_position_bias_table_keys = [\n k for k in state_dict.keys()\n if 'relative_position_bias_table' in k\n ]\n for table_key in relative_position_bias_table_keys:\n table_pretrained = state_dict[table_key]\n table_current = self.state_dict()[table_key]\n L1, nH1 = table_pretrained.size()\n L2, nH2 = table_current.size()\n if nH1 != nH2:\n logger.warning(f'Error in loading {table_key}, pass')\n elif L1 != L2:\n S1 = int(L1**0.5)\n S2 = int(L2**0.5)\n table_pretrained_resized = F.interpolate(\n table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n size=(S2, S2),\n mode='bicubic')\n state_dict[table_key] = table_pretrained_resized.view(\n nH2, L2).permute(1, 0).contiguous()\n\n # load state_dict\n self.load_state_dict(state_dict, False)\n```\n###test function signature:\n```python\ndef test_swin_transformer():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_swin_transformer", + "reference": "def test_swin_transformer():\n\n with pytest.raises(TypeError):\n # Pretrained arg must be str or None.\n SwinTransformer(pretrained=123)\n\n with pytest.raises(AssertionError):\n # Because swin uses non-overlapping patch embed, so the stride of patch\n # embed must be equal to patch size.\n SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)\n\n # test pretrained image size\n with pytest.raises(AssertionError):\n SwinTransformer(pretrain_img_size=(224, 224, 224))\n\n # Test absolute position embedding\n temp = torch.randn((1, 3, 224, 224))\n model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)\n model.init_weights()\n model(temp)\n\n # Test patch norm\n model = SwinTransformer(patch_norm=False)\n model(temp)\n\n # Test normal inference\n temp = torch.randn((1, 3, 32, 32))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 31, 31))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 112, 137))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 28, 35)\n assert outs[1].shape == (1, 192, 14, 18)\n assert outs[2].shape == (1, 384, 7, 9)\n assert outs[3].shape == (1, 768, 4, 5)\n\n model = SwinTransformer(frozen_stages=4)\n model.train()\n for p in model.parameters():\n assert not p.requires_grad", + "completions": [ + "```python\ndef test_swin_transformer():\n from mmcv.cnn import (build_pos_embed, build_pos_embed_sine,\n constant_init, trunc_normal_init)\n from mmcv.runner import CheckpointLoader\n from mmdet.models.builder import build_backbone\n from mmdet.models.utils import swin_converter\n from mmdet.utils import get_root_logger\n from mmdet.utils.logger import MMLogger\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import PatchEmbed\n from mmdet.models.backbones.swin import PatchMerging\n from mmdet.models.backbones.swin import BasicLayer\n from mmdet.models.backbones.swin import BasicBlock\n from mmdet.models.backbones.swin import DepthwiseSeparableConv\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransform" + ], + "line": 79, + "token": 586, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 81, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\n\ndef test_date_columns_equal():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_date_columns_equal", + "reference": "\ndef test_date_columns_equal():\n data = \"\"\"a|b|expected\n2017-01-01|2017-01-01|True\n2017-01-02|2017-01-02|True\n2017-10-01|2017-10-10|False\n2017-01-01||False\n|2017-01-01|False\n||True\"\"\"\n df = pl.read_csv(\n io.StringIO(data),\n separator=\"|\",\n null_values=[\"NULL\"],\n missing_utf8_is_empty_string=True,\n )\n # First compare just the strings\n actual_out = columns_equal(df[\"a\"], df[\"b\"], rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)\n\n # Then compare converted to datetime objects\n col_a = df[\"a\"].str.to_datetime(strict=False)\n col_b = df[\"b\"].str.to_datetime(strict=False)\n actual_out = columns_equal(col_a, col_b, rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)", + "completions": [ + "```python\ndef test_date_columns_equal():\n col_1 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n col_2 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 82, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def upsert_child_firing(\n db: PrefectDBInterface,\n session: AsyncSession,\n firing: Firing,\n):\n automation_id = firing.trigger.automation.id\n parent_trigger_id = firing.trigger.parent.id\n child_trigger_id = firing.trigger.id\n\n upsert = (\n postgresql.insert(db.CompositeTriggerChildFiring)\n .values(\n automation_id=automation_id,\n parent_trigger_id=parent_trigger_id,\n child_trigger_id=child_trigger_id,\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n )\n .on_conflict_do_update(\n index_elements=[\n db.CompositeTriggerChildFiring.automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id,\n ],\n set_=dict(\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n updated=pendulum.now(\"UTC\"),\n ),\n )\n )\n\n await session.execute(upsert)\n\n result = await session.execute(\n sa.select(db.CompositeTriggerChildFiring).filter(\n db.CompositeTriggerChildFiring.automation_id == automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id == parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id == child_trigger_id,\n )\n )\n\n return result.scalars().one()\n```\n###test function signature:\n```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_insert_compound_trigger_child_firing", + "reference": "\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n firing = Firing(\n trigger=saw_a_red_dragon,\n trigger_states={TriggerState.Triggered},\n triggered=frozen_time.add(seconds=1),\n triggering_labels={},\n triggering_event=baby_red_dragon_passed_by,\n )\n res = await upsert_child_firing(\n session=automations_session,\n firing=firing,\n )\n\n assert res and res.id\n assert res.automation_id == raise_the_alarm.id\n assert res.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert res.child_trigger_id == saw_a_red_dragon.id\n assert res.child_firing_id == firing.id\n assert res.child_fired_at == firing.triggered\n\n child_firings = await get_child_firings(\n session=automations_session, trigger=saw_a_red_and_green_dragon\n )\n assert len(child_firings) == 1\n assert child_firings[0].id == res.id\n assert child_firings[0].child_firing_id == firing.id", + "completions": [ + "```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n \"\"\"\n Test that a compound trigger child firing is inserted correctly\n \"\"\"\n # Arrange\n firing = Firing(\n id=uuid4(),\n triggered=frozen_time,\n trigger=saw_a_red_and_green_dragon,\n event=baby_red_dragon_passed_by,\n )\n\n # Act\n result = await upsert_child_firing(\n db=automations_session,\n session=automations_session,\n firing=firing,\n )\n\n # Assert\n assert result.automation_id == raise_the_alarm.id\n assert result.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert result.child_trigger_id == saw_a_red_dragon.id\n assert result.child_firing_id == firing.id\n assert result.child_fired_at == firing.triggered\n assert result.child_firing == firing.dict()\n assert result.updated == frozen_time\n\n```\n" + ], + "line": 58, + "token": 225, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 83, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef write_output(output_file, output_type, results):\n\n # If json or txt, write contents directly to file.\n if output_type == \"json\" or output_type == \"txt\":\n with open(output_file, \"w\") as f:\n json.dump(results, f, indent=4)\n return\n # If csv or xlsx, convert results to pandas dataframes.\n urls_df = get_urls_df(results)\n codes_df = get_codes_df(results)\n\n # If csv, write dataframes to separate csv files for urls, codes.\n if output_type == \"csv\":\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n urls_df.to_csv(urls_output_file, index=False)\n codes_df.to_csv(codes_output_file, index=False)\n\n # If xlsx, write dataframes to separate sheets for urls, codes.\n if output_type == \"xlsx\":\n writer = pd.ExcelWriter(output_file, engine=\"xlsxwriter\")\n urls_df.to_excel(writer, sheet_name=\"URLs\", index=False)\n codes_df.to_excel(writer, sheet_name=\"Codes\", index=False)\n writer.close()\n```\n###test function signature:\n```python\n def test_write_output_csv(self, mock_urls, mock_codes):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_output_csv", + "reference": " def test_write_output_csv(self, mock_urls, mock_codes):\n\n test_file = \"./test_output/test_file.csv\"\n test_file_urls = \"./test_output/test_file_urls.csv\"\n test_file_codes = \"./test_output/test_file_codes.csv\"\n test_results = {\"test\": \"test\"}\n mock_urls.return_value = pd.DataFrame([test_results])\n mock_codes.return_value = pd.DataFrame([test_results])\n\n with open(test_file_urls, \"w\") as f:\n pass\n\n with open(test_file_codes, \"w\") as f:\n pass\n\n write_output(test_file, \"csv\", test_results)\n\n test_data_urls = pd.read_csv(test_file_urls).to_dict(orient=\"records\")[0]\n test_data_codes = pd.read_csv(test_file_codes).to_dict(orient=\"records\")[0]\n\n os.remove(test_file_urls)\n os.remove(test_file_codes)\n\n self.assertEqual(test_data_urls, test_results)\n self.assertEqual(test_data_codes, test_results)", + "completions": [ + "```python\ndef test_write_output_csv(self, mock_urls, mock_codes):\n \"\"\"\n Test write_output function with csv output type.\n \"\"\"\n output_file = \"output.csv\"\n output_type = \"csv\"\n results = {\"urls\": mock_urls, \"codes\": mock_codes}\n write_output(output_file, output_type, results)\n\n # Check that the urls and codes dataframes are written to separate csv files.\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n self.assertTrue(os.path.isfile(urls_output_file))\n self.assertTrue(os.path.isfile(codes_output_file))\n\n # Check that the dataframes are correct.\n urls_df = pd.read_csv(urls_output_file)\n codes_df = pd.read_csv(codes_output_file)\n self.assertEqual(urls_df.shape, (2, 2))\n self.assertEqual(codes_df.shape, (2, 2))\n self.assertEqual(urls_df.iloc[0, 0], \"url1\")\n self.assertEqual(codes_df.iloc[0, 0], \"code1\")\n self.assertEqual(urls_df.iloc[1, 0], \"url2\")\n self.assertEqual(codes_df.iloc[1, 0], \"code2\")\n```\n" + ], + "line": 31, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 84, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def terminate_idle_instance() -> None:\n async with get_session_ctx() as session:\n async with PROCESSING_POOL_LOCK:\n res = await session.execute(\n select(InstanceModel)\n .where(\n InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE,\n InstanceModel.deleted == False,\n InstanceModel.job == None, # noqa: E711\n InstanceModel.status == InstanceStatus.IDLE,\n )\n .options(joinedload(InstanceModel.project))\n )\n instances = res.scalars().all()\n\n for instance in instances:\n last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)\n if instance.last_job_processed_at is not None:\n last_time = instance.last_job_processed_at.replace(\n tzinfo=datetime.timezone.utc\n )\n\n idle_seconds = instance.termination_idle_time\n delta = datetime.timedelta(seconds=idle_seconds)\n\n current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc)\n\n if last_time + delta < current_time:\n jpd = JobProvisioningData.__response__.parse_raw(\n instance.job_provisioning_data\n )\n await terminate_job_provisioning_data_instance(\n project=instance.project, job_provisioning_data=jpd\n )\n instance.deleted = True\n instance.deleted_at = get_current_datetime()\n instance.finished_at = get_current_datetime()\n instance.status = InstanceStatus.TERMINATED\n instance.termination_reason = \"Idle timeout\"\n\n idle_time = current_time - last_time\n logger.info(\n \"Instance %s terminated by termination policy: idle time %ss\",\n instance.name,\n str(idle_time.seconds),\n extra={\n \"instance_name\": instance.name,\n \"instance_status\": InstanceStatus.TERMINATED.value,\n },\n )\n\n await session.commit()\n```\n###test function signature:\n```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_terminate_by_idle_timeout", + "reference": "\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n project = await create_project(session=session)\n pool = await create_pool(session, project)\n\n instance = await create_instance(session, project, pool, status=InstanceStatus.IDLE)\n instance.termination_idle_time = 300\n instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE\n instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)\n await session.commit()\n\n with patch(\n \"dstack._internal.server.background.tasks.process_instances.terminate_job_provisioning_data_instance\"\n ):\n await terminate_idle_instance()\n\n await session.refresh(instance)\n\n assert instance is not None\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"", + "completions": [ + "```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n # given\n project = ProjectModel(name=\"test_project\")\n await session.add(project)\n await session.commit()\n\n instance = InstanceModel(\n name=\"test_instance\",\n project=project,\n termination_policy=TerminationPolicy.DESTROY_AFTER_IDLE,\n termination_idle_time=10,\n status=InstanceStatus.IDLE,\n created_at=datetime.datetime.now(datetime.timezone.utc),\n )\n await session.add(instance)\n await session.commit()\n\n # when\n await terminate_idle_instance()\n\n # then\n instance = await session.get(InstanceModel, instance.id)\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"\n```\n" + ], + "line": 56, + "token": 289, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 85, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef nargsort(\n items,\n kind: str = \"quicksort\",\n ascending: bool = True,\n na_position: str = \"last\",\n key: Callable | None = None,\n mask: npt.NDArray[np.bool_] | None = None,\n) -> npt.NDArray[np.intp]:\n\n if key is not None:\n items = ensure_key_mapped(items, key)\n return nargsort(\n items,\n kind=kind,\n ascending=ascending,\n na_position=na_position,\n key=None,\n mask=mask,\n )\n\n if isinstance(items, ABCRangeIndex):\n return items.argsort(ascending=ascending) # TODO: test coverage with key?\n elif not isinstance(items, ABCMultiIndex):\n items = extract_array(items)\n if mask is None:\n mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?\n\n if is_extension_array_dtype(items):\n return items.argsort(ascending=ascending, kind=kind, na_position=na_position)\n else:\n items = np.asanyarray(items)\n\n idx = np.arange(len(items))\n non_nans = items[~mask]\n non_nan_idx = idx[~mask]\n\n nan_idx = np.nonzero(mask)[0]\n if not ascending:\n non_nans = non_nans[::-1]\n non_nan_idx = non_nan_idx[::-1]\n indexer = non_nan_idx[non_nans.argsort(kind=kind)]\n if not ascending:\n indexer = indexer[::-1]\n # Finally, place the NaNs at the end or the beginning according to\n # na_position\n if na_position == \"last\":\n indexer = np.concatenate([indexer, nan_idx])\n elif na_position == \"first\":\n indexer = np.concatenate([nan_idx, indexer])\n else:\n raise ValueError(f\"invalid na_position: {na_position}\")\n return ensure_platform_int(indexer)\n```\n###test function signature:\n```python\n\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nargsort", + "reference": "\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)\n\n # mergesort is the most difficult to get right because we want it to be\n # stable.\n\n # According to numpy/core/tests/test_multiarray, \"\"\"The number of\n # sorted items must be greater than ~50 to check the actual algorithm\n # because quick and merge sort fall over to insertion sort for small\n # arrays.\"\"\"\n\n result = nargsort(\n items, kind=\"mergesort\", ascending=ascending, na_position=na_position\n )\n tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)", + "completions": [ + "```python\n def test_nargsort(self, ascending, na_position, exp, box):\n items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216," + ], + "line": 58, + "token": 412, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 86, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_rotated_rb_circuits(\n n_qubits: int,\n num_cliffords: int,\n theta: Optional[float] = None,\n trials: int = 1,\n return_type: Optional[str] = None,\n seed: Optional[int] = None,\n) -> List[QPROGRAM]:\n\n circuits = cast(\n List[cirq.Circuit],\n generate_rb_circuits(n_qubits, num_cliffords, trials, seed=seed),\n )\n\n if theta is None:\n theta = 2 * np.pi * np.random.rand()\n\n for circ in circuits:\n qubits = list(circ.all_qubits())\n circ.insert(len(circ) // 2, cirq.Rz(rads=theta).on(qubits[0]))\n\n return_type = \"cirq\" if not return_type else return_type\n return [convert_from_mitiq(circuit, return_type) for circuit in circuits]\n```\n###test function signature:\n```python\n\ndef test_rotated_rb_circuits(n_qubits, theta):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rotated_rb_circuits", + "reference": "\ndef test_rotated_rb_circuits(n_qubits, theta):\n depth = 10\n for trials in [2, 3]:\n circuits = generate_rotated_rb_circuits(\n n_qubits=n_qubits, num_cliffords=depth, theta=theta, trials=trials\n )\n for circ in circuits:\n zero_prob = (\n cirq.DensityMatrixSimulator()\n .simulate(circ)\n .final_density_matrix[0, 0]\n .real\n )\n assert -1.0001 <= zero_prob <= 1.0001", + "completions": [ + "```python\ndef test_rotated_rb_circuits(n_qubits, theta):\n from cirq import Circuit, LineQubit\n from cirq.testing import assert_has_diagram\n from mitiq.generate import generate_rotated_rb_circuits\n\n circuits = generate_rotated_rb_circuits(n_qubits, 10, theta)\n for circuit in circuits:\n assert_has_diagram(\n circuit,\n f\"\"\"\n0: ───Rz({theta})───\n1: ───Rz({theta})───\n2: ───Rz({theta})───\n3: ───Rz({theta})───\n4: ───Rz({theta})───\n5: ───Rz({theta})───\n6: ───Rz({theta})───\n7: ───Rz({theta})───\n8: ───Rz({theta})───\n9: ───Rz({theta})───\n10: ───Rz({theta})───\n11: ───Rz({theta})───\n12: ───Rz({theta})───\n13: ───Rz({theta})───\n14: ───Rz({theta})───\n15: ───Rz({theta})───\n16: ───Rz({theta})───\n17: ───Rz({theta})───\n18: ───Rz({theta})───\n19: ───Rz({theta})───\n20: ───Rz({theta})───\n21: ───Rz({theta})───\n22: ───Rz({theta})───\n23: ───Rz({theta})───\n24: ───Rz({theta})───\n25: ───Rz({theta})───\n26: ───Rz({theta})───\n27: ───Rz({theta})───\n28: ───Rz({theta})───\n29: ───Rz({theta})───\n30: ───Rz({theta})───\n31: ───Rz({theta})───\n32: ───Rz({theta})───\n33: ───Rz({theta})───\n34: ───Rz({theta})───\n35: ───Rz({theta})───\n36: ───Rz({theta})───\n37: ───Rz({theta})───\n38: ───Rz({theta})───\n39: ───Rz({theta})───\n40: ───Rz({theta})───\n41: ───Rz({theta})───\n42: ───Rz({theta})───\n43: ───Rz({theta})───\n44: ───Rz({theta})───\n45: ───Rz({theta})───\n46: ───Rz({theta})───\n47: ───Rz({theta})───\n48: ───Rz({theta})───\n49: ───Rz({theta})───\n50: ───Rz({theta})───\n51: ───Rz({theta})───\n52: ───Rz({theta})───\n53: ───Rz({theta})───\n54: ───Rz({theta})───\n55: ───Rz({theta})───\n56: ───Rz({theta})───\n57: ───Rz({theta})───\n58: ───Rz({theta})───\n59: ───Rz({theta})───\n60: ───Rz({theta})───\n61: ───Rz({theta})───\n62: ───Rz({theta})───\n63: ───" + ], + "line": 29, + "token": 226, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 87, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef eye(N, M=None, k=None, dtype=None):\n dtype = to_torch_dtype(dtype or config.floatx())\n M = N if M is None else M\n k = 0 if k is None else k\n if k == 0:\n # TODO: torch.eye doesn't support bfloat16 with cpu\n if get_device() == \"cpu\" and dtype == torch.bfloat16:\n return cast(\n torch.eye(\n N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n ),\n dtype,\n )\n return torch.eye(N, M, dtype=dtype, device=get_device())\n diag_length = builtins.max(N, M)\n diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n return torch.diag(diag, diagonal=k)[:N, :M]\n```\n###test function signature:\n```python\n\n def test_eye(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_eye", + "reference": "\n def test_eye(self):\n self.assertAllClose(knp.eye(3), np.eye(3))\n self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n self.assertAllClose(knp.Eye()(3), np.eye(3))\n self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1))\n\n # Test k >= N\n self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3))\n\n # Test k > 0 and N >= M\n self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1))\n\n # Test k > 0 and N < M and N + k > M\n self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2))\n\n # Test k < 0 and M >= N\n self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1))\n\n # Test k < 0 and M < N and M - k > N\n self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))", + "completions": [ + "```python\n\n def test_eye(self):\n self.assertEqual(eye(3), torch.eye(3))\n self.assertEqual(eye(3, 4), torch.eye(3, 4))\n self.assertEqual(eye(3, 4, 1), torch.eye(3, 4, 1))\n self.assertEqual(eye(3, 4, -1), torch.eye(3, 4, -1))\n self.assertEqual(eye(3, 4, 2), torch.eye(3, 4, 2))\n self.assertEqual(eye(3, 4, -2), torch.eye(3, 4, -2))\n self.assertEqual(eye(3, 4, 3), torch.eye(3, 4, 3))\n self.assertEqual(eye(3, 4, -3), torch.eye(3, 4, -3))\n self.assertEqual(eye(3, 4, 4), torch.eye(3, 4, 4))\n self.assertEqual(eye(3, 4, -4), torch.eye(3, 4, -4))\n self.assertEqual(eye(3, 4, 5), torch.eye(3, 4, 5))\n self.assertEqual(eye(3, 4, -5), torch.eye(3, 4, -5))\n self.assertEqual(eye(3, 4, 6), torch.eye(3, 4, 6))\n self.assertEqual(eye(3, 4, -6), torch.eye(3, 4, -6))\n self.assertEqual(eye(3, 4, 7), torch.eye(3, 4, 7))\n self.assertEqual(eye(3, 4, -7), torch.eye(3, 4, -7))\n self.assertEqual(eye(3, 4, 8), torch.eye(3, 4, 8))\n self.assertEqual(eye(3, 4, -8), torch.eye(3, 4, -8))\n self.assertEqual(eye(3, 4, 9), torch.eye(3, 4, 9))\n self.assertEqual(eye(3, 4, -9), torch.eye(3, 4, -9))\n self.assertEqual(eye(3, 4, 10), torch.eye(3, 4, 10))\n self.assertEqual(eye(3, 4, -10), torch.eye(3, 4, -10))\n self.assertEqual(eye(3, 4, 11), torch.eye(3, 4, 11))\n self.assertEqual(eye(3, 4, -11), torch.eye(3, 4, -11))\n self.assertEqual(eye(3, 4, 12), torch.eye(3, 4, 12))\n self.assertEqual(eye(3, 4, -12), torch.eye(3, 4, -12))\n self.assertEqual(eye(3, 4, 13), torch.eye(3, 4, 13))\n self.assertEqual(eye(3, 4, -13), torch.eye(3, 4, -13))\n self.assertEqual(eye(3, 4, 14), torch.eye(3, 4, 14))\n self.assertEqual(eye(3, 4, -14), torch.eye(3, 4, -14))\n self.assertEqual(eye(3, 4, 15), torch.eye(" + ], + "line": 27, + "token": 219, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef ragged_qk_dotprod(\n query: RaggedActivations, key: RaggedActivations, lut: RaggedQkPidLookupTable\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n # check constraints\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n # allocates output\n # max_n_ctx_q_across_seqs = query.max_n_ctx_per_seq\n\n assert query.n_seqs == key.n_seqs\n # TODO: flag use zeros for garbage\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n # Integers\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n```\n###test function signature:\n```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ragged_qk_dotprod_single_seq", + "reference": "\ndef test_ragged_qk_dotprod_single_seq(dtype):\n d_head = 2\n\n key = RaggedActivations.from_list(\n [\n make_seq(n_ctx=3, value=42, d_model=d_head, dtype=dtype),\n ]\n )\n query = RaggedActivations.from_list(\n [\n make_seq(n_ctx=4, value=55, d_model=d_head, dtype=dtype),\n ]\n )\n torch_scores = scores_via_qk_dotprod(query, key)\n print(f\"{torch_scores=}\")\n\n lut = RaggedQkPidLookupTable.from_query_and_key_tokens_per_seq(\n n_ctx_q_per_seq=query.n_ctx_per_seq, n_ctx_k_per_seq=key.n_ctx_per_seq\n )\n\n scores = ragged_qk_dotprod(query, key, lut)\n assert_eq(torch_scores, scores)", + "completions": [ + "```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n # query\n query_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n query = RaggedActivations(\n raw_tensor=query_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # key\n key_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n key = RaggedActivations(\n raw_tensor=key_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # lut\n pid_to_in_q_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_in_k_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_q_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_k_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_seq_idx = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n lut = RaggedQkPidLookupTable(\n pid_to_in_q_token_offset=pid_to_in_q_token_offset,\n pid_to_in_k_token_offset=pid_to_in_k_token_offset,\n pid_to_out_q_block=pid_to_out_q_block,\n pid_to_out_k_block=pid_to_out_k_block,\n pid_to_out_seq_idx=pid_to_out_seq_idx,\n n_pids_total=1,\n )\n\n # scores_out\n scores_out = torch.ones(\n (1, 1, 1), dtype=dtype, device=torch.device(\"cuda\")\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_" + ], + "line": 57, + "token": 310, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 89, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):\n # Ravel both arrays, behavior for the first array could be different\n ar1 = np.asarray(ar1).ravel()\n ar2 = np.asarray(ar2).ravel()\n\n # Ensure that iteration through object arrays yields size-1 arrays\n if ar2.dtype == object:\n ar2 = ar2.reshape(-1, 1)\n\n if kind not in {None, 'sort', 'table'}:\n raise ValueError(\n f\"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.\")\n\n # Can use the table method if all arrays are integers or boolean:\n is_int_arrays = all(ar.dtype.kind in (\"u\", \"i\", \"b\") for ar in (ar1, ar2))\n use_table_method = is_int_arrays and kind in {None, 'table'}\n\n if use_table_method:\n if ar2.size == 0:\n if invert:\n return np.ones_like(ar1, dtype=bool)\n else:\n return np.zeros_like(ar1, dtype=bool)\n\n # Convert booleans to uint8 so we can use the fast integer algorithm\n if ar1.dtype == bool:\n ar1 = ar1.astype(np.uint8)\n if ar2.dtype == bool:\n ar2 = ar2.astype(np.uint8)\n\n ar2_min = np.min(ar2)\n ar2_max = np.max(ar2)\n\n ar2_range = int(ar2_max) - int(ar2_min)\n\n # Constraints on whether we can actually use the table method:\n # 1. Assert memory usage is not too large\n below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)\n # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype\n range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max\n # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype\n if ar1.size > 0:\n ar1_min = np.min(ar1)\n ar1_max = np.max(ar1)\n\n # After masking, the range of ar1 is guaranteed to be\n # within the range of ar2:\n ar1_upper = min(int(ar1_max), int(ar2_max))\n ar1_lower = max(int(ar1_min), int(ar2_min))\n\n range_safe_from_overflow &= all((\n ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,\n ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min\n ))\n\n # Optimal performance is for approximately\n # log10(size) > (log10(range) - 2.27) / 0.927.\n # However, here we set the requirement that by default\n # the intermediate array can only be 6x\n # the combined memory allocation of the original\n # arrays. See discussion on \n # https://github.com/numpy/numpy/pull/12065.\n\n if (\n range_safe_from_overflow and \n (below_memory_constraint or kind == 'table')\n ):\n\n if invert:\n outgoing_array = np.ones_like(ar1, dtype=bool)\n else:\n outgoing_array = np.zeros_like(ar1, dtype=bool)\n\n # Make elements 1 where the integer exists in ar2\n if invert:\n isin_helper_ar = np.ones(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 0\n else:\n isin_helper_ar = np.zeros(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 1\n\n # Mask out elements we know won't work\n basic_mask = (ar1 <= ar2_max) & (ar1 >= ar2_min)\n outgoing_array[basic_mask] = isin_helper_ar[ar1[basic_mask] -\n ar2_min]\n\n return outgoing_array\n elif kind == 'table': # not range_safe_from_overflow\n raise RuntimeError(\n \"You have specified kind='table', \"\n \"but the range of values in `ar2` or `ar1` exceed the \"\n \"maximum integer of the datatype. \"\n \"Please set `kind` to None or 'sort'.\"\n )\n elif kind == 'table':\n raise ValueError(\n \"The 'table' method is only \"\n \"supported for boolean or integer arrays. \"\n \"Please select 'sort' or None for kind.\"\n )\n\n\n # Check if one of the arrays may contain arbitrary objects\n contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject\n\n # This code is run when\n # a) the first condition is true, making the code significantly faster\n # b) the second condition is true (i.e. `ar1` or `ar2` may contain\n # arbitrary objects), since then sorting is not guaranteed to work\n if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:\n if invert:\n mask = np.ones(len(ar1), dtype=bool)\n for a in ar2:\n mask &= (ar1 != a)\n else:\n mask = np.zeros(len(ar1), dtype=bool)\n for a in ar2:\n mask |= (ar1 == a)\n return mask\n\n # Otherwise use sorting\n if not assume_unique:\n ar1, rev_idx = np.unique(ar1, return_inverse=True)\n ar2 = np.unique(ar2)\n\n ar = np.concatenate((ar1, ar2))\n # We need this to be a stable sort, so always use 'mergesort'\n # here. The values from the first array should always come before\n # the values from the second array.\n order = ar.argsort(kind='mergesort')\n sar = ar[order]\n if invert:\n bool_ar = (sar[1:] != sar[:-1])\n else:\n bool_ar = (sar[1:] == sar[:-1])\n flag = np.concatenate((bool_ar, [invert]))\n ret = np.empty(ar.shape, dtype=bool)\n ret[order] = flag\n\n if assume_unique:\n return ret[:len(ar1)]\n else:\n return ret[rev_idx]\n```\n###test function signature:\n```python\n\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_in1d_ravel", + "reference": "\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n a = np.arange(6).reshape(2, 3)\n b = np.arange(3, 9).reshape(3, 2)\n long_b = np.arange(3, 63).reshape(30, 2)\n ec = np.array([False, False, False, True, True, True])\n\n assert_array_equal(in1d(a, b, assume_unique=True, kind=kind),\n ec)\n assert_array_equal(in1d(a, b, assume_unique=False,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=True,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=False,\n kind=kind),\n ec)", + "completions": [ + "```python\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n ar1 = np.array([1, 2, 3, 4, 5])\n ar2 = np.array([1, 2, 3, 4, 5])\n ar1_ravel = ar1.ravel()\n ar2_ravel = ar2.ravel()\n self.assertTrue(np.array_equal(ar1_ravel, ar1))\n self.assertTrue(np.array_equal(ar2_ravel, ar2))\n self.assertTrue(np.array_equal(in1d(ar1, ar2, kind=kind),\n in1d(ar1_ravel, ar2_ravel, kind=kind)))\n```\n" + ], + "line": 133, + "token": 1112, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 90, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_metadata", + "reference": "\ndef test_header_to_metadata_and_cell_metadata():\n text = \"\"\"---\ntitle: Sample header\njupyter:\n mainlanguage: python\n---\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {\"mainlanguage\": \"python\"}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {\"lines_to_next_cell\": 0}\n assert pos == len(lines)", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_metadata():\n lines = [\n \"#' ---\",\n \"#' jupyter:\",\n \"#' jupytext:\",\n \"#' text_representation:\",\n \"#' extension: .py\",\n \"#' format_name: light\",\n \"#' format_version: '1.5'\",\n \"#' jupytext_version: 1.11.4\",\n \"#' kernelspec:\",\n \"#' display_name: Python 3\",\n \"#' language: python\",\n \"#' name: python3\",\n \"#' ---\",\n \"#'\",\n \"#' This is a comment\",\n \"#'\",\n \"#' And another one\",\n \"#'\",\n \"#' ```python\",\n \"#' import numpy as np\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' import matplotlib.pyplot as plt\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %matplotlib inline\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2" + ], + "line": 104, + "token": 642, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef import_optional_dependency(\n name: str,\n extra: str = \"\",\n errors: str = \"raise\",\n min_version: str | None = None,\n):\n\n assert errors in {\"warn\", \"raise\", \"ignore\"}\n\n package_name = INSTALL_MAPPING.get(name)\n install_name = package_name if package_name is not None else name\n\n msg = (\n f\"Missing optional dependency '{install_name}'. {extra} \"\n f\"Use pip or conda to install {install_name}.\"\n )\n try:\n module = importlib.import_module(name)\n except ImportError:\n if errors == \"raise\":\n raise ImportError(msg)\n return None\n\n # Handle submodules: if we have submodule, grab parent module from sys.modules\n parent = name.split(\".\")[0]\n if parent != name:\n install_name = parent\n module_to_get = sys.modules[install_name]\n else:\n module_to_get = module\n minimum_version = min_version if min_version is not None else VERSIONS.get(parent)\n if minimum_version:\n version = get_version(module_to_get)\n if version and Version(version) < Version(minimum_version):\n msg = (\n f\"Pandas requires version '{minimum_version}' or newer of '{parent}' \"\n f\"(version '{version}' currently installed).\"\n )\n if errors == \"warn\":\n warnings.warn(\n msg,\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return None\n elif errors == \"raise\":\n raise ImportError(msg)\n\n return module\n```\n###test function signature:\n```python\n\ndef test_bad_version(monkeypatch):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_bad_version", + "reference": "\ndef test_bad_version(monkeypatch):\n name = \"fakemodule\"\n module = types.ModuleType(name)\n module.__version__ = \"0.9.0\"\n sys.modules[name] = module\n monkeypatch.setitem(VERSIONS, name, \"1.0.0\")\n\n match = \"Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'\"\n with pytest.raises(ImportError, match=match):\n import_optional_dependency(\"fakemodule\")\n\n # Test min_version parameter\n result = import_optional_dependency(\"fakemodule\", min_version=\"0.8\")\n assert result is module\n\n with tm.assert_produces_warning(UserWarning):\n result = import_optional_dependency(\"fakemodule\", errors=\"warn\")\n assert result is None\n\n module.__version__ = \"1.0.0\" # exact match is OK\n result = import_optional_dependency(\"fakemodule\")\n assert result is module", + "completions": [ + "```python\ndef test_bad_version(monkeypatch):\n monkeypatch.setattr(import_optional_dependency, \"get_version\", lambda x: \"0.1.0\")\n monkeypatch.setattr(import_optional_dependency, \"VERSIONS\", {\"pandas\": \"0.2.0\"})\n with pytest.raises(ImportError):\n import_optional_dependency(\"pandas\")\n```\n" + ], + "line": 54, + "token": 352, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_mel_spectrogram(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mel_spectrogram", + "reference": "\n def test_mel_spectrogram(self):\n waveform = self._load_datasamples(1)[0]\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=513,\n num_mel_filters=13,\n min_frequency=100,\n max_frequency=4000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"htk\",\n )\n self.assertEqual(mel_filters.shape, (513, 13))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n )\n self.assertEqual(spec.shape, (513, 732))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n mel_filters=mel_filters,\n )\n self.assertEqual(spec.shape, (13, 732))\n\n # fmt: off\n expected = np.array([\n 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,\n 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,\n 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,\n 9.44153646e-04\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:, 300], expected))", + "completions": [ + "```python\n def test_mel_spectrogram(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n mel_floor = 1e-10\n log_mel = \"log\"\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (1, 10))\n self.assertEqual(spectrogram[0][0], 1)\n self.assertEqual(spectrogram[0][1], 2)\n self.assertEqual(spectrogram[0][2], 3)\n self.assertEqual(spectrogram[0][3], 4)\n self.assertEqual(spectrogram[0][4], 5)\n self.assertEqual(spectrogram[0][5], 6)\n self.assertEqual(spectrogram[0][6], 7)\n self.assertEqual(spectrogram[0][7], 8)\n self.assertEqual(spectrogram[0][8], 9)\n self.assertEqual(spectrogram[0][9], 10)\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 93, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef create(\n meta_schema,\n validators=(),\n version=None,\n type_checker=_types.draft202012_type_checker,\n format_checker=_format.draft202012_format_checker,\n id_of=_id_of,\n applicable_validators=methodcaller(\"items\"),\n):\n # preemptively don't shadow the `Validator.format_checker` local\n format_checker_arg = format_checker\n\n @attr.s\n class Validator:\n\n VALIDATORS = dict(validators)\n META_SCHEMA = dict(meta_schema)\n TYPE_CHECKER = type_checker\n FORMAT_CHECKER = format_checker_arg\n ID_OF = staticmethod(id_of)\n\n schema = attr.ib(repr=reprlib.repr)\n resolver = attr.ib(default=None, repr=False)\n format_checker = attr.ib(default=None)\n\n def __init_subclass__(cls):\n warnings.warn(\n (\n \"Subclassing validator classes is not intended to \"\n \"be part of their public API. A future version \"\n \"will make doing so an error, as the behavior of \"\n \"subclasses isn't guaranteed to stay the same \"\n \"between releases of jsonschema. Instead, prefer \"\n \"composition of validators, wrapping them in an object \"\n \"owned entirely by the downstream library.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n\n def __attrs_post_init__(self):\n if self.resolver is None:\n self.resolver = RefResolver.from_schema(\n self.schema,\n id_of=id_of,\n )\n\n @classmethod\n def check_schema(cls, schema, format_checker=_UNSET):\n Validator = validator_for(cls.META_SCHEMA, default=cls)\n if format_checker is _UNSET:\n format_checker = Validator.FORMAT_CHECKER\n validator = Validator(\n schema=cls.META_SCHEMA,\n format_checker=format_checker,\n )\n for error in validator.iter_errors(schema):\n raise exceptions.SchemaError.create_from(error)\n\n def evolve(self, **changes):\n # Essentially reproduces attr.evolve, but may involve instantiating\n # a different class than this one.\n cls = self.__class__\n\n schema = changes.setdefault(\"schema\", self.schema)\n NewValidator = validator_for(schema, default=cls)\n\n for field in attr.fields(cls):\n if not field.init:\n continue\n attr_name = field.name # To deal with private attributes.\n init_name = attr_name if attr_name[0] != \"_\" else attr_name[1:]\n if init_name not in changes:\n changes[init_name] = getattr(self, attr_name)\n\n return NewValidator(**changes)\n\n def iter_errors(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.iter_errors \"\n \"is deprecated and will be removed in a future \"\n \"release. Call validator.evolve(schema=new_schema).\"\n \"iter_errors(...) instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n else:\n _schema = self.schema\n\n if _schema is True:\n return\n elif _schema is False:\n yield exceptions.ValidationError(\n f\"False schema does not allow {instance!r}\",\n validator=None,\n validator_value=None,\n instance=instance,\n schema=_schema,\n )\n return\n\n scope = id_of(_schema)\n if scope:\n self.resolver.push_scope(scope)\n try:\n for k, v in applicable_validators(_schema):\n validator = self.VALIDATORS.get(k)\n if validator is None:\n continue\n\n errors = validator(self, v, instance, _schema) or ()\n for error in errors:\n # set details if not already set by the called fn\n error._set(\n validator=k,\n validator_value=v,\n instance=instance,\n schema=_schema,\n type_checker=self.TYPE_CHECKER,\n )\n if k not in {\"if\", \"$ref\"}:\n error.schema_path.appendleft(k)\n yield error\n finally:\n if scope:\n self.resolver.pop_scope()\n\n def descend(self, instance, schema, path=None, schema_path=None):\n for error in self.evolve(schema=schema).iter_errors(instance):\n if path is not None:\n error.path.appendleft(path)\n if schema_path is not None:\n error.schema_path.appendleft(schema_path)\n yield error\n\n def validate(self, *args, **kwargs):\n for error in self.iter_errors(*args, **kwargs):\n raise error\n\n def is_type(self, instance, type):\n try:\n return self.TYPE_CHECKER.is_type(instance, type)\n except exceptions.UndefinedTypeCheck:\n raise exceptions.UnknownType(type, instance, self.schema)\n\n def is_valid(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.is_valid is deprecated \"\n \"and will be removed in a future release. Call \"\n \"validator.evolve(schema=new_schema).is_valid(...) \"\n \"instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n self = self.evolve(schema=_schema)\n\n error = next(self.iter_errors(instance), None)\n return error is None\n\n if version is not None:\n safe = version.title().replace(\" \", \"\").replace(\"-\", \"\")\n Validator.__name__ = Validator.__qualname__ = f\"{safe}Validator\"\n Validator = validates(version)(Validator)\n\n return Validator\n```\n###test function signature:\n```python\n\n def test_validates_registers_meta_schema_id(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validates_registers_meta_schema_id", + "reference": "\n def test_validates_registers_meta_schema_id(self):\n meta_schema_key = \"meta schema id\"\n my_meta_schema = {\"id\": meta_schema_key}\n\n validators.create(\n meta_schema=my_meta_schema,\n version=\"my version\",\n id_of=lambda s: s.get(\"id\", \"\"),\n )\n self.addCleanup(validators._META_SCHEMAS.pop, meta_schema_key)\n\n self.assertIn(meta_schema_key, validators._META_SCHEMAS)", + "completions": [ + "```python\n def test_validates_registers_meta_schema_id(self):\n validator = create(\n meta_schema={\n \"$id\": \"https://example.com/schema.json\",\n \"type\": \"object\",\n \"properties\": {\"foo\": {\"type\": \"string\"}},\n }\n )\n self.assertEqual(\n validator.resolver.store[\"https://example.com/schema.json\"],\n validator.schema,\n )\n```\n" + ], + "line": 160, + "token": 896, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 94, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_metrics(directory: str) -> pd.Series:\n\n dfs = []\n # Loop through all csv files in the directory\n csv_files = [\n filename for filename in os.listdir(directory) if filename.endswith(\".csv\")\n ]\n\n if len(csv_files) == 0:\n return pd.Series()\n\n for filename in csv_files:\n filepath = os.path.join(directory, filename)\n # Read the csv file into a pandas DataFrame\n df = pd.read_csv(\n filepath,\n header=None,\n names=[\n \"initial_time_iso\",\n \"lead_time_hours\",\n \"channel\",\n \"metric\",\n \"value\",\n ],\n )\n out = pd.DataFrame()\n out[\"initial_time\"] = pd.to_datetime(df[\"initial_time_iso\"])\n out[\"lead_time\"] = pd.to_timedelta(df[\"lead_time_hours\"], unit=\"h\")\n out[\"channel\"] = df[\"channel\"]\n out[\"metric\"] = df[\"metric\"]\n out[\"value\"] = df[\"value\"]\n dfs.append(out)\n\n df = pd.concat(dfs, axis=0)\n df.set_index([\"initial_time\", \"lead_time\", \"channel\", \"metric\"], inplace=True)\n return df[\"value\"]\n```\n###test function signature:\n```python\n\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_and_read_metric", + "reference": "\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n initial_time = datetime.datetime(2022, 1, 1, 0, 0, 0)\n lead_time = datetime.timedelta(hours=24)\n channel = \"t2m\"\n metric = \"rmse\"\n value = 25.6\n\n # Read the metric from the directory\n directory = tmpdir.mkdir(\"metrics\")\n filename = \"test_metric.csv\"\n filepath = os.path.join(directory, filename)\n with open(filepath, \"w\") as f:\n forecast_metrics_io.write_metric(\n f, initial_time, lead_time, channel, metric, value\n )\n metrics = forecast_metrics_io.read_metrics(directory)\n\n # Check that the metric was written correctly\n expected_index = pd.MultiIndex.from_tuples(\n [(initial_time, lead_time, channel, metric)],\n names=[\"initial_time\", \"lead_time\", \"channel\", \"metric\"],\n )\n expected_value = pd.Series([value], index=expected_index, name=\"value\")\n pd.testing.assert_series_equal(metrics, expected_value)", + "completions": [ + "```python\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n metric = pd.Series(\n data=[1, 2, 3],\n index=[\n pd.Timestamp(\"2020-01-01 00:00:00\"),\n pd.Timestamp(\"2020-01-01 01:00:00\"),\n pd.Timestamp(\"2020-01-01 02:00:00\"),\n ],\n name=\"metric\",\n )\n metric.to_csv(tmpdir.join(\"metric.csv\"), header=False)\n\n # Read the metric from the file object\n read_metric = read_metrics(tmpdir.strpath)\n\n # Check that the metric is the same\n assert_series_equal(metric, read_metric)\n```\n" + ], + "line": 43, + "token": 320, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 95, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n\n def test_kde_bandwidth_method(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_kde_bandwidth_method", + "reference": "\n def test_kde_bandwidth_method(self):\n\n np.random.seed(8765678)\n n_basesample = 50\n xn = np.random.randn(n_basesample)\n\n # Default\n gkde = mlab.GaussianKDE(xn)\n # Supply a callable\n gkde2 = mlab.GaussianKDE(xn, 'scott')\n # Supply a scalar\n gkde3 = mlab.GaussianKDE(xn, bw_method=gkde.factor)\n\n xs = np.linspace(-7, 7, 51)\n kdepdf = gkde.evaluate(xs)\n kdepdf2 = gkde2.evaluate(xs)\n assert kdepdf.all() == kdepdf2.all()\n kdepdf3 = gkde3.evaluate(xs)\n assert kdepdf.all() == kdepdf3.all()", + "completions": [ + "```python\n def test_kde_bandwidth_method(self):\n from sklearn.neighbors import KernelDensity\n from sklearn.metrics import euclidean_distances\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n from sklearn.model_selection import train_test_split\n from sklearn.metrics import accuracy_score\n from sklearn.metrics import f1_score\n from sklearn.metrics import precision_score\n from sklearn.metrics import recall_score\n from sklearn.metrics import roc_auc_score\n from sklearn.metrics import confusion_matrix\n from sklearn.metrics import classification_report\n from sklearn.metrics import cohen_kappa_score\n from sklearn.metrics import mean_squared_error\n from sklearn.metrics import mean_absolute_error\n from sklearn.metrics import mean_squared_log_error\n from sklearn.metrics import median_absolute_error\n from sklearn.metrics import r2_score\n from sklearn.metrics import explained_variance_score\n from sklearn.metrics import silhouette_score\n from sklearn.metrics import calinski_harabasz_score\n from sklearn.metrics import davies_bouldin_score\n from sklearn.metrics import mutual_info_score\n from sklearn.metrics import normalized_mutual_info_score\n from sklearn.metrics import adjusted_mutual_info_score\n from sklearn.metrics import adjusted_rand_score\n from sklearn.metrics import completeness_score\n from sklearn.metrics import contingency_matrix\n from sklearn.metrics import homogeneity_completeness_v_measure\n from sklearn.metrics import homogeneity_score\n from sklearn.metrics import v_measure_score\n from sklearn.metrics import normalized_v_measure_score\n from sklearn.metrics import poisson\n from sklearn.metrics import kld\n from sklearn.metrics import kld_normalized\n from sklearn.metrics import kld_unnormalized\n from sklearn.metrics import kld_unnormalized_normalized\n from sklearn.metrics import kld_unnormalized_normalized_sym\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_top_k_embeddings(\n query_embedding: List[float],\n embeddings: List[List[float]],\n similarity_fn: Optional[Callable[..., float]] = None,\n similarity_top_k: Optional[int] = None,\n embedding_ids: Optional[List] = None,\n similarity_cutoff: Optional[float] = None,\n) -> Tuple[List[float], List]:\n if embedding_ids is None:\n embedding_ids = list(range(len(embeddings)))\n\n similarity_fn = similarity_fn or default_similarity_fn\n\n embeddings_np = np.array(embeddings)\n query_embedding_np = np.array(query_embedding)\n\n similarity_heap: List[Tuple[float, Any]] = []\n for i, emb in enumerate(embeddings_np):\n similarity = similarity_fn(query_embedding_np, emb)\n if similarity_cutoff is None or similarity > similarity_cutoff:\n heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))\n if similarity_top_k and len(similarity_heap) > similarity_top_k:\n heapq.heappop(similarity_heap)\n result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)\n\n result_similarities = [s for s, _ in result_tups]\n result_ids = [n for _, n in result_tups]\n\n return result_similarities, result_ids\n```\n###test function signature:\n```python\ndef test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_top_k_mmr_embeddings", + "reference": "def test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n query_embedding = [5.0, 0.0, 0.0]\n embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]]\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.8\n )\n\n assert np.isclose(0.8 * 4 / 5, result_similarities[0], atol=0.00001)\n assert np.isclose(\n 0.8 * 3 / 5 - (1 - 0.8) * (3 * 4 / 25 + 3 * 4 / 25),\n result_similarities[1],\n atol=0.00001,\n )\n assert np.isclose(\n 0.8 * -4 / 5 - (1 - 0.8) * (3 * -4 / 25 + 4 * 3 / 25),\n result_similarities[2],\n atol=0.00001,\n )\n assert result_ids == [0, 1, 2]\n\n # Tests that if the first embedding vector is close to the second,\n # it will return the third\n query_embedding = [1.0, 0.0, 1.0]\n embeddings = [[1.0, 0.0, 0.9], [1.0, 0.0, 0.8], [0.7, 0.0, 1.0]]\n\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.5\n )\n assert result_ids == [0, 2, 1]\n\n # Tests that embedding ids map properly to results\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, embedding_ids=[\"A\", \"B\", \"C\"], mmr_threshold=0.5\n )\n assert result_ids == [\"A\", \"C\", \"B\"]\n # Test that it will go back to the original order under a high threshold\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n assert result_ids == [0, 1, 2]\n\n # Test similarity_top_k works\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1, similarity_top_k=2\n )\n assert result_ids == [0, 1]\n\n # Test the results for get_top_k_embeddings and get_top_k_mmr_embeddings are the\n # same for threshold = 1\n query_embedding = [10, 23, 90, 78]\n embeddings = [[1, 23, 89, 68], [1, 74, 144, 23], [0.23, 0.0, 1.0, 9]]\n result_similarities_no_mmr, result_ids_no_mmr = get_top_k_embeddings(\n query_embedding, embeddings\n )\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n\n for result_no_mmr, result_with_mmr in zip(\n result_similarities_no_mmr, result_similarities\n ):\n assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)", + "completions": [ + "```python\ndef test_get_top_k_mmr_embeddings() -> None:\n query_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]\n embeddings = [\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n ]\n result_similarities, result_ids = get_top_k_embeddings(\n query_embedding, embeddings, similarity_top_k=3\n )\n assert result_similarities == [0.5, 0.4, 0.3]\n assert result_ids == [0, 1, 2]\n```\n" + ], + "line": 35, + "token": 309, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 97, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef frame_accuracy(\n logits: np.ndarray,\n label: np.ndarray,\n background_logit_threshold: float = 0.0,\n) -> float:\n top1_idx = np.argmax(logits, axis=-1)\n background_label = np.sum(label, axis=-1) == 0\n pred_background = (\n np.max(logits, axis=-1) <= background_logit_threshold\n ).astype(np.int32)\n\n # Extracts the label at the highest logit index for each input.\n top1_correct = np.take_along_axis(label, top1_idx[..., None], axis=-1)\n top1_correct = np.squeeze(top1_correct)\n foreground_correct = ~pred_background.astype(bool) * top1_correct\n\n # Count correctly classified background samples.\n background_correct = pred_background * background_label\n correct = foreground_correct + background_correct\n return np.sum(correct) / len(logits)\n```\n###test function signature:\n```python\n\n def test_frame_accuracy_all_background(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_frame_accuracy_all_background", + "reference": "\n def test_frame_accuracy_all_background(self):\n logits = np.array(\n [\n [1.5, 0.0, -1.0],\n [0.1, 0.7, 0.2],\n [0.1, 0.7, 0.2],\n [0.1, -0.7, 0.2],\n [-0.1, -0.7, -0.2],\n ],\n np.float32,\n )\n label = np.array(\n [\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n ],\n np.int32,\n )\n actual = metrics.frame_accuracy(logits, label)\n self.assertAlmostEqual(actual, 0.2)", + "completions": [ + "```python\ndef test_frame_accuracy_all_background(self):\n logits = np.array([[0.1, 0.2, 0.3, 0.4]])\n label = np.array([[0, 0, 0, 0]])\n self.assertAlmostEqual(frame_accuracy(logits, label), 1.0)\n```\n" + ], + "line": 28, + "token": 206, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_local_image_to_be_loaded", + "reference": "\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n file_path, _ = example_local_image\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=file_path,\n max_width=64,\n max_height=128,\n )\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert abs(scaling_factor - 0.5) < 1e-5\n assert decoding_result.shape == (64, 64, 3)\n assert (decoding_result == 0).all()", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n reference, image = example_local_image\n # when\n result = load_image_from_string(\n reference=reference,\n max_height=image.shape[0],\n max_width=image.shape[1],\n )\n # then\n assert result[0] == reference\n assert result[1] is None\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 99, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = \"htk\") -> Union[float, np.ndarray]:\n\n if mel_scale not in [\"slaney\", \"htk\", \"kaldi\"]:\n raise ValueError('mel_scale should be one of \"htk\", \"slaney\" or \"kaldi\".')\n\n if mel_scale == \"htk\":\n return 2595.0 * np.log10(1.0 + (freq / 700.0))\n elif mel_scale == \"kaldi\":\n return 1127.0 * np.log(1.0 + (freq / 700.0))\n\n min_log_hertz = 1000.0\n min_log_mel = 15.0\n logstep = 27.0 / np.log(6.4)\n mels = 3.0 * freq / 200.0\n\n if isinstance(freq, np.ndarray):\n log_region = freq >= min_log_hertz\n mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep\n elif freq >= min_log_hertz:\n mels = min_log_mel + np.log(freq / min_log_hertz) * logstep\n\n return mels\n```\n###test function signature:\n```python\n\n def test_hertz_to_mel(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_hertz_to_mel", + "reference": "\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(0.0), 0.0)\n self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)\n\n inputs = np.array([100, 200])\n expected = np.array([150.48910241, 283.22989816])\n self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))\n\n self.assertEqual(hertz_to_mel(0.0, \"slaney\"), 0.0)\n self.assertEqual(hertz_to_mel(100, \"slaney\"), 1.5)\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"slaney\"), expected))\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"kaldi\"), expected))\n\n with pytest.raises(ValueError):\n hertz_to_mel(100, mel_scale=None)", + "completions": [ + "```python\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(" + ], + "line": 27, + "token": 262, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/7/EI/line_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Python/7/EI/line_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..6f376c6188f58530671329e541a7a7a32efc8b6e --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/7/EI/line_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(15~104),subset_1(133~187),subset_2,subset_3,subset_4,subset_5,subset_6(749~749) +StarCoder2-15b,24.25,27.66,0.00,0.00,0.00,0.00,20.70 +CodeLlama-7b,26.39,26.96,0.00,0.00,0.00,0.00,27.18 +CodeLlama-13b,26.72,23.82,0.00,0.00,0.00,0.00,31.82 +CodeLlama-34b,26.98,29.47,0.00,0.00,0.00,0.00,24.58 +DeepSeek-Coder-1.3b,25.45,30.19,0.00,0.00,0.00,0.00,21.28 +DeepSeek-Coder-6.7b,24.67,25.76,0.00,0.00,0.00,0.00,20.74 +DeepSeek-Coder-33b,28.81,33.38,0.00,0.00,0.00,0.00,25.46 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/7/EI/token_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Python/7/EI/token_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..5897dae2eaf7694bc4d66981f408e6164cb199a3 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/7/EI/token_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(122~992),subset_1(1112~1981),subset_2,subset_3,subset_4,subset_5,subset_6(7038~7038) +StarCoder2-15b,24.64,20.71,0.00,0.00,0.00,0.00,20.70 +CodeLlama-7b,26.28,33.79,0.00,0.00,0.00,0.00,27.18 +CodeLlama-13b,26.69,16.50,0.00,0.00,0.00,0.00,31.82 +CodeLlama-34b,27.27,24.06,0.00,0.00,0.00,0.00,24.58 +DeepSeek-Coder-1.3b,25.99,19.48,0.00,0.00,0.00,0.00,21.28 +DeepSeek-Coder-6.7b,24.75,25.44,0.00,0.00,0.00,0.00,20.74 +DeepSeek-Coder-33b,29.14,32.65,0.00,0.00,0.00,0.00,25.46 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/7/EI/tongji.py b/dataset/Test Generation/ComplexCodeEval-Python/7/EI/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/7/EI/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/7/QS/QS.json b/dataset/Test Generation/ComplexCodeEval-Python/7/QS/QS.json new file mode 100644 index 0000000000000000000000000000000000000000..1441a3dbdfc7da43254cef554f06e74f10437c4b --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/7/QS/QS.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_pair_vector_and_distance(g: dgl.DGLGraph):\n dst_pos = g.ndata[\"pos\"][g.edges()[1]] + g.edata[\"pbc_offshift\"]\n src_pos = g.ndata[\"pos\"][g.edges()[0]]\n bond_vec = dst_pos - src_pos\n bond_dist = torch.norm(bond_vec, dim=1)\n\n return bond_vec, bond_dist\n```\n###test function signature:\n```python\n\n def test_compute_angle(self, graph_Mo, graph_CH4):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_angle", + "reference": "\n def test_compute_angle(self, graph_Mo, graph_CH4):\n s1, g1, state1 = graph_Mo\n lattice = torch.tensor(s1.lattice.matrix, dtype=matgl.float_th).unsqueeze(dim=0)\n g1.edata[\"pbc_offshift\"] = torch.matmul(g1.edata[\"pbc_offset\"], lattice[0])\n g1.ndata[\"pos\"] = g1.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g1)\n g1.edata[\"bond_vec\"] = bv\n g1.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g1, 4.0)\n\n line_graph = create_line_graph(g1, 4.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))\n np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata[\"theta\"])), decimal=4)\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n s2, g2, state2 = graph_CH4\n lattice = torch.tensor(np.identity(3), dtype=matgl.float_th).unsqueeze(dim=0)\n g2.edata[\"pbc_offshift\"] = torch.matmul(g2.edata[\"pbc_offset\"], lattice[0])\n g2.ndata[\"pos\"] = g2.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g2)\n g2.edata[\"bond_vec\"] = bv\n g2.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g2, 2.0)\n\n line_graph = create_line_graph(g2, 2.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata[\"theta\"]))\n )\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )", + "completions": [ + "```python\n def test_compute_angle(self, graph_Mo, graph_CH4):\n \"\"\"\n Test compute_angle function\n \"\"\"\n # compute angle between Mo-C and Mo-H\n angle_Mo_C = compute_angle(graph_Mo, graph_CH4, 0, 1)\n angle_Mo_H = compute_angle(graph_Mo, graph_CH4, 0, 2)\n\n # check angle between Mo-C and Mo-H\n self.assertAlmostEqual(angle_Mo_C, 104.5, places=1)\n self.assertAlmostEqual(angle_Mo_H, 109.5, places=1)\n```\n" + ], + "line": 16, + "token": 148, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 1, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def _make_schedule(self, targets: List[Task]) -> List[Task]:\n self.schedule = []\n self.dependencies = self._build_dependencies(targets)\n\n edge_tups = []\n for node in self.dependencies:\n for dependency in self.dependencies[node]:\n edge_tups.append((dependency, node))\n\n for task in targets:\n # add edges from a dummy node to each target to guarantee\n # they will be included in the final schedule\n edge_tups.append((Executor.DUMMY_TASK_VALUE, task))\n\n def _compare_key(task: Union[Task, str]):\n if task == Executor.DUMMY_TASK_VALUE:\n return (\"\", 0)\n return (\n task.group_label() or \"\",\n -task.priority(),\n )\n\n graph = networkx.DiGraph(edge_tups)\n res = [\n t\n for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)\n if t != Executor.DUMMY_TASK_VALUE\n ]\n return res\n```\n###test function signature:\n```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_group_label_with_dependencies", + "reference": "\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = create_mock_task(\"task1\", result=1, group_label=\"group1\")\n task2 = create_mock_task(\n \"task2\", result=2, dependencies={\"task1\": task1}, group_label=\"group2\"\n )\n task3 = create_mock_task(\n \"task3\", result=3, dependencies={\"task2\": task2}, group_label=\"group1\"\n )\n\n executor = Executor([task3])\n schedule = executor._make_schedule([task3])\n scheduled_labels = [\n task.group_label() for task in schedule if task.group_label()\n ]\n\n # Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1\n group1_indices = [\n i for i, label in enumerate(scheduled_labels) if label == \"group1\"\n ]\n group2_index = scheduled_labels.index(\"group2\")\n\n assert (\n group1_indices[-1] > group2_index\n ), \"Task with the same group label but later dependency was not scheduled after different group label\"", + "completions": [ + "```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = Task(name=\"task1\", group_label=\"group1\")\n task2 = Task(name=\"task2\", group_label=\"group1\")\n task3 = Task(name=\"task3\", group_label=\"group2\")\n task4 = Task(name=\"task4\", group_label=\"group2\")\n task5 = Task(name=\"task5\", group_label=\"group3\")\n task6 = Task(name=\"task6\", group_label=\"group3\")\n task7 = Task(name=\"task7\", group_label=\"group4\")\n task8 = Task(name=\"task8\", group_label=\"group4\")\n task9 = Task(name=\"task9\", group_label=\"group5\")\n task10 = Task(name=\"task10\", group_label=\"group5\")\n\n # Creating dependencies between tasks\n task1.depends_on(task2)\n task2.depends_on(task3)\n task3.depends_on(task4)\n task4.depends_on(task5)\n task5.depends_on(task6)\n task6.depends_on(task7)\n task7.depends_on(task8)\n task8.depends_on(task9)\n task9.depends_on(task10)\n\n # Creating a list of tasks\n tasks = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies\n tasks_with_dependencies = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with group labels\n tasks_with_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_" + ], + "line": 36, + "token": 242, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 2, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n):\n if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:\n return\n\n head_dim = dim // heads\n max_seq_len = find_multiple(max_seq_len, 8)\n\n for b in layers:\n b.attention.kv_cache = KVCache(\n max_batch_size, max_seq_len, heads, head_dim\n )\n\n freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)\n causal_mask = torch.tril(\n torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)\n )\n\n return causal_mask, freq_cis\n```\n###test function signature:\n```python\n\ndef test_setup_cache_max_batch_size_greater_than_max():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_setup_cache_max_batch_size_greater_than_max", + "reference": "\ndef test_setup_cache_max_batch_size_greater_than_max():\n layers = [\n Mock(attention=Mock(kw_cache=None)),\n Mock(attention=Mock(kw_cache=None)),\n ]\n max_batch_size = 64\n max_seq_len = 32\n heads = 4\n head_dim = 32\n block_size = 32\n rope_base = 1000\n setup_cache(\n max_batch_size + 10,\n max_seq_len,\n head_dim * heads,\n heads,\n layers,\n block_size,\n rope_base,\n )\n for layer in layers:\n assert isinstance(layer.attention.kw_cache, KVCache)\n assert layer.attention.kw_cache.k_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )\n assert layer.attention.kw_cache.v_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )", + "completions": [ + "```python\ndef test_setup_cache_max_batch_size_greater_than_max():\n max_batch_size = 10\n max_seq_len = 10\n dim = 10\n heads = 10\n layers = [1, 2, 3]\n block_size = 10\n rope_base = 10\n\n causal_mask, freq_cis = setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n )\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size" + ], + "line": 26, + "token": 168, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 3, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_file_to_be_downloaded", + "reference": "\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n response = Response()\n response.status_code = 200\n image = np.zeros((128, 128, 3), dtype=np.uint8)\n _, encoded_image = cv2.imencode(\".jpg\", image)\n response._content = encoded_image\n requests_get_mock.return_value = response\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=\"https://some/file.jpg\"\n )\n\n # then\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert scaling_factor is None\n assert decoding_result.shape == image.shape\n assert np.allclose(decoding_result, image)\n requests_get_mock.assert_called_once_with(\"https://some/file.jpg\")", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n reference = \"https://www.example.com/image.jpg\"\n max_height = 100\n max_width = 100\n requests_get_mock.return_value.content = b\"image_bytes\"\n # when\n result = load_image_from_string(\n reference=reference, max_height=max_height, max_width=max_width\n )\n # then\n assert result == (\n \"data:image/jpeg;base64,image_bytes\",\n None,\n )\n```\n" + ], + "line": 47, + "token": 276, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 4, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_random_circuits", + "reference": "\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n placer = placer(connectivity=connectivity)\n layout_circ = Circuit(5)\n initial_layout = placer(layout_circ)\n router = Sabre(connectivity=connectivity, lookahead=look, decay_lookahead=decay)\n circuit = generate_random_circuit(nqubits=5, ngates=n_gates)\n measurement = gates.M(*range(5))\n circuit.add(measurement)\n transpiled_circuit, final_qubit_map = router(circuit, initial_layout)\n assert router.added_swaps >= 0\n assert_connectivity(connectivity, transpiled_circuit)\n assert_placement(transpiled_circuit, final_qubit_map)\n assert n_gates + router.added_swaps + 1 == transpiled_circuit.ngates\n assert_circuit_equivalence(\n original_circuit=circuit,\n transpiled_circuit=transpiled_circuit,\n final_map=final_qubit_map,\n initial_map=initial_layout,\n )\n assert transpiled_circuit.queue[-1].register_name == measurement.register_name", + "completions": [ + "```python\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n import random\n import networkx as nx\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import" + ], + "line": 27, + "token": 250, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 5, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef xx_circuit_step(source, strength, target, embodiment):\n\n permute_source_for_overlap, permute_target_for_overlap = None, None\n\n # apply all possible reflections, shifts to the source\n for source_reflection_name in reflection_options:\n reflected_source_coord, source_reflection, reflection_phase_shift = apply_reflection(\n source_reflection_name, source\n )\n for source_shift_name in shift_options:\n shifted_source_coord, source_shift, shift_phase_shift = apply_shift(\n source_shift_name, reflected_source_coord\n )\n\n # check for overlap, back out permutation\n source_shared, target_shared = None, None\n for i, j in [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]:\n\n if (\n abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi)) < EPSILON\n or abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi) - np.pi)\n < EPSILON\n ):\n source_shared, target_shared = i, j\n break\n if source_shared is None:\n continue\n\n # pick out the other coordinates\n source_first, source_second = (x for x in [0, 1, 2] if x != source_shared)\n target_first, target_second = (x for x in [0, 1, 2] if x != target_shared)\n\n # check for arccos validity\n r, s, u, v, x, y = decompose_xxyy_into_xxyy_xx(\n float(target[target_first]),\n float(target[target_second]),\n float(shifted_source_coord[source_first]),\n float(shifted_source_coord[source_second]),\n float(strength),\n )\n if any(math.isnan(val) for val in (r, s, u, v, x, y)):\n continue\n\n # OK: this combination of things works.\n # save the permutation which rotates the shared coordinate into ZZ.\n permute_source_for_overlap = canonical_rotation_circuit(source_first, source_second)\n permute_target_for_overlap = canonical_rotation_circuit(target_first, target_second)\n break\n\n if permute_source_for_overlap is not None:\n break\n\n if permute_source_for_overlap is None:\n raise QiskitError(\n \"Error during RZX decomposition: Could not find a suitable Weyl \"\n f\"reflection to match {source} to {target} along {strength}.\"\n )\n\n prefix_circuit, affix_circuit = QuantumCircuit(2), QuantumCircuit(2)\n\n # the basic formula we're trying to work with is:\n # target^p_t_f_o =\n # rs * (source^s_reflection * s_shift)^p_s_f_o * uv * operation * xy\n # but we're rearranging it into the form\n # target = affix source prefix\n # and computing just the prefix / affix circuits.\n\n # the outermost prefix layer comes from the (inverse) target permutation.\n prefix_circuit.compose(permute_target_for_overlap.inverse(), inplace=True)\n # the middle prefix layer comes from the local Z rolls.\n prefix_circuit.rz(2 * x, [0])\n prefix_circuit.rz(2 * y, [1])\n prefix_circuit.compose(embodiment, inplace=True)\n prefix_circuit.rz(2 * u, [0])\n prefix_circuit.rz(2 * v, [1])\n # the innermost prefix layer is source_reflection, shifted by source_shift,\n # finally conjugated by p_s_f_o.\n prefix_circuit.compose(permute_source_for_overlap, inplace=True)\n prefix_circuit.compose(source_reflection, inplace=True)\n prefix_circuit.global_phase += -np.log(reflection_phase_shift).imag\n prefix_circuit.global_phase += -np.log(shift_phase_shift).imag\n\n # the affix circuit is constructed in reverse.\n # first (i.e., innermost), we install the other half of the source transformations and p_s_f_o.\n affix_circuit.compose(source_reflection.inverse(), inplace=True)\n affix_circuit.compose(source_shift, inplace=True)\n affix_circuit.compose(permute_source_for_overlap.inverse(), inplace=True)\n # then, the other local rolls in the middle.\n affix_circuit.rz(2 * r, [0])\n affix_circuit.rz(2 * s, [1])\n # finally, the other half of the p_t_f_o conjugation.\n affix_circuit.compose(permute_target_for_overlap, inplace=True)\n\n return {\"prefix_circuit\": prefix_circuit, \"affix_circuit\": affix_circuit}\n```\n###test function signature:\n```python\n def test_xx_circuit_step(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_xx_circuit_step", + "reference": " def test_xx_circuit_step(self):\n\n for _ in range(100):\n source_coordinate, interaction, target_coordinate = self._generate_xxyy_test_case()\n\n source_embodiment = qiskit.QuantumCircuit(2)\n source_embodiment.append(UnitaryGate(canonical_matrix(*source_coordinate)), [0, 1])\n interaction_embodiment = qiskit.QuantumCircuit(2)\n interaction_embodiment.append(UnitaryGate(canonical_matrix(*interaction)), [0, 1])\n\n prefix_circuit, affix_circuit = itemgetter(\"prefix_circuit\", \"affix_circuit\")(\n xx_circuit_step(\n source_coordinate, interaction[0], target_coordinate, interaction_embodiment\n )\n )\n\n target_embodiment = QuantumCircuit(2)\n target_embodiment.compose(prefix_circuit, inplace=True)\n target_embodiment.compose(source_embodiment, inplace=True)\n target_embodiment.compose(affix_circuit, inplace=True)\n self.assertTrue(\n np.all(\n np.abs(\n qiskit.quantum_info.operators.Operator(target_embodiment).data\n - canonical_matrix(*target_coordinate)\n )\n < EPSILON\n )\n )", + "completions": [ + "```python\n def test_xx_circuit_step(self):\n source = np.array([0.1, 0.2, 0.3])\n strength = 0.4\n target = np.array([0.5, 0.6, 0.7])\n embodiment = QuantumCircuit(2)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.cx(0, 1)\n embodiment.barrier()\n embodiment.rz(0.1, 0)\n embodiment.rz(0.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.3, 0)\n embodiment.rz(0.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.5, 0)\n embodiment.rz(0.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.7, 0)\n embodiment.rz(0.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.9, 0)\n embodiment.rz(1.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.1, 0)\n embodiment.rz(1.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.3, 0)\n embodiment.rz(1.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.5, 0)\n embodiment.rz(1.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.7, 0)\n embodiment.rz(1.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.9, 0)\n embodiment.rz(2.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(2.1, 0)\n embodiment.rz(2.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment" + ], + "line": 90, + "token": 807, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 6, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_einsum(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_einsum", + "reference": "\n def test_einsum(self):\n x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n self.assertAllClose(\n knp.einsum(\"ijk,lkj->il\", x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"ijk,ikj->i\", x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"i...,j...k->...ijk\", x, y),\n np.einsum(\"i..., j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n self.assertAllClose(\n knp.Einsum(\"ijk,lkj->il\")(x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"ijk,ikj->i\")(x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n np.einsum(\"i...,j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))", + "completions": [ + "```python\n\n def test_arange(self):\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32),\n torch.arange(0, 10, 1, dtype=torch.float32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64),\n torch.arange(0, 10, 1, dtype=torch.float64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32),\n torch.arange(0, 10, 1, dtype=torch.int32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64),\n torch.arange(0, 10, 1, dtype=torch.int64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool),\n torch.arange(0, 10, 1, dtype=torch.bool),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64),\n torch.arange(0, 10, 1, dtype=torch.complex64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128),\n torch.arange(0, 10, 1, dtype=torch.complex128),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n )\n```\n" + ], + "line": 25, + "token": 195, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 7, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_center_padding(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_center_padding", + "reference": "\n def test_spectrogram_center_padding(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"reflect\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,\n 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,\n 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,\n 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,\n 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,\n 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,\n 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,\n 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,\n 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,\n 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,\n 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,\n 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,\n 0.00217659, 0.00276204, 0.00260835, 0.00299299,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"constant\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,\n 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,\n 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,\n 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,\n 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,\n 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,\n 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,\n 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,\n 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,\n 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,\n 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,\n 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,\n 0.00788239, 0.00664407, 0.00824227, 0.00628301,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=False,\n )\n self.assertEqual(spec.shape, (257, 728))\n\n # fmt: off\n expected = np.array([\n 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,\n 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,\n 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,\n 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,\n 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,\n 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,\n 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,\n 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,\n 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,\n 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,\n 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,\n 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,\n 0.00811857, 0.00538216, 0.00685749, 0.00535275,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))", + "completions": [ + "```python\n def test_spectrogram_center_padding(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 4\n hop_length = 2\n fft_length = 16\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n np.testing.assert_array_equal(\n spectrogram,\n np.array(\n [\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0," + ], + "line": 95, + "token": 804, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 8, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_kword_ref_line", + "reference": "\ndef test_nanoplot_out_bars_with_kword_ref_line():\n\n out_bars_with_kword_ref_line = _generate_nanoplot(**CASES[12])\n\n assert _is_nanoplot_output(out_bars_with_kword_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = \"max\"\n y_ref_area = [1, 2]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = \"min\"\n expand_y = \"min\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_stroke_color=data_bar_negative_stroke_" + ], + "line": 749, + "token": 7038, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 9, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_take_along_axis(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_take_along_axis", + "reference": "\n def test_take_along_axis(self):\n x = np.arange(24).reshape([1, 2, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=1),\n np.take_along_axis(x, indices, axis=1),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=1)(x, indices),\n np.take_along_axis(x, indices, axis=1),\n )\n\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )\n\n # Test with axis=None\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.array([1, 2, 3], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=None),\n np.take_along_axis(x, indices, axis=None),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=None)(x, indices),\n np.take_along_axis(x, indices, axis=None),\n )\n\n # Test with negative indices\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )", + "completions": [ + "```python\n\n def test_take_along_axis(self):\n # test arange\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=1, dtype=torch.float32),\n torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=2, dtype=torch.float32),\n torch.tensor([0, 2, 4, 6, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=3, dtype=torch.float32),\n torch.tensor([0, 3, 6, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=4, dtype=torch.float32),\n torch.tensor([0, 4, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=5, dtype=torch.float32),\n torch.tensor([0, 5], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=6, dtype=torch.float32),\n torch.tensor([0, 6], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=7, dtype=torch.float32),\n torch.tensor([0, 7], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=8, dtype=torch.float32),\n torch.tensor([0, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=9, dtype=torch.float32),\n torch.tensor([0, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=10, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=11, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=12, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=13, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=14, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32" + ], + "line": 25, + "token": 195, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 10, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_correlation(\n response_level_scores: dict[str, list[dict[str, dict[str, Any]]]],\n name: str = '',\n) -> dict[str, dict[str, dict[str, float]]]:\n factscore_metrics = set(list_metrics(response_level_scores[FACTSCORE]))\n pred_metrics = set(list_metrics(response_level_scores[EVAL_METHOD]))\n result = {}\n\n for metric in factscore_metrics.intersection(pred_metrics):\n factscore_scores = [\n find_metric(response_level_scores[FACTSCORE][i], metric)\n for i in range(len(response_level_scores[FACTSCORE]))\n ]\n pred_scores = [\n find_metric(response_level_scores[EVAL_METHOD][i], metric)\n for i in range(len(response_level_scores[EVAL_METHOD]))\n ]\n\n # Remove failed runs\n for i in range(len(factscore_scores) - 1, -1, -1):\n if factscore_scores[i] == -1 or pred_scores[i] == -1:\n factscore_scores.pop(i)\n pred_scores.pop(i)\n\n # Cannot calculate correlation with only one data point\n if len(factscore_scores) <= 1 or len(pred_scores) <= 1:\n pearson_result, spearman_result = None, None\n else:\n pearson_result = stats.pearsonr(factscore_scores, pred_scores)\n spearman_result = stats.spearmanr(factscore_scores, pred_scores)\n\n scatter_plot(\n factscore_scores,\n pred_scores,\n f'{name}-{metric}',\n FACTSCORE,\n EVAL_METHOD,\n )\n result[metric] = {\n _PEARSON: CorrelationResult(pearson_result).result,\n _SPEARMAN: CorrelationResult(spearman_result).result,\n }\n\n return result\n```\n###test function signature:\n```python\n\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_correlation_one_data_point", + "reference": "\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n response_level_scores = {\n correlation_vs_factscore.FACTSCORE: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'N': 1}},\n ],\n correlation_vs_factscore.EVAL_METHOD: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'M': 1}},\n ],\n }\n actual_output = correlation_vs_factscore.compute_correlation(\n response_level_scores=response_level_scores\n )\n mock_pearsonr.assert_not_called()\n mock_spearmanr.assert_not_called()\n mock_scatter_plot.assert_called()\n self.assertIsInstance(actual_output, dict)\n self.assertLen(actual_output, 1)\n self.assertIn('S', actual_output)\n self.assertIsInstance(actual_output['S'], dict)\n self.assertLen(actual_output['S'], 2)\n self.assertIn('Pearson', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Pearson'], dict)\n self.assertIn('correlation', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['correlation']))\n self.assertIn('p_value', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['p_value']))\n self.assertIn('Spearman', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Spearman'], dict)\n self.assertIn('correlation', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['correlation']))\n self.assertIn('p_value', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['p_value']))", + "completions": [ + "```python\ndef test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n) -> None:\n response_level_scores = {\n FACTSCORE: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n EVAL_METHOD: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n }\n result = compute_correlation(response_level_scores)\n self.assertEqual(result, {})\n mock_spearmanr.assert_not_called()\n mock_pearsonr.assert_not_called()\n mock_scatter_plot.assert_not_called()\n```\n" + ], + "line": 54, + "token": 384, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 11, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_data_lines():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_data_lines", + "reference": "\ndef test_nanoplot_out_data_lines():\n\n out_data_lines = _generate_nanoplot(**CASES[0])\n\n assert _is_nanoplot_output(out_data_lines)\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_data_lines():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n expected = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_" + ], + "line": 749, + "token": 7038, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 12, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_daily_temp_mean", + "reference": "\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data = il_electricity_cdd_hdd_daily[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_daily[\"temperature_data\"]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert df.shape == (810, 3)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_mean\",\n ]\n\n assert round(df.temperature_mean.mean()) == 55.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"D\", tz=\"UTC\"\n )\n temperature_data = pd.DataFrame(\n {\"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]},\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n df = compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=[],\n cooling_balance_points=[],\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n )\n expected = pd.DataFrame(\n {\n \"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],\n \"n_hours_kept\": [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],\n \"n_hours_dropped\": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n },\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n pd.testing.assert_frame_equal(df, expected)\n```\n" + ], + "line": 166, + "token": 983, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 13, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef power_color(\n frequency,\n power,\n power_err=None,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n):\n freq_edges = np.asarray(freq_edges)\n if len(freq_edges) != 5:\n raise ValueError(\"freq_edges must have 5 elements\")\n\n frequency = np.asarray(frequency)\n power = np.asarray(power)\n\n if df is None:\n df = np.median(np.diff(frequency))\n input_frequency_low_edges = frequency - df / 2\n input_frequency_high_edges = frequency + df / 2\n\n if freq_edges.min() < input_frequency_low_edges[0]:\n raise ValueError(\"The minimum frequency is larger than the first frequency edge\")\n if freq_edges.max() > input_frequency_high_edges[-1]:\n raise ValueError(\"The maximum frequency is lower than the last frequency edge\")\n\n if power_err is None:\n power_err = power / np.sqrt(m)\n else:\n power_err = np.asarray(power_err)\n\n if freqs_to_exclude is not None:\n if len(np.shape(freqs_to_exclude)) == 1:\n freqs_to_exclude = [freqs_to_exclude]\n\n if (\n not isinstance(freqs_to_exclude, Iterable)\n or len(np.shape(freqs_to_exclude)) != 2\n or np.shape(freqs_to_exclude)[1] != 2\n ):\n raise ValueError(\"freqs_to_exclude must be of format [[f0, f1], [f2, f3], ...]\")\n for f0, f1 in freqs_to_exclude:\n frequency_mask = (input_frequency_low_edges > f0) & (input_frequency_high_edges < f1)\n idx0, idx1 = np.searchsorted(frequency, [f0, f1])\n power[frequency_mask] = np.mean([power[idx0], power[idx1]])\n\n var00, var00_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[:2],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var01, var01_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[2:4],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var10, var10_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[1:3],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var11, var11_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[3:5],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n pc0 = var00 / var01\n pc1 = var10 / var11\n pc0_err = pc0 * (var00_err / var00 + var01_err / var01)\n pc1_err = pc1 * (var10_err / var10 + var11_err / var11)\n if return_log:\n pc0_err = 1 / pc0 * pc0_err\n pc1_err = 1 / pc1 * pc1_err\n pc0 = np.log10(pc0)\n pc1 = np.log10(pc1)\n return pc0, pc0_err, pc1, pc1_err\n```\n###test function signature:\n```python\n\n def test_with_power_err(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_power_err", + "reference": "\n def test_with_power_err(self):\n pc0, pc0_err, pc1, pc1_err = power_color(\n self.freq,\n self.power,\n power_err=self.power / 2,\n )\n pc0e, pc0e_err, pc1e, pc1e_err = power_color(\n self.freq,\n self.power,\n power_err=self.power,\n )\n assert np.isclose(pc0, 1, atol=0.001)\n assert np.isclose(pc1, 1, atol=0.001)\n assert np.isclose(pc0e, 1, atol=0.001)\n assert np.isclose(pc1e, 1, atol=0.001)\n assert np.isclose(pc0e_err / pc0_err, 2, atol=0.001)\n assert np.isclose(pc1e_err / pc1_err, 2, atol=0.001)", + "completions": [ + "```python\n def test_with_power_err(self):\n frequency = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power_err = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n pc0, pc0_err, pc1, pc1_err = power_color(\n frequency,\n power,\n power_err=power_err,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n )\n self.assertEqual(pc0, 0.0)\n self.assertEqual(pc0_err, 0.0)\n self.assertEqual(pc1, 0.0)\n self.assertEqual(pc1_err, 0.0)\n```\n" + ], + "line": 97, + "token": 560, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 14, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def apply_top_k(\n self, logits: numpy.ndarray, filter_value=_MIN_FLOAT\n ) -> numpy.ndarray:\n logits_shape = logits.shape\n logits = logits.reshape(logits.shape[-1])\n top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]\n logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value\n\n return logits.reshape(logits_shape)\n```\n###test function signature:\n```python\n\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_apply_topk", + "reference": "\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11))\n\n token_generator = TokenGenerator(\n logits_shape=logits[-1].shape[-1],\n top_k=3,\n )\n\n filter_value = -float(\"Inf\")\n new_logits = token_generator.apply_top_k(\n logits.copy(), filter_value=filter_value\n )\n\n for _ in range(token_generator.top_k):\n curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits)\n assert curr_max > filter_value\n new_logits = numpy.delete(new_logits, idx)\n\n assert numpy.all(new_logits == filter_value)", + "completions": [ + "```python\n def test_apply_topk(\n self,\n ):\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9," + ], + "line": 21, + "token": 152, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(\n self,\n sources: List[Union[str, Path, ByteStream]],\n meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,\n ):\n parser = MarkdownIt(renderer_cls=RendererPlain)\n if self.table_to_single_line:\n parser.enable(\"table\")\n\n documents = []\n meta_list = normalize_metadata(meta=meta, sources_count=len(sources))\n\n for source, metadata in tqdm(\n zip(sources, meta_list),\n total=len(sources),\n desc=\"Converting markdown files to Documents\",\n disable=not self.progress_bar,\n ):\n try:\n bytestream = get_bytestream_from_source(source)\n except Exception as e:\n logger.warning(\"Could not read {source}. Skipping it. Error: {error}\", source=source, error=e)\n continue\n try:\n file_content = bytestream.data.decode(\"utf-8\")\n text = parser.render(file_content)\n except Exception as conversion_e:\n logger.warning(\n \"Failed to extract text from {source}. Skipping it. Error: {error}\",\n source=source,\n error=conversion_e,\n )\n continue\n\n merged_metadata = {**bytestream.meta, **metadata}\n document = Document(content=text, meta=merged_metadata)\n documents.append(document)\n\n return {\"documents\": documents}\n```\n###test function signature:\n```python\n\n def test_run_calls_normalize_metadata(self, test_files_path):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_calls_normalize_metadata", + "reference": "\n def test_run_calls_normalize_metadata(self, test_files_path):\n bytestream = ByteStream(data=b\"test\", meta={\"author\": \"test_author\", \"language\": \"en\"})\n\n converter = MarkdownToDocument()\n\n with patch(\"haystack.components.converters.markdown.normalize_metadata\") as normalize_metadata, patch(\n \"haystack.components.converters.markdown.MarkdownIt\"\n ):\n converter.run(sources=[bytestream, test_files_path / \"markdown\" / \"sample.md\"], meta={\"language\": \"it\"})\n\n # check that the metadata normalizer is called properly\n normalize_metadata.assert_called_with(meta={\"language\": \"it\"}, sources_count=2)", + "completions": [ + "```python\n def test_run_calls_normalize_metadata(self, test_files_path):\n # Arrange\n sources = [test_files_path / \"test_file.md\"]\n meta = {\"key\": \"value\"}\n\n # Act\n run(sources, meta)\n\n # Assert\n normalize_metadata.assert_called_once_with(meta=meta, sources_count=1)\n```\n" + ], + "line": 45, + "token": 311, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 16, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_config(config: object) -> None:\n if hasattr(config, \"kind\"):\n Config.running_config.classes[config.kind] = config # type: ignore\n Config.running_config.types[config.kind] = {}\n for field in fields(config): # type: ignore\n if hasattr(field, \"type\"):\n Config.running_config.types[config.kind][field.name] = optional_origin(field.type)\n else:\n raise RuntimeError(\"Config must have a 'kind' attribute\")\n```\n###test function signature:\n```python\n\ndef test_config():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_config", + "reference": "\ndef test_config():\n arg_parser = get_arg_parser()\n core_add_args(arg_parser)\n arg_parser.parse_known_args()\n cfg = Config(\"test\")\n cfg.add_config(ConfigTest)\n cfg.init_default_config()\n assert Config.dict() == {\n \"configtest\": {\n \"testvar1\": \"testing123\",\n \"testvar2\": 12345,\n \"testvar3\": {\n \"mydict\": {\"foo\": \"bar\", \"abc\": {\"def\": \"ghi\"}},\n \"myint\": 0,\n \"mystr\": \"Hello\",\n },\n }\n }\n cfg2 = Config(\"test2\")\n assert cfg.configtest.testvar1 == cfg2.configtest.testvar1\n assert cfg.configtest.testvar1 == Config.configtest.testvar1\n assert cfg.configtest.testvar2 == cfg2.configtest.testvar2\n assert cfg.configtest.testvar3 == cfg2.configtest.testvar3\n Config.configtest.testvar2 += 1\n assert cfg.configtest.testvar2 == 12346\n with pytest.raises(ConfigNotFoundError):\n Config.does_not_exist.foo = \"bar\"\n with pytest.raises(ConfigNotFoundError):\n cfg.does_not_exist.foo = \"bar\"", + "completions": [ + "```python\ndef test_config():\n from dataclasses import dataclass, field\n from typing import Optional\n\n @dataclass\n class Config:\n kind: str\n a: Optional[int] = field(default=None)\n b: Optional[int] = field(default=None)\n\n add_config(Config)\n\n assert Config.running_config.classes[\"Config\"] == Config\n assert Config.running_config.types[\"Config\"] == {\"a\": int, \"b\": int}\n```\n" + ], + "line": 19, + "token": 166, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 17, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_with_num_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_with_num_ref_line", + "reference": "\ndef test_nanoplot_out_with_num_ref_line():\n\n out_with_num_ref_line = _generate_nanoplot(**CASES[1])\n\n assert _is_nanoplot_output(out_with_num_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_with_num_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color," + ], + "line": 749, + "token": 7038, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 18, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def begin_typing(self, x=None):\n self.capturekeystrokes = True\n # Disable keypress shortcuts, which may otherwise cause the figure to\n # be saved, closed, etc., until the user stops typing. The way to\n # achieve this depends on whether toolmanager is in use.\n stack = ExitStack() # Register cleanup actions when user stops typing.\n self._on_stop_typing = stack.close\n toolmanager = getattr(\n self.ax.figure.canvas.manager, \"toolmanager\", None)\n if toolmanager is not None:\n # If using toolmanager, lock keypresses, and plan to release the\n # lock when typing stops.\n toolmanager.keypresslock(self)\n stack.callback(toolmanager.keypresslock.release, self)\n else:\n # If not using toolmanager, disable all keypress-related rcParams.\n # Avoid spurious warnings if keymaps are getting deprecated.\n with _api.suppress_matplotlib_deprecation_warning():\n stack.enter_context(mpl.rc_context(\n {k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n```\n###test function signature:\n```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_TextBox", + "reference": "\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n plt.rcParams._set(\"toolbar\", toolbar)\n\n submit_event = mock.Mock(spec=noop, return_value=None)\n text_change_event = mock.Mock(spec=noop, return_value=None)\n tool = widgets.TextBox(ax, '')\n tool.on_submit(submit_event)\n tool.on_text_change(text_change_event)\n\n assert tool.text == ''\n\n do_event(tool, '_click')\n\n tool.set_val('x**2')\n\n assert tool.text == 'x**2'\n assert text_change_event.call_count == 1\n\n tool.begin_typing()\n tool.stop_typing()\n\n assert submit_event.call_count == 2\n\n do_event(tool, '_click')\n do_event(tool, '_keypress', key='+')\n do_event(tool, '_keypress', key='5')\n\n assert text_change_event.call_count == 3", + "completions": [ + "```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n with ExitStack() as stack:\n stack.enter_context(mpl.rc_context({k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n ax.begin_typing()\n assert ax.capturekeystrokes is True\n assert toolbar.keypresslock.is_locked(ax)\n```\n" + ], + "line": 31, + "token": 258, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 19, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def read_events(\n db: PrefectDBInterface,\n session: AsyncSession,\n events_filter: EventFilter,\n limit: \"int | None\" = None,\n offset: \"int | None\" = None,\n) -> Sequence[ORMEvent]:\n # Always order by occurred timestamp, with placeholder for order direction\n order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc\n\n # Check if distinct fields are provided\n if distinct_fields := build_distinct_queries(events_filter):\n # Define window function\n window_function = (\n sa.func.row_number()\n .over(partition_by=distinct_fields, order_by=order(db.Event.occurred))\n .label(\"row_number\")\n )\n # Create a subquery with the window function\n subquery = (\n sa.select(db.Event, window_function)\n .where(\n sa.and_(\n *events_filter.build_where_clauses(db)\n ) # Ensure the same filters are applied here\n )\n .subquery()\n )\n\n # Alias the subquery for easier column references\n aliased_table = aliased(db.Event, subquery)\n\n # Create the final query from the subquery, filtering to get only rows with row_number = 1\n select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1)\n\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(subquery.c.occurred))\n\n else:\n # If no distinct fields are provided, create a query for all events\n select_events_query = sa.select(db.Event).where(\n sa.and_(*events_filter.build_where_clauses(db))\n )\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(db.Event.occurred))\n\n if limit is not None:\n limit = max(0, min(limit, events_filter.logical_limit))\n select_events_query = select_events_query.limit(limit=limit)\n if offset is not None:\n select_events_query = select_events_query.offset(offset=offset)\n\n logger.debug(\"Running PostgreSQL query: %s\", select_events_query)\n\n select_events_query_result = await session.execute(select_events_query)\n return select_events_query_result.scalars().unique().all()\n```\n###test function signature:\n```python\n\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_read_events", + "reference": "\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_1.id\n\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=3),\n until=pendulum.now(\"UTC\").subtract(days=1),\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_2.id\n\n # Read criteria should apply AND logic\n # Nothing should be returned for this resource\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n resource=EventResourceFilter(id=[\"prefect.garbage.foo\"]),\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 0", + "completions": [ + "```python\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n # Arrange\n events_filter = EventFilter(\n occurred_after=datetime.datetime(2022, 1, 1),\n occurred_before=datetime.datetime(2022, 1, 2),\n order=EventOrder.DESC,\n )\n # Act\n events = await read_events(\n db=self.db,\n session=session,\n events_filter=events_filter,\n )\n # Assert\n self.assertEqual(len(events), 2)\n self.assertEqual(events[0], event_2)\n self.assertEqual(events[1], event_1)\n```\n" + ], + "line": 58, + "token": 412, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 20, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n def test_evaluate_diff_dim(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_evaluate_diff_dim", + "reference": " def test_evaluate_diff_dim(self):\n x1 = np.arange(3, 10, 2)\n kde = mlab.GaussianKDE(x1)\n x2 = np.arange(3, 12, 2)\n y_expected = [\n 0.08797252, 0.11774109, 0.11774109, 0.08797252, 0.0370153\n ]\n y = kde.evaluate(x2)\n np.testing.assert_array_almost_equal(y, y_expected, 7)", + "completions": [ + "```python\n def test_evaluate_diff_dim(self):\n \"\"\"\n Test evaluate with different dimensions\n \"\"\"\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n\n # Generate a random dataset\n X, _ = make_blobs(n_samples=100, n_features=10, centers=5,\n random_state=42)\n\n # Standardize the data\n scaler = StandardScaler()\n X = scaler.fit_transform(X)\n\n # Create a GaussianMixture object\n gmm = GaussianMixture(n_components=5, covariance_type='full')\n gmm.fit(X)\n\n # Generate a random set of points\n n_points = 100\n points = np.random.rand(n_points, 10)\n\n # Evaluate the points\n result = gmm.evaluate(points)\n\n # Check that the result is correct\n self.assertEqual(result.shape, (n_points,))\n self.assertTrue(np.all(result >= 0))\n self.assertTrue(np.all(result <= 1))\n```\n" + ], + "line": 33, + "token": 273, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 21, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _differentiate(func, x, *, args=(), atol=None, rtol=None, maxiter=10,\n order=8, initial_step=0.5, step_factor=2.0,\n step_direction=0, preserve_shape=False, callback=None):\n # TODO (followup):\n # - investigate behavior at saddle points\n # - array initial_step / step_factor?\n # - multivariate functions?\n\n res = _differentiate_iv(func, x, args, atol, rtol, maxiter, order, initial_step,\n step_factor, step_direction, preserve_shape, callback)\n (func, x, args, atol, rtol, maxiter, order,\n h0, fac, hdir, preserve_shape, callback) = res\n\n # Initialization\n # Since f(x) (no step) is not needed for central differences, it may be\n # possible to eliminate this function evaluation. However, it's useful for\n # input validation and standardization, and everything else is designed to\n # reduce function calls, so let's keep it simple.\n temp = eim._initialize(func, (x,), args, preserve_shape=preserve_shape)\n func, xs, fs, args, shape, dtype = temp\n x, f = xs[0], fs[0]\n df = np.full_like(f, np.nan)\n # Ideally we'd broadcast the shape of `hdir` in `_elementwise_algo_init`, but\n # it's simpler to do it here than to generalize `_elementwise_algo_init` further.\n # `hdir` and `x` are already broadcasted in `_differentiate_iv`, so we know\n # that `hdir` can be broadcasted to the final shape.\n hdir = np.broadcast_to(hdir, shape).flatten()\n\n status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress\n nit, nfev = 0, 1 # one function evaluations performed above\n # Boolean indices of left, central, right, and (all) one-sided steps\n il = hdir < 0\n ic = hdir == 0\n ir = hdir > 0\n io = il | ir\n\n # Most of these attributes are reasonably obvious, but:\n # - `fs` holds all the function values of all active `x`. The zeroth\n # axis corresponds with active points `x`, the first axis corresponds\n # with the different steps (in the order described in\n # `_differentiate_weights`).\n # - `terms` (which could probably use a better name) is half the `order`,\n # which is always even.\n work = _RichResult(x=x, df=df, fs=f[:, np.newaxis], error=np.nan, h=h0,\n df_last=np.nan, error_last=np.nan, h0=h0, fac=fac,\n atol=atol, rtol=rtol, nit=nit, nfev=nfev,\n status=status, dtype=dtype, terms=(order+1)//2,\n hdir=hdir, il=il, ic=ic, ir=ir, io=io)\n # This is the correspondence between terms in the `work` object and the\n # final result. In this case, the mapping is trivial. Note that `success`\n # is prepended automatically.\n res_work_pairs = [('status', 'status'), ('df', 'df'), ('error', 'error'),\n ('nit', 'nit'), ('nfev', 'nfev'), ('x', 'x')]\n\n def pre_func_eval(work):\n \"\"\"Determine the abscissae at which the function needs to be evaluated.\n\n See `_differentiate_weights` for a description of the stencil (pattern\n of the abscissae).\n\n In the first iteration, there is only one stored function value in\n `work.fs`, `f(x)`, so we need to evaluate at `order` new points. In\n subsequent iterations, we evaluate at two new points. Note that\n `work.x` is always flattened into a 1D array after broadcasting with\n all `args`, so we add a new axis at the end and evaluate all point\n in one call to the function.\n\n For improvement:\n - Consider measuring the step size actually taken, since `(x + h) - x`\n is not identically equal to `h` with floating point arithmetic.\n - Adjust the step size automatically if `x` is too big to resolve the\n step.\n - We could probably save some work if there are no central difference\n steps or no one-sided steps.\n \"\"\"\n n = work.terms # half the order\n h = work.h # step size\n c = work.fac # step reduction factor\n d = c**0.5 # square root of step reduction factor (one-sided stencil)\n # Note - no need to be careful about dtypes until we allocate `x_eval`\n\n if work.nit == 0:\n hc = h / c**np.arange(n)\n hc = np.concatenate((-hc[::-1], hc))\n else:\n hc = np.asarray([-h, h]) / c**(n-1)\n\n if work.nit == 0:\n hr = h / d**np.arange(2*n)\n else:\n hr = np.asarray([h, h/d]) / c**(n-1)\n\n n_new = 2*n if work.nit == 0 else 2 # number of new abscissae\n x_eval = np.zeros((len(work.hdir), n_new), dtype=work.dtype)\n il, ic, ir = work.il, work.ic, work.ir\n x_eval[ir] = work.x[ir, np.newaxis] + hr\n x_eval[ic] = work.x[ic, np.newaxis] + hc\n x_eval[il] = work.x[il, np.newaxis] - hr\n return x_eval\n\n def post_func_eval(x, f, work):\n \"\"\" Estimate the derivative and error from the function evaluations\n\n As in `pre_func_eval`: in the first iteration, there is only one stored\n function value in `work.fs`, `f(x)`, so we need to add the `order` new\n points. In subsequent iterations, we add two new points. The tricky\n part is getting the order to match that of the weights, which is\n described in `_differentiate_weights`.\n\n For improvement:\n - Change the order of the weights (and steps in `pre_func_eval`) to\n simplify `work_fc` concatenation and eliminate `fc` concatenation.\n - It would be simple to do one-step Richardson extrapolation with `df`\n and `df_last` to increase the order of the estimate and/or improve\n the error estimate.\n - Process the function evaluations in a more numerically favorable\n way. For instance, combining the pairs of central difference evals\n into a second-order approximation and using Richardson extrapolation\n to produce a higher order approximation seemed to retain accuracy up\n to very high order.\n - Alternatively, we could use `polyfit` like Jacobi. An advantage of\n fitting polynomial to more points than necessary is improved noise\n tolerance.\n \"\"\"\n n = work.terms\n n_new = n if work.nit == 0 else 1\n il, ic, io = work.il, work.ic, work.io\n\n # Central difference\n # `work_fc` is *all* the points at which the function has been evaluated\n # `fc` is the points we're using *this iteration* to produce the estimate\n work_fc = (f[ic, :n_new], work.fs[ic, :], f[ic, -n_new:])\n work_fc = np.concatenate(work_fc, axis=-1)\n if work.nit == 0:\n fc = work_fc\n else:\n fc = (work_fc[:, :n], work_fc[:, n:n+1], work_fc[:, -n:])\n fc = np.concatenate(fc, axis=-1)\n\n # One-sided difference\n work_fo = np.concatenate((work.fs[io, :], f[io, :]), axis=-1)\n if work.nit == 0:\n fo = work_fo\n else:\n fo = np.concatenate((work_fo[:, 0:1], work_fo[:, -2*n:]), axis=-1)\n\n work.fs = np.zeros((len(ic), work.fs.shape[-1] + 2*n_new))\n work.fs[ic] = work_fc\n work.fs[io] = work_fo\n\n wc, wo = _differentiate_weights(work, n)\n work.df_last = work.df.copy()\n work.df[ic] = fc @ wc / work.h\n work.df[io] = fo @ wo / work.h\n work.df[il] *= -1\n\n work.h /= work.fac\n work.error_last = work.error\n # Simple error estimate - the difference in derivative estimates between\n # this iteration and the last. This is typically conservative because if\n # convergence has begin, the true error is much closer to the difference\n # between the current estimate and the *next* error estimate. However,\n # we could use Richarson extrapolation to produce an error estimate that\n # is one order higher, and take the difference between that and\n # `work.df` (which would just be constant factor that depends on `fac`.)\n work.error = abs(work.df - work.df_last)\n\n def check_termination(work):\n \"\"\"Terminate due to convergence, non-finite values, or error increase\"\"\"\n stop = np.zeros_like(work.df).astype(bool)\n\n i = work.error < work.atol + work.rtol*abs(work.df)\n work.status[i] = eim._ECONVERGED\n stop[i] = True\n\n if work.nit > 0:\n i = ~((np.isfinite(work.x) & np.isfinite(work.df)) | stop)\n work.df[i], work.status[i] = np.nan, eim._EVALUEERR\n stop[i] = True\n\n # With infinite precision, there is a step size below which\n # all smaller step sizes will reduce the error. But in floating point\n # arithmetic, catastrophic cancellation will begin to cause the error\n # to increase again. This heuristic tries to avoid step sizes that are\n # too small. There may be more theoretically sound approaches for\n # detecting a step size that minimizes the total error, but this\n # heuristic seems simple and effective.\n i = (work.error > work.error_last*10) & ~stop\n work.status[i] = _EERRORINCREASE\n stop[i] = True\n\n return stop\n\n def post_termination_check(work):\n return\n\n def customize_result(res, shape):\n return shape\n\n return eim._loop(work, callback, shape, maxiter, func, args, dtype,\n pre_func_eval, post_func_eval, check_termination,\n post_termination_check, customize_result, res_work_pairs,\n preserve_shape)\n```\n###test function signature:\n```python\n\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_maxiter_callback", + "reference": "\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n x = 0.612814\n dist = stats.norm()\n maxiter = 3\n\n def f(x):\n res = dist.cdf(x)\n return res\n\n default_order = 8\n res = differentiate(f, x, maxiter=maxiter, rtol=1e-15)\n assert not np.any(res.success)\n assert np.all(res.nfev == default_order + 1 + (maxiter - 1)*2)\n assert np.all(res.nit == maxiter)\n\n def callback(res):\n callback.iter += 1\n callback.res = res\n assert hasattr(res, 'x')\n assert res.df not in callback.dfs\n callback.dfs.add(res.df)\n assert res.status == eim._EINPROGRESS\n if callback.iter == maxiter:\n raise StopIteration\n callback.iter = -1 # callback called once before first iteration\n callback.res = None\n callback.dfs = set()\n\n res2 = differentiate(f, x, callback=callback, rtol=1e-15)\n # terminating with callback is identical to terminating due to maxiter\n # (except for `status`)\n for key in res.keys():\n if key == 'status':\n assert res[key] == eim._ECONVERR\n assert callback.res[key] == eim._EINPROGRESS\n assert res2[key] == eim._ECALLBACK\n else:\n assert res2[key] == callback.res[key] == res[key]", + "completions": [ + "```python\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n def callback(work):\n if work.nit == 0:\n work.maxiter = 1\n else:\n work.maxiter = 0\n res = _differentiate(lambda x: x**2, 0, maxiter=2, callback=callback)\n self.assertEqual(res.nit, 1)\n self.assertEqual(res.status, 0)\n```\n" + ], + "line": 187, + "token": 1981, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 22, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef get_standard_metadata(\n pyproject_dict: Mapping[str, Any],\n settings: ScikitBuildSettings,\n) -> StandardMetadata:\n new_pyproject_dict = dict(pyproject_dict)\n # Handle any dynamic metadata\n for field, provider, config in load_dynamic_metadata(settings.metadata):\n if provider is None:\n msg = f\"{field} is missing provider\"\n raise KeyError(msg)\n if field not in pyproject_dict.get(\"project\", {}).get(\"dynamic\", []):\n msg = f\"{field} is not in project.dynamic\"\n raise KeyError(msg)\n new_pyproject_dict[\"project\"][field] = provider.dynamic_metadata(field, config)\n new_pyproject_dict[\"project\"][\"dynamic\"].remove(field)\n\n metadata = StandardMetadata.from_pyproject(new_pyproject_dict)\n\n # For scikit-build-core < 0.5, we keep the normalized name for back-compat\n if settings.minimum_version is not None and settings.minimum_version < Version(\n \"0.5\"\n ):\n metadata.name = metadata.canonical_name\n\n # The description field is required to be one line. Instead of merging it\n # or cutting off subsequent lines (setuptools), we throw a nice error.\n # But we didn't validate before 0.9.\n if (\n settings.minimum_version is None or settings.minimum_version >= Version(\"0.9\")\n ) and \"\\n\" in (metadata.description or \"\"):\n msg = \"Multiple lines in project.description are not supported; this is supposed to be a one line summary\"\n raise ValueError(msg)\n\n return metadata\n```\n###test function signature:\n```python\n\ndef test_plugin_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_plugin_metadata", + "reference": "\ndef test_plugin_metadata():\n reason_msg = (\n \"install hatch-fancy-pypi-readme and setuptools-scm to test the \"\n \"dynamic metadata plugins\"\n )\n pytest.importorskip(\"hatch_fancy_pypi_readme\", reason=reason_msg)\n pytest.importorskip(\"setuptools_scm\", reason=reason_msg)\n if shutil.which(\"git\") is None:\n pytest.skip(\"git is not installed\")\n\n shutil.copy(\"plugin_project.toml\", \"pyproject.toml\")\n\n subprocess.run([\"git\", \"init\", \"--initial-branch=main\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.name\", \"bot\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.email\", \"bot@scikit-build.org\"], check=True)\n subprocess.run([\"git\", \"add\", \"pyproject.toml\"], check=True)\n subprocess.run([\"git\", \"commit\", \"-m\", \"initial commit\"], check=True)\n subprocess.run([\"git\", \"tag\", \"v0.1.0\", \"-m\", \"initial commint\"], check=True)\n\n with Path(\"pyproject.toml\").open(\"rb\") as ft:\n pyproject = tomllib.load(ft)\n settings_reader = SettingsReader(pyproject, {}, state=\"metadata_wheel\")\n settings = settings_reader.settings\n\n settings_reader.validate_may_exit()\n\n metadata = get_standard_metadata(pyproject, settings)\n\n assert str(metadata.version) == \"0.1.0\"\n assert metadata.readme == pyproject_metadata.Readme(\n \"Fragment #1Fragment #2\", None, \"text/x-rst\"\n )\n\n assert set(GetRequires().dynamic_metadata()) == {\n \"hatch-fancy-pypi-readme>=22.3\",\n \"setuptools-scm\",\n }", + "completions": [ + "```python\n\ndef test_plugin_metadata():\n from skbuild import __version__ as skbuild_version\n from skbuild.settings import ScikitBuildSettings\n from skbuild.metadata import StandardMetadata\n from skbuild.utils import load_dynamic_metadata\n from skbuild.version import Version\n\n pyproject_dict = {\n \"build-system\": {\n \"requires\": [\"setuptools>=40.8.0\", \"wheel\"],\n \"build-backend\": \"setuptools.build_meta\",\n },\n \"project\": {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n \"description\": \"A build system for Python packages\",\n \"dynamic\": [\"name\", \"version\"],\n },\n }\n\n settings = ScikitBuildSettings(\n minimum_version=Version(\"0.1.0\"),\n metadata=[\n (\"name\", None, None),\n (\"version\", None, None),\n ],\n )\n\n metadata = get_standard_metadata(pyproject_dict, settings)\n\n assert metadata.name == \"scikit-build\"\n assert metadata.version == \"0.1.0\"\n assert metadata.description == \"A build system for Python packages\"\n\n assert metadata.canonical_name == \"scikit-build\"\n assert metadata.canonical_version == \"0.1.0\"\n\n assert metadata.dynamic_metadata == {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n }\n\n assert metadata.dynamic_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_providers == [None, None]\n\n assert metadata.dynamic_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_metadata_providers == [None, None]\n\n assert metadata.dynamic_metadata_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields_by_provider == {None: [\"name\", \"version\"]}\n\n assert metadata.dynamic_metadata_providers_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_configs_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field == {\n (None, \"name\"): \"scikit-build\",\n (None, \"version\"): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider == {\n (\"name\", None): \"scikit-build\",\n (\"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider == {\n (\"name\", None): None,\n (\"version\", None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config == {\n (None, \"name\", None): \"scikit-build\",\n (None, \"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config == {\n (\"name\", None, None): \"scikit-build\",\n (\"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider_and_config == {\n (\"name\", None, None): None,\n (\"version\", None, None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config_and_provider == {\n (None, \"name\", None, None): \"scikit-build\",\n (None, \"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config_and_provider == {\n (\"name\", None, None" + ], + "line": 40, + "token": 347, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 23, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _get_datetime_format(value):\n if not isinstance(value, pd.Series):\n value = pd.Series(value)\n\n value = value[~value.isna()]\n value = value.astype(str).to_numpy()\n\n return _guess_datetime_format_for_array(value)\n```\n###test function signature:\n```python\ndef test__get_datetime_format():\n # Setup\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test__get_datetime_format", + "reference": "def test__get_datetime_format():\n # Setup\n string_value = '2021-02-02'\n list_value = [np.nan, '2021-02-02']\n series_value = pd.Series(['2021-02-02T12:10:59'])\n\n # Run\n string_out = _get_datetime_format(string_value)\n list_out = _get_datetime_format(list_value)\n series_out = _get_datetime_format(series_value)\n\n # Assert\n expected_output = '%Y-%m-%d'\n assert string_out == expected_output\n assert list_out == expected_output\n assert series_out == '%Y-%m-%dT%H:%M:%S'", + "completions": [ + "```python\ndef test__get_datetime_format():\n # Setup\n value = pd.Series([\"2020-01-01\", \"2020-01-02\", \"2020-01-03\"])\n expected = \"%Y-%m-%d\"\n\n # Exercise\n actual = _get_datetime_format(value)\n\n # Verify\n assert actual == expected\n```\n" + ], + "line": 17, + "token": 124, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 24, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def load(self) -> None:\n super().load()\n\n _path = str(\n importlib_resources.files(\"distilabel\")\n / \"steps\"\n / \"tasks\"\n / \"templates\"\n / \"ultrafeedback\"\n / f\"{self.aspect}.jinja2\"\n )\n\n self._template = Template(open(_path).read())\n```\n###test function signature:\n```python\n\n def test_process_with_complex_aspect(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_process_with_complex_aspect", + "reference": "\n def test_process_with_complex_aspect(self) -> None:\n pipeline = Pipeline(name=\"unit-test-pipeline\")\n llm = UltraFeedbackLLM()\n\n task = UltraFeedback(\n name=\"ultrafeedback\",\n aspect=\"truthfulness\",\n llm=llm,\n pipeline=pipeline,\n )\n task.load()\n\n assert next(\n task.process([{\"instruction\": \"test\", \"generations\": [\"A\", \"B\"]}])\n ) == [\n {\n \"instruction\": \"test\",\n \"generations\": [\"A\", \"B\"],\n \"types\": [1, 2],\n \"rationales\": [\"text\", \"text\"],\n \"ratings\": [1, 2],\n \"rationales-for-ratings\": [\"text\", \"text\"],\n \"model_name\": \"ultrafeedback-model\",\n }\n ]", + "completions": [ + "```python\n def test_process_with_complex_aspect(self) -> None:\n self.aspect = \"complex_aspect\"\n self.load()\n self.assertEqual(\n self._template.render(\n {\n \"aspect\": self.aspect,\n \"aspect_name\": \"Complex Aspect\",\n \"aspect_description\": \"This is a complex aspect\",\n \"aspect_type\": \"complex\",\n \"aspect_value\": \"1\",\n \"aspect_weight\": \"1\",\n \"aspect_color\": \"red\",\n \"aspect_icon\": \"icon-circle\",\n \"aspect_icon_color\": \"red\",\n \"aspect_icon_bg_color\": \"white\",\n \"aspect_icon_size\": \"20\",\n \"aspect_icon_border_color\": \"red\",\n \"aspect_icon_border_width\": \"2\",\n \"aspect_icon_border_radius\": \"50\",\n \"aspect_icon_padding\": \"5\",\n \"aspect_icon_margin\": \"5\",\n \"aspect_icon_bg_size\": \"20\",\n \"aspect_icon_bg_border_color\": \"red\",\n \"aspect_icon_bg_border_width\": \"2\",\n \"aspect_icon_bg_border_radius\": \"50\",\n \"aspect_icon_bg_padding\": \"5\",\n \"aspect_icon_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_" + ], + "line": 21, + "token": 145, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 25, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _setup_cuda_device(device: torch.device) -> torch.device:\n local_rank = _get_local_rank() or 0\n if device.index is None:\n device = torch.device(type=\"cuda\", index=local_rank)\n\n # Ensure index is available before setting device\n if device.index >= torch.cuda.device_count():\n raise RuntimeError(\n \"The local rank is larger than the number of available GPUs.\"\n )\n\n torch.cuda.set_device(device)\n return device\n```\n###test function signature:\n```python\n\n def test_get_gpu_device(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_gpu_device", + "reference": "\n def test_get_gpu_device(self) -> None:\n device_idx = torch.cuda.device_count() - 1\n assert device_idx >= 0\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(device_idx)}, clear=True):\n device = get_device()\n assert device.type == \"cuda\"\n assert device.index == device_idx\n assert device.index == torch.cuda.current_device()\n\n # Test that we raise an error if the device index is specified on distributed runs\n if device_idx > 0:\n with pytest.raises(\n RuntimeError,\n match=f\"Device specified is cuda:0 but was assigned cuda:{device_idx}\",\n ):\n device = get_device(\"cuda:0\")\n\n invalid_device_idx = device_idx + 10\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(invalid_device_idx)}):\n with pytest.raises(\n RuntimeError,\n match=\"The local rank is larger than the number of available GPUs\",\n ):\n device = get_device(\"cuda\")\n\n # Test that we fall back to 0 if LOCAL_RANK is not specified\n device = torch.device(_get_device_type_from_env())\n device = _setup_cuda_device(device)\n assert device.type == \"cuda\"\n assert device.index == 0\n assert device.index == torch.cuda.current_device()", + "completions": [ + "```python\n def test_get_gpu_device(self) -> None:\n device = torch.device(type=\"cuda\", index=0)\n self.assertEqual(device, _setup_cuda_device(device))\n```\n" + ], + "line": 21, + "token": 157, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):\n local_dir = settings.home_dir / settings.cache_dir / \"models\" / model_id\n client.load_artifact(local_dir, posixpath.join(\"models\", model_id))\n meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)\n # internal worker case, no token based http client\n if not local_dir.exists():\n raise RuntimeError(f\"Cannot find existing model {project_key}.{model_id} in {local_dir}\")\n with (Path(local_dir) / META_FILENAME).open(encoding=\"utf-8\") as f:\n file_meta = yaml.load(f, Loader=yaml.Loader)\n classification_labels = cls.cast_labels(meta_response)\n meta = ModelMeta(\n name=meta_response.name,\n description=meta_response.description,\n model_type=SupportedModelTypes[meta_response.modelType],\n feature_names=meta_response.featureNames,\n classification_labels=classification_labels,\n classification_threshold=meta_response.threshold,\n loader_module=file_meta[\"loader_module\"],\n loader_class=file_meta[\"loader_class\"],\n )\n\n model_py_ver = (\n tuple(meta_response.languageVersion.split(\".\")) if \"PYTHON\" == meta_response.language.upper() else None\n )\n\n clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver)\n\n constructor_params = meta.__dict__\n constructor_params[\"id\"] = str(model_id)\n\n del constructor_params[\"loader_module\"]\n del constructor_params[\"loader_class\"]\n\n model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params)\n return model\n```\n###test function signature:\n```python\n\ndef test_model_download(request):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_model_download", + "reference": "\ndef test_model_download(request):\n model: BaseModel = request.getfixturevalue(\"german_credit_model\")\n project_key = str(uuid.uuid4())\n\n with utils.MockedProjectCacheDir():\n with utils.MockedClient(mock_all=False) as (client, mr):\n # The model needs to request files\n requested_urls = []\n requested_urls.extend(utils.register_uri_for_model_meta_info(mr, model, project_key))\n requested_urls.extend(\n utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True)\n )\n\n downloaded_model = BaseModel.download(client=client, project_key=project_key, model_id=str(model.id))\n\n for requested_url in requested_urls:\n assert utils.is_url_requested(mr.request_history, requested_url)\n\n assert downloaded_model.id == model.id\n assert downloaded_model.meta == model.meta", + "completions": [ + "```python\ndef test_model_download(request):\n from pathlib import Path\n from unittest.mock import MagicMock\n from giskard.client import GiskardClient\n from giskard.models import ModelMetaInfo, ModelMeta, SupportedModelTypes\n from giskard.utils import cast_labels\n from giskard.models.base import Model\n from giskard.models.classification import ClassificationModel\n from giskard.models.regression import RegressionModel\n from giskard.models.clustering import ClusteringModel\n from giskard.models.dim_reduction import DimReductionModel\n from giskard.models.anomaly_detection import AnomalyDetectionModel\n from giskard.models.time_series import TimeSeriesModel\n from giskard.models.text_classification import TextClassificationModel\n from giskard.models.text_recommendation import TextRecommendationModel\n from giskard.models.image_classification import ImageClassificationModel\n from giskard.models.image_recommendation import ImageRecommendationModel\n from giskard.models.image_segmentation import ImageSegmentationModel\n from giskard.models.image_generation import ImageGenerationModel\n from giskard.models.image_style_transfer import ImageStyleTransferModel\n from giskard.models.image_super_resolution import ImageSuperResolutionModel\n from giskard.models.image_denoising import ImageDenoisingModel\n from giskard.models.image_colorization import ImageColorizationModel\n from giskard.models.image_deraining import ImageDerainingModel\n from giskard.models.image_deblurring import ImageDeblurringModel\n from giskard.models.image_edge_detection import ImageEdgeDetectionModel\n from giskard.models.image_face_detection import ImageFaceDetectionModel\n from giskard.models.image_face_recognition import ImageFaceRecognitionModel\n from giskard.models.image_object_detection import ImageObjectDetectionModel\n from giskard.models.image_object_recognition import ImageObjectRecognitionModel\n from giskard.models.image_segmentation_mask_rcnn import ImageSegmentationMaskRCNNModel\n from giskard.models.image_segmentation_yolo import ImageSegmentationYOLOModel\n from giskard.models.image_segmentation_faster_rcnn import ImageSegmentationFasterRCNNModel\n from giskard.models.image_segmentation_unet import ImageSegmentationUNetModel\n from giskard.models.image_segmentation_deeplab import ImageSegmentationDeepLabModel\n from giskard.models.image_segmentation_fcn import ImageSegmentationFCNModel\n from giskard.models.image_segmentation_pspnet import ImageSegmentationPSPNetModel\n from giskard.models.image_segmentation_linknet import ImageSegmentationLinkNetModel\n from giskard.models.image_segmentation_unet_plus import ImageSegmentationUNetPlusModel\n from giskard.models.image_segmentation_deeplab_v3 import ImageSegmentationDeepLabV3Model\n from giskard.models.image_segmentation_deeplab_v3_plus import ImageSegmentationDeepLabV3PlusModel\n from giskard.models.image_segmentation_fcn_resnet101 import ImageSegmentationFCNResNet101Model\n from giskard.models.image_segmentation_fcn_resnet50 import ImageSegmentationFCNResNet50Model\n from giskard.models.image_segmentation_fcn_mobilenet_v2 import ImageSegmentationFCNMobileNetV2Model\n from giskard.models.image_segmentation_fcn_mobilenet_v3 import ImageSegmentationFCNMobileNetV3Model\n from giskard.models.image_segmentation_fcn_efficientnet_b" + ], + "line": 40, + "token": 316, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 27, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\ndef test_rounded_date_columns():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rounded_date_columns", + "reference": "def test_rounded_date_columns():\n df = pl.DataFrame(\n [\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.000000\", \"exp\": True},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.123456\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:01.000000\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00\", \"exp\": True},\n ]\n )\n\n df = df.with_columns(df[\"a\"].str.to_date().alias(\"a_dt\"))\n actual = columns_equal(df[\"a_dt\"], df[\"b\"])\n expected = df[\"exp\"]\n assert_series_equal(actual, expected, check_names=False)", + "completions": [ + "```python\ndef test_rounded_date_columns():\n # given\n col_1 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n col_2 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n\n # when\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n\n # then\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 28, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef annotations_to_instances(annos, image_size, mask_format=\"polygon\"):\n boxes = [BoxMode.convert(obj[\"bbox\"], obj[\"bbox_mode\"], BoxMode.XYXY_ABS) for obj in annos]\n target = Instances(image_size)\n boxes = target.gt_boxes = Boxes(boxes)\n boxes.clip(image_size)\n\n classes = [obj[\"category_id\"] for obj in annos]\n classes = torch.tensor(classes, dtype=torch.int64)\n target.gt_classes = classes\n\n if len(annos) and \"segmentation\" in annos[0]:\n segms = [obj[\"segmentation\"] for obj in annos]\n if mask_format == \"polygon\":\n masks = PolygonMasks(segms)\n else:\n assert mask_format == \"bitmask\", mask_format\n masks = []\n for segm in segms:\n if isinstance(segm, list):\n # polygon\n masks.append(polygons_to_bitmask(segm, *image_size))\n elif isinstance(segm, dict):\n # COCO RLE\n masks.append(mask_util.decode(segm))\n elif isinstance(segm, np.ndarray):\n assert segm.ndim == 2, \"Expect segmentation of 2 dimensions, got {}.\".format(\n segm.ndim\n )\n # mask array\n masks.append(segm)\n else:\n raise ValueError(\n \"Cannot convert segmentation of type '{}' to BitMasks!\"\n \"Supported types are: polygons as list[list[float] or ndarray],\"\n \" COCO-style RLE as a dict, or a full-image segmentation mask \"\n \"as a 2D ndarray.\".format(type(segm))\n )\n # torch.from_numpy does not support array with negative stride.\n masks = BitMasks(\n torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])\n )\n target.gt_masks = masks\n\n if len(annos) and \"keypoints\" in annos[0]:\n kpts = [obj.get(\"keypoints\", []) for obj in annos]\n target.gt_keypoints = Keypoints(kpts)\n\n return target\n```\n###test function signature:\n```python\n\n def test_transform_simple_annotation(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_simple_annotation", + "reference": "\n def test_transform_simple_annotation(self):\n transforms = T.TransformList([T.HFlipTransform(400)])\n anno = {\n \"bbox\": np.asarray([10, 10, 200, 300]),\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 3,\n \"segmentation\": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],\n }\n\n output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))\n self.assertTrue(np.allclose(output[\"bbox\"], [200, 10, 390, 300]))\n self.assertEqual(len(output[\"segmentation\"]), len(anno[\"segmentation\"]))\n self.assertTrue(np.allclose(output[\"segmentation\"][0], [390, 10, 300, 100, 300, 10]))\n\n detection_utils.annotations_to_instances([output, output], (400, 400))", + "completions": [ + "```python\n def test_transform_simple_annotation(self):\n annos = [\n {\n \"bbox\": [10, 10, 100, 100],\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 0,\n \"segmentation\": [\n [10, 10, 50, 50, 90, 90, 100, 100],\n [20, 20, 50, 50, 80, 80, 100, 100],\n ],\n \"keypoints\": [10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n }\n ]\n image_size = (100, 100)\n target = annotations_to_instances(annos, image_size)\n assert target.gt_boxes.tensor.tolist() == [[10, 10, 100, 100]]\n assert target.gt_classes.tolist() == [0]\n assert target.gt_masks.polygons().tolist() == [\n [[10, 10], [50, 50], [90, 90], [100, 100]],\n [[20, 20], [50, 50], [80, 80], [100, 100]],\n ]\n assert target.gt_keypoints.tensor.tolist() == [[10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n```\n" + ], + "line": 54, + "token": 437, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 29, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_simple(seed):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_simple", + "reference": "\ndef test_sabre_simple(seed):\n placer = Trivial()\n circ = Circuit(5)\n circ.add(gates.CZ(0, 1))\n initial_layout = placer(circ)\n router = Sabre(connectivity=star_connectivity(), seed=seed)\n routed_circuit, final_map = router(circuit=circ, initial_layout=initial_layout)\n assert router.added_swaps == 1\n assert final_map == {\"q0\": 2, \"q1\": 1, \"q2\": 0, \"q3\": 3, \"q4\": 4}\n assert routed_circuit.queue[0].qubits == (0, 2)\n assert isinstance(routed_circuit.queue[0], gates.SWAP)\n assert isinstance(routed_circuit.queue[1], gates.CZ)\n assert_connectivity(circuit=routed_circuit, connectivity=star_connectivity())\n assert_circuit_equivalence(\n original_circuit=circ,\n transpiled_circuit=routed_circuit,\n final_map=final_map,\n initial_map=initial_layout,\n )", + "completions": [ + "```python\ndef test_sabre_simple(seed):\n import random\n random.seed(seed)\n import networkx as nx\n from pyquil.quil import Program\n from pyquil.api import QVMConnection\n from pyquil.gates import CNOT, H, MEASURE\n from pyquil.quilbase import Declare\n from pyquil.quilbase import DefGate\n from pyquil.quilbase import DefCircuit\n from pyquil.quilbase import Gate\n from pyquil.quilbase import Circuit\n from pyquil.quilbase import QubitPlaceholder\n from pyquil.quilbase import MemoryReference\n from pyquil.quilbase import Parameter\n from pyquil.quilbase import Pragma\n from pyquil.quilbase import PragmaDeclare\n from pyquil.quilbase import PragmaDefGate\n from pyquil.quilbase import PragmaDefCircuit\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefQubit\n from pyquil.quilbase import PragmaDefMemory\n from pyquil.quilbase import PragmaDefRegister\n from pyquil.quilbase import PragmaDefLabel\n from pyquil.quilbase import PragmaDefFunction\n from pyquil.quilbase import PragmaDefProcedure\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyqu" + ], + "line": 27, + "token": 242, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 30, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef unique_patterns_2d(\n agrid: NDArray[np.int64], ksize: int, periodic_input: bool\n) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:\n assert ksize >= 1\n if periodic_input:\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n else:\n # TODO: implement non-wrapped image handling\n # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n\n patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(\n agrid,\n (\n agrid.shape[0] - ksize + 1,\n agrid.shape[1] - ksize + 1,\n ksize,\n ksize,\n *agrid.shape[2:],\n ),\n agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],\n writeable=False,\n )\n patch_codes = hash_downto(patches, 2)\n uc, ui = np.unique(patch_codes, return_index=True)\n locs = np.unravel_index(ui, patch_codes.shape)\n up: NDArray[np.int64] = patches[locs[0], locs[1]]\n ids: NDArray[np.int64] = np.vectorize(\n {code: ind for ind, code in enumerate(uc)}.get\n )(patch_codes)\n return ids, up, patch_codes\n```\n###test function signature:\n```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_patterns_2d", + "reference": "\ndef test_unique_patterns_2d(img_redmaze) -> None:\n img = img_redmaze\n tile_size = 1\n pattern_width = 2\n _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(\n img, tile_size\n )\n\n (\n _patterns_in_grid,\n pattern_contents_list,\n patch_codes,\n ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)\n assert patch_codes[1][2] == 4867810695119132864\n assert pattern_contents_list[7][1][1] == 8253868773529191888", + "completions": [ + "```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n ids, up, patch_codes = unique_patterns_2d(img_redmaze, 3, True)\n assert ids.shape == (10,)\n assert up.shape == (10, 3, 3)\n assert patch_codes.shape == (10,)\n```\n" + ], + "line": 48, + "token": 425, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 31, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef serialize(data: Any, encoding: str | None = None) -> str | bytes:\n serialized: bytes | str\n\n if encoding is not None:\n serialized = yaml.dump(data, encoding=encoding, Dumper=AiiDADumper)\n else:\n serialized = yaml.dump(data, Dumper=AiiDADumper)\n\n return serialized\n```\n###test function signature:\n```python\ndef test_mixed_attribute_normal_dict():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mixed_attribute_normal_dict", + "reference": "def test_mixed_attribute_normal_dict():\n from aiida.common.extendeddicts import AttributeDict\n\n # Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively\n dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}\n attribute_dict = AttributeDict(dictionary)\n\n # Now add a normal dictionary in the attribute dictionary\n attribute_dict['nested']['normal'] = {'a': 2}\n\n serialized = serialize.serialize(attribute_dict)\n deserialized = serialize.deserialize_unsafe(serialized)\n\n assert attribute_dict, deserialized", + "completions": [ + "```python\ndef test_mixed_attribute_normal_dict():\n data = {\n \"a\": 1,\n \"b\": 2,\n \"c\": 3,\n \"d\": 4,\n \"e\": 5,\n \"f\": 6,\n \"g\": 7,\n \"h\": 8,\n \"i\": 9,\n \"j\": 10,\n \"k\": 11,\n \"l\": 12,\n \"m\": 13,\n \"n\": 14,\n \"o\": 15,\n \"p\": 16,\n \"q\": 17,\n \"r\": 18,\n \"s\": 19,\n \"t\": 20,\n \"u\": 21,\n \"v\": 22,\n \"w\": 23,\n \"x\": 24,\n \"y\": 25,\n \"z\": 26,\n }\n assert serialize(data) == \"\"\"a: 1\nb: 2\nc: 3\nd: 4\ne: 5\nf: 6\ng: 7\nh: 8\ni: 9\nj: 10\nk: 11\nl: 12\nm: 13\nn: 14\no: 15\np: 16\nq: 17\nr: 18\ns: 19\nt: 20\nu: 21\nv: 22\nw: 23\nx: 24\ny: 25\nz: 26\n\"\"\"\n```\n" + ], + "line": 17, + "token": 133, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected_against_alias", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n http_client.select_model(model_id=\"yolov8n-640\")\n\n # when\n result = http_client.unload_model(model_id=\"yolov8n-640\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"coco/3\",\n }\n assert (\n http_client.selected_model is None\n ), \"Even when alias is in use - selected model should be emptied\"", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n requests_mock.post(\n f\"{API_URL}/model/remove\",\n json={\n \"model_id\": \"test_model_id\",\n },\n headers=DEFAULT_HEADERS,\n )\n client = RoboflowClient(api_key=\"test_api_key\", api_secret=\"test_api_secret\")\n # when\n response = client.unload_model(model_id=\"test_model_alias\")\n # then\n assert response.model_id == \"test_model_id\"\n assert response.model_name == \"test_model_name\"\n assert response.model_version == \"test_model_version\"\n assert response.model_type == \"test_model_type\"\n assert response.model_status == \"test_model_status\"\n assert response.model_created_at == \"test_model_created_at\"\n assert response.model_updated_at == \"test_model_updated_at\"\n assert response.model_tags == \"test_model_tags\"\n assert response.model_description == \"test_model_description\"\n assert response.model_url == \"test_model_url\"\n assert response.model_thumbnail_url == \"test_model_thumbnail_url\"\n assert response.model_training_data_url == \"test_model_training_data_url\"\n assert response.model_training_data_type == \"test_model_training_data_type\"\n assert response.model_training_data_size == \"test_model_training_data_size\"\n assert response.model_training_data_description == \"test_model_training_data_description\"\n assert response.model_training_data_tags == \"test_model_training_data_tags\"\n assert response.model_training_data_created_at == \"test_model_training_data_created_at\"\n assert response.model_training_data_updated_at == \"test_model_training_data_updated_at\"\n assert response.model_training_data_status == \"test_model_training_data_status\"\n assert response.model_training_data_status_message == \"test_model_training_data_status_message\"\n assert response.model_training_data_status_details == \"test_model_training_data_status_details\"\n assert response.model_training_data_status_created_at == \"test_model_training_data_status_created_at\"\n assert response.model_training_data_status_updated_at == \"test_model_training_data_status_updated_at\"\n assert response.model_training_data_status_completed_at == \"test_model_training_data_status_completed_at\"\n assert response.model_training_data_status_duration == \"test_model_training_data_status_duration\"\n assert response.model_training_data_status_budget_type == \"test_model_training_data_status_budget_type\"\n assert response.model_training_data_status_budget_amount == \"test_model_training_data_status_budget_amount\"\n assert response.model_training_data_status_budget_unit == \"test_model_training_data_status_budget_unit\"\n assert response.model_training_data_status_budget_remaining == \"test_model_training_data_status_budget_remaining\"\n assert response.model_training_data_status_budget_message == \"test_model_training_data_status_budget_message\"\n assert response.model_training_data_status_budget_details == \"test_model_training_data_status_budget_details\"\n assert response.model_training_data_status_budget_created_at == \"test_model_training_data_status_budget_created_at\"\n assert response.model_training_data_status_budget_updated_at == \"test_" + ], + "line": 31, + "token": 162, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 33, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef edge_index_to_vector(\n edge_index: Tensor,\n size: Tuple[int, int],\n bipartite: bool,\n force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n row, col = edge_index\n\n if bipartite: # No need to account for self-loops.\n idx = (row * size[1]).add_(col)\n population = size[0] * size[1]\n return idx, population\n\n elif force_undirected:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We only operate on the upper triangular matrix:\n mask = row < col\n row, col = row[mask], col[mask]\n offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n idx = row.mul_(num_nodes).add_(col).sub_(offset)\n population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n return idx, population\n\n else:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We remove self-loops as we do not want to take them into account\n # when sampling negative values.\n mask = row != col\n row, col = row[mask], col[mask]\n col[row < col] -= 1\n idx = row.mul_(num_nodes - 1).add_(col)\n population = num_nodes * num_nodes - num_nodes\n return idx, population\n```\n###test function signature:\n```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_edge_index_to_vector_and_vice_versa", + "reference": "\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n N = 10\n row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n edge_index = torch.stack([row, col], dim=0)\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n assert population == N * N\n assert idx.tolist() == list(range(population))\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n assert is_undirected(edge_index2)\n assert edge_index.tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n assert population == N * N - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n force_undirected=True)\n assert population == (N * (N + 1)) / 2 - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n force_undirected=True)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()", + "completions": [ + "```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n\n # Test force_undirected:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test force_undirected with bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n```\n" + ], + "line": 43, + "token": 351, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 34, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_integration_test(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_integration_test", + "reference": "\n def test_spectrogram_integration_test(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,\n 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,\n 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,\n 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,\n 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,\n 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,\n 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,\n 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,\n 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,\n 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,\n 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,\n 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,\n 0.0293578 , 0.03452379, 0.02194803, 0.01676056,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=256,\n num_mel_filters=400,\n min_frequency=20,\n max_frequency=8000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"kaldi\",\n triangularize_in_mel_space=True,\n )\n\n mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"povey\", periodic=False),\n frame_length=400,\n hop_length=160,\n fft_length=512,\n power=2.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n preemphasis=0.97,\n mel_filters=mel_filters,\n log_mel=\"log\",\n mel_floor=1.1920928955078125e-07,\n remove_dc_offset=True,\n )\n self.assertEqual(spec.shape, (400, 584))\n\n # fmt: off\n expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -6.52463769, -7.73677889, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.18650018, -3.37195286,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -4.70190154, -2.4217066 , -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -5.62755239, -3.53385194,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -9.43303023, -8.77480925, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.2951092 , -5.51585994,\n -15.94238515, -15.94238515, -15.94238515, -4.40151721,\n -3.95228878, -15.94238515, -15.94238515, -15.94238515,\n -6.10365415, -4.59494697, -15.94238515, -15.94238515,\n -15.94238515, -8.10727767, -6.2585298 , -15.94238515,\n -15.94238515, -15.94238515, -5.60161702, -4.47217004,\n -15.94238515, -15.94238515, -15.94238515, -5.91641988]\n )\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))", + "completions": [ + "```python\n def test_spectrogram_integration_test(self):\n waveform = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])\n window = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n expected_spectrogram = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]])\n\n self.assertTrue(np.allclose(spectrogram, expected_spectrogram))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 35, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef run_query_expansion_node(modules: List[Callable],\n module_params: List[Dict],\n previous_result: pd.DataFrame,\n node_line_dir: str,\n strategies: Dict,\n ) -> pd.DataFrame:\n if not os.path.exists(node_line_dir):\n os.makedirs(node_line_dir)\n node_dir = os.path.join(node_line_dir, \"query_expansion\")\n if not os.path.exists(node_dir):\n os.makedirs(node_dir)\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n\n # run query expansion\n results, execution_times = zip(*map(lambda task: measure_speed(\n task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))\n average_times = list(map(lambda x: x / len(results[0]), execution_times))\n\n # save results to folder\n pseudo_module_params = deepcopy(module_params)\n for i, module_param in enumerate(pseudo_module_params):\n if 'prompt' in module_params:\n module_param['prompt'] = str(i)\n filepaths = list(map(lambda x: os.path.join(node_dir, f'{x}.parquet'), range(len(modules))))\n list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet\n filenames = list(map(lambda x: os.path.basename(x), filepaths))\n\n # make summary file\n summary_df = pd.DataFrame({\n 'filename': filenames,\n 'module_name': list(map(lambda module: module.__name__, modules)),\n 'module_params': module_params,\n 'execution_time': average_times,\n })\n\n # Run evaluation when there are more than one module.\n if len(modules) > 1:\n # pop general keys from strategies (e.g. metrics, speed_threshold)\n general_key = ['metrics', 'speed_threshold']\n general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))\n extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))\n\n # first, filter by threshold if it is enabled.\n if general_strategy.get('speed_threshold') is not None:\n results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],\n filenames)\n\n # check metrics in strategy\n if general_strategy.get('metrics') is None:\n raise ValueError(\"You must at least one metrics for query expansion evaluation.\")\n\n if extra_strategy.get('top_k') is None:\n extra_strategy['top_k'] = 10 # default value\n\n # get retrieval modules from strategy\n retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)\n\n # get retrieval_gt\n retrieval_gt = pd.read_parquet(os.path.join(project_dir, \"data\", \"qa.parquet\"))['retrieval_gt'].tolist()\n\n # run evaluation\n evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(\n retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,\n general_strategy['metrics'], project_dir, previous_result), results))\n\n evaluation_df = pd.DataFrame({\n 'filename': filenames,\n **{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))\n for metric_name in general_strategy['metrics']}\n })\n summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')\n\n best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)\n # change metric name columns to query_expansion_metric_name\n best_result = best_result.rename(columns={\n metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})\n best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])\n else:\n best_result, best_filename = results[0], filenames[0]\n best_result = pd.concat([previous_result, best_result], axis=1)\n\n # add 'is_best' column at summary file\n summary_df['is_best'] = summary_df['filename'] == best_filename\n\n # save files\n summary_df.to_csv(os.path.join(node_dir, \"summary.csv\"), index=False)\n best_result.to_parquet(os.path.join(node_dir, f\"best_{os.path.splitext(best_filename)[0]}.parquet\"), index=False)\n\n return best_result\n```\n###test function signature:\n```python\n\ndef test_run_query_expansion_node_default(node_line_dir):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_query_expansion_node_default", + "reference": "\ndef test_run_query_expansion_node_default(node_line_dir):\n generator_models['mock'] = MockLLM\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n qa_path = os.path.join(project_dir, \"data\", \"qa.parquet\")\n previous_result = pd.read_parquet(qa_path)\n\n modules = [query_decompose, hyde]\n module_params = [{'llm': \"mock\", 'temperature': 0.2, 'batch': 7}, {'llm': \"mock\"}]\n strategies = {\n 'metrics': metrics\n }\n best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n base_query_expansion_test(best_result, node_line_dir)", + "completions": [ + "```python\ndef test_run_query_expansion_node_default(node_line_dir):\n modules = [query_expansion_module]\n module_params = [{'prompt': 'default'}]\n previous_result = pd.DataFrame({'query': ['a', 'b', 'c']})\n strategies = {'metrics': ['recall', 'precision', 'f1'], 'speed_threshold': 1000}\n run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n```\n" + ], + "line": 84, + "token": 850, + "line_diff": 5, + "token_diff": 6 + }, + { + "id": 36, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(self) -> Tuple[pd.DataFrame, FreqDetectionResult]:\n try:\n if self._prophet_detector_params[\"suppress_stan\"]:\n pd.set_option(\"mode.chained_assignment\", None)\n\n # Skip measurements based on feedbacks given from SODA Cloud\n preprocessed_df = self.preprocess(time_series_df=self.raw_time_series_df)\n\n # Automatically detect frequency of the time series\n freq_detector = FrequencyDetector(\n logs=self.logs,\n params=self.params,\n time_series_df=preprocessed_df,\n manual_freq=self.training_dataset_params.frequency,\n )\n freq_detection_result = freq_detector.detect_frequency()\n\n # Return if frequency detection failed\n if freq_detection_result.error_code_int >= ERROR_CODE_LEVEL_CUTTOFF:\n return self.exit_with_warning(freq_detection_result)\n\n # Apply training dataset configurations\n training_df = self.apply_training_dataset_configs(\n time_series_df=preprocessed_df, freq_detection_result=freq_detection_result\n )\n\n # Remove big gaps from the time series to not confuse Prophet\n training_df = self.remove_big_gaps_from_time_series(\n time_series_df=training_df, freq_detection_result=freq_detection_result\n )\n\n # Only use the last n points for training based on the window length\n window_length = self.get_window_length(training_df=training_df)\n training_df = training_df.iloc[-window_length:]\n\n training_df_shape = training_df[\"y\"].dropna().shape[0]\n if training_df_shape <= self._min_n_points:\n freq_detection_result = get_not_enough_measurements_freq_result(n_data_points=training_df_shape)\n return self.exit_with_warning(freq_detection_result)\n\n model_hyperparameters = self.get_prophet_hyperparameters(time_series_df=training_df)\n\n predictions_df = self.setup_fit_predict(\n time_series_df=training_df, model_hyperparameters=model_hyperparameters\n )\n anomalies_df = self.detect_anomalies(time_series_df=training_df, predictions_df=predictions_df)\n anomalies_df = self.generate_severity_zones(anomalies_df=anomalies_df)\n anomalies_df = self.compute_alert_level(anomalies_df=anomalies_df)\n return anomalies_df, freq_detection_result\n except Exception as e:\n raise e\n```\n###test function signature:\n```python\n\ndef test_with_exit() -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_exit", + "reference": "\ndef test_with_exit() -> None:\n time_series_df = generate_random_dataframe(size=2, n_rows_to_convert_none=0, frequency=\"D\")\n detector = ProphetDetector(\n logs=LOGS,\n params=PARAMS,\n time_series_df=time_series_df,\n model_cfg=ModelConfigs(),\n training_dataset_params=TrainingDatasetParameters(),\n severity_level_params=SeverityLevelParameters(),\n )\n df_anomalies, frequency_result = detector.run()\n assert df_anomalies.empty\n assert frequency_result.error_code_int == 100\n assert frequency_result.freq_detection_strategy == \"not_enough_measurements_custom\"", + "completions": [ + "```python\ndef test_with_exit() -> None:\n from soda.sodacl.checks.check_frequency import FrequencyDetector\n from soda.sodacl.checks.check_frequency import get_not_enough_measurements_freq_result\n from soda.sodacl.checks.check_frequency import ERROR_CODE_LEVEL_CUTTOFF\n from soda.sodacl.checks.check_frequency import FreqDetectionResult\n from soda.sodacl.checks.check_frequency import CheckFrequency\n from soda.sodacl.checks.check_frequency import CheckFrequencyParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyLogs\n from soda.sodacl.checks.check_frequency import CheckFrequencyTrainingDatasetParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyPro" + ], + "line": 52, + "token": 310, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 37, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef average_transforms(\n R: torch.Tensor,\n t: torch.Tensor,\n w: torch.Tensor,\n mask: torch.Tensor,\n dim: int,\n t_edge: Optional[torch.Tensor] = None,\n dither: Optional[bool] = True,\n dither_eps: float = 1e-4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert dim >= 0, \"dimension must index from the left\"\n w = torch.where(\n mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)\n )\n\n # We use different averaging models based on the number of weights\n num_transform_weights = w.size(-1)\n if num_transform_weights == 1:\n # Share a single scalar weight between t and R.\n probs = w.softmax(dim)\n t_probs = probs\n R_probs = probs[..., None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 2:\n # Use separate scalar weights for each of t and R.\n probs = w.softmax(dim)\n t_probs, R_probs = probs.unbind(-1)\n t_probs = t_probs[..., None]\n R_probs = R_probs[..., None, None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 3:\n # For R use a signed scalar weight.\n R_probs = w[..., 2].softmax(dim)[..., None, None]\n\n # For t use a two-parameter precision matrix P = P_isometric + P_radial.\n # We need to hand compute softmax over the shared dim x 2 elements.\n w_t = w[..., :2]\n w_t_total = w_t.logsumexp([dim, -1], True)\n p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)\n\n # Use Gaussian fusion for translation.\n t_edge = t_edge * mask.to(t_edge.dtype)[..., None]\n t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)\n else:\n raise NotImplementedError\n\n # Average rotation via SVD\n R_avg_unc = (R * R_probs).sum(dim)\n R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)\n U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)\n R_avg = U @ Vh\n\n # Enforce that matrix is rotation matrix\n d = torch.linalg.det(R_avg)\n d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)\n Vh = Vh * d_expand\n R_avg = U @ Vh\n return R_avg, t_avg\n```\n###test function signature:\n```python\n\ndef test_average_transforms(rotations, translations):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_average_transforms", + "reference": "\ndef test_average_transforms(rotations, translations):\n R = torch.stack([rotations[0], torch.eye(3)])\n t = torch.stack([translations[0], torch.zeros(3)])\n w = torch.ones(2, 2)\n mask = torch.ones(2)\n # average of a transform with the identity is \"half\" the transform\n R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)\n R_total_fromavg, _ = compose_transforms(\n R_avg, torch.zeros(3), R_avg, torch.zeros(3)\n )\n _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)\n assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)\n assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)", + "completions": [ + "```python\ndef test_average_transforms(rotations, translations):\n # test case 1\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 2\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 3\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 4\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 5\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 6\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 7\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 8\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 9\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w" + ], + "line": 65, + "token": 537, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 38, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n return GPT2Config(\n vocab_size=btlm_config.vocab_size,\n n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n n_embd=btlm_config.hidden_size,\n n_layer=btlm_config.num_hidden_layers,\n n_head=btlm_config.num_attention_heads,\n n_inner=btlm_config.n_inner,\n activation_function=btlm_config.activation_function,\n resid_pdrop=btlm_config.resid_pdrop,\n embd_pdrop=btlm_config.embd_pdrop,\n attn_pdrop=btlm_config.attn_pdrop,\n layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n initializer_range=btlm_config.initializer_range,\n bos_token_id=btlm_config.bos_token_id,\n eos_token_id=btlm_config.eos_token_id,\n # These are new arguments not in the original GPT2Config\n use_alibi=btlm_config.position_embedding_type == \"alibi\",\n use_flash_attn=btlm_config.position_embedding_type == \"alibi\", # Alibi code path requires flash_attn\n mup_width_scale=btlm_config.mup_width_scale,\n mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n mup_output_multiplier=btlm_config.mup_output_alpha,\n mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n mlp_multiple_of=1,\n )\n```\n###test function signature:\n```python\ndef test_btlm_optimized(model_name):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_btlm_optimized", + "reference": "def test_btlm_optimized(model_name):\n dtype = torch.float16\n device = \"cuda\"\n config = btlm_config_to_gpt2_config(\n AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n )\n config.fused_bias_fc = True\n config.fused_dropout_add_ln = True\n config.residual_in_fp32 = True\n\n pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n model = GPTLMHeadModel(config, device=device, dtype=dtype)\n model.load_state_dict(pretrained_state_dict)\n model.eval()\n\n torch.manual_seed(0)\n batch_size = 2\n max_seqlen = 256\n seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n input_ids = torch.randint(\n 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n )\n with torch.no_grad():\n out = model.transformer(input_ids)\n logits = model(input_ids).logits\n del model\n\n # Without device_map, the model is loaded on the CPU, which is very slow\n # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n model_ref = AutoModelForCausalLM.from_pretrained(\n model_name, device_map=\"auto\", trust_remote_code=True\n )\n model_ref.eval()\n with torch.no_grad():\n out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n logits_ref = model_ref(input_ids).logits.to(device=device)\n del model_ref\n\n model_hf = AutoModelForCausalLM.from_pretrained(\n model_name,\n torch_dtype=dtype,\n device_map={\"\": device},\n trust_remote_code=True,\n )\n model_hf.eval()\n with torch.no_grad():\n out_hf = model_hf.transformer(input_ids).last_hidden_state\n logits_hf = model_hf(input_ids).logits\n del model_hf\n\n print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n assert (logits - logits_ref).abs().max().item() < 3 * (\n logits_hf - logits_ref\n ).abs().max().item()", + "completions": [ + "```python\ndef test_btlm_optimized(model_name):\n from transformers import GPT2Config, PretrainedConfig\n from bert_extended.optimized import btlm_config_to_gpt2_config\n\n config = PretrainedConfig.from_pretrained(model_name)\n gpt2_config = btlm_config_to_gpt2_config(config)\n gpt2_config.to_json_file(f\"{model_name}_optimized.json\")\n```\n" + ], + "line": 35, + "token": 169, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 39, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef grad_solid_harmonic(r: jnp.ndarray, l_max: int) -> jnp.ndarray:\n def cartesian_sum(r, ell):\n # computes all cartesian sums of order ell\n # i.e. all terms (a*x + b*y + c*z) for a+b+c = ell\n if ell == 0:\n return np.array([0.])\n return np.sum(np.stack(\n np.meshgrid(*[r]*ell, indexing='ij'), axis=0).reshape(ell, -1), axis=0)\n\n dharmonics = []\n for ell in range(l_max + 1):\n dx = cartesian_sum(np.array([1., 0., 0.]), ell)\n dy = cartesian_sum(np.array([0., 1., 0.]), ell)\n dz = cartesian_sum(np.array([0., 0., 1.]), ell)\n d = np.stack([dx, dy, dz], axis=1)\n cp = jax.vmap(functools.partial(cartesian_product, ell=ell), out_axes=-1)\n dcp = jnp.where(r.T[None, :, :] == 0, 0,\n d[:, :, None]*cp(r)[:, None, :] / r.T[None, :, :])\n sh = jnp.einsum('cs,cxn->snx', full_cart2sph(ell, True), dcp)\n pad = l_max - ell\n sh = jnp.pad(sh, [[pad, pad], [0, 0], [0, 0]])\n dharmonics.append(sh)\n\n return jnp.stack(dharmonics, axis=1) # [L, M, grid, xyz]\n```\n###test function signature:\n```python\n\n def test_grad_solid_harmonic(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_grad_solid_harmonic", + "reference": "\n def test_grad_solid_harmonic(self):\n np.random.seed(0)\n r = np.random.randn(100, 3)\n l_max = 5\n\n jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max))\n expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])),\n [1, 2, 0, 3])\n\n with self.subTest('by hand'):\n observed = gto.grad_solid_harmonic(r, l_max)\n np.testing.assert_allclose(observed, expected, atol=1.e-4)\n with self.subTest('by jax'):\n observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max)\n np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4)", + "completions": [ + "```python\n def test_grad_solid_harmonic(self):\n r = jnp.array([[0.5, 0.5, 0.5]])\n l_max = 2\n dharmonics = grad_solid_harmonic(r, l_max)\n self.assertEqual(dharmonics.shape, (l_max + 1, l_max + 1, 1, 3))\n self.assertTrue(jnp.allclose(dharmonics[0, 0, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 0, 0, :], jnp.array([1., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 1, 0, :], jnp.array([0., 1., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 0, 0, :], jnp.array([0., 0., 1.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 1, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 2, 0, :], jnp.array([0., 0., 0.])))\n```\n" + ], + "line": 32, + "token": 389, + "line_diff": 2, + "token_diff": 4 + }, + { + "id": 40, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_basis(\n base_shape, angular_tesselation, remove_symmetries=True, eps=1e-4\n):\n\n if base_shape == 'tetrahedron':\n verts = np.array([\n (np.sqrt(8 / 9), 0, -1 / 3),\n (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),\n (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3),\n (0, 0, 1),\n ])\n faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)])\n elif base_shape == 'icosahedron':\n a = (np.sqrt(5) + 1) / 2\n verts = np.array([\n (-1, 0, a),\n (1, 0, a),\n (-1, 0, -a),\n (1, 0, -a),\n (0, a, 1),\n (0, a, -1),\n (0, -a, 1),\n (0, -a, -1),\n (a, 1, 0),\n (-a, 1, 0),\n (a, -1, 0),\n (-a, -1, 0),\n ]) / np.sqrt(a + 2)\n faces = np.array([\n (0, 4, 1),\n (0, 9, 4),\n (9, 5, 4),\n (4, 5, 8),\n (4, 8, 1),\n (8, 10, 1),\n (8, 3, 10),\n (5, 3, 8),\n (5, 2, 3),\n (2, 7, 3),\n (7, 10, 3),\n (7, 6, 10),\n (7, 11, 6),\n (11, 0, 6),\n (0, 1, 6),\n (6, 1, 10),\n (9, 0, 11),\n (9, 11, 2),\n (9, 2, 5),\n (7, 2, 11),\n ])\n elif base_shape == 'octahedron':\n verts = np.array(\n [(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]\n )\n corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n else:\n raise ValueError(f'base_shape {base_shape} not supported')\n verts = tesselate_geodesic(verts, faces, angular_tesselation)\n\n if remove_symmetries:\n # Remove elements of `verts` that are reflections of each other.\n match = compute_sq_dist(verts.T, -verts.T) < eps\n verts = verts[~np.any(np.triu(match), axis=0), :]\n\n basis = verts[:, ::-1]\n return basis\n```\n###test function signature:\n```python\n def test_generate_basis_golden(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_basis_golden", + "reference": " def test_generate_basis_golden(self):\n\n basis = geopoly.generate_basis('tetrahedron', 1)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('tetrahedron', 2)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.57735027, 0.00000000, -0.81649658],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.57735027, -0.70710678, 0.40824829],\n [-0.57735027, 0.70710678, 0.40824829],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('icosahedron', 2)\n basis_golden = np.array([\n [0.85065081, 0.00000000, 0.52573111],\n [0.80901699, 0.50000000, 0.30901699],\n [0.52573111, 0.85065081, 0.00000000],\n [1.00000000, 0.00000000, 0.00000000],\n [0.80901699, 0.50000000, -0.30901699],\n [0.85065081, 0.00000000, -0.52573111],\n [0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, -0.85065081],\n [0.50000000, 0.30901699, -0.80901699],\n [0.00000000, 1.00000000, 0.00000000],\n [-0.52573111, 0.85065081, 0.00000000],\n [-0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, 0.85065081],\n [-0.30901699, 0.80901699, 0.50000000],\n [0.30901699, 0.80901699, 0.50000000],\n [0.50000000, 0.30901699, 0.80901699],\n [0.50000000, -0.30901699, 0.80901699],\n [0.00000000, 0.00000000, 1.00000000],\n [-0.50000000, 0.30901699, 0.80901699],\n [-0.80901699, 0.50000000, 0.30901699],\n [-0.80901699, 0.50000000, -0.30901699],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('octahedron', 4)\n basis_golden = np.array([\n [0.00000000, 0.00000000, -1.00000000],\n [0.00000000, -0.31622777, -0.94868330],\n [0.00000000, -0.70710678, -0.70710678],\n [0.00000000, -0.94868330, -0.31622777],\n [0.00000000, -1.00000000, 0.00000000],\n [-0.31622777, 0.00000000, -0.94868330],\n [-0.40824829, -0.40824829, -0.81649658],\n [-0.40824829, -0.81649658, -0.40824829],\n [-0.31622777, -0.94868330, 0.00000000],\n [-0.70710678, 0.00000000, -0.70710678],\n [-0.81649658, -0.40824829, -0.40824829],\n [-0.70710678, -0.70710678, 0.00000000],\n [-0.94868330, 0.00000000, -0.31622777],\n [-0.94868330, -0.31622777, 0.00000000],\n [-1.00000000, 0.00000000, 0.00000000],\n [0.00000000, -0.31622777, 0.94868330],\n [0.00000000, -0.70710678, 0.70710678],\n [0.00000000, -0.94868330, 0.31622777],\n [0.40824829, -0.40824829, 0.81649658],\n [0.40824829, -0.81649658, 0.40824829],\n [0.31622777, -0.94868330, 0.00000000],\n [0.81649658, -0.40824829, 0.40824829],\n [0.70710678, -0.70710678, 0.00000000],\n [0.94868330, -0.31622777, 0.00000000],\n [0.31622777, 0.00000000, -0.94868330],\n [0.40824829, 0.40824829, -0.81649658],\n [0.40824829, 0.81649658, -0.40824829],\n [0.70710678, 0.00000000, -0.70710678],\n [0.81649658, 0.40824829, -0.40824829],\n [0.94868330, 0.00000000, -0.31622777],\n [0.40824829, -0.40824829, -0.81649658],\n [0.40824829, -0.81649658, -0.40824829],\n [0.81649658, -0.40824829, -0.40824829],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))", + "completions": [ + "```python\n def test_generate_basis_golden(self):\n basis = generate_basis(\n base_shape='tetrahedron', angular_tesselation=10, remove_symmetries=True, eps=1e-4\n )\n self.assertTrue(np.allclose(basis, np.array([[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0]])))\n```\n" + ], + "line": 75, + "token": 692, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 41, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:\n if not isinstance(t, torch.Tensor):\n t = torch.Tensor([t]).float()\n\n t_max = self.t_max.to(t.device)\n t_min = self.t_min.to(t.device)\n t_tilde = t_min + (t_max - t_min) * t\n\n return t_tilde\n```\n###test function signature:\n```python\n\ndef test_noise_schedule_diffusion_coeff(kind):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_noise_schedule_diffusion_coeff", + "reference": "\ndef test_noise_schedule_diffusion_coeff(kind):\n noise_schedule = GaussianNoiseSchedule(kind=kind, log_snr_range=(-6, 6))\n ts = torch.linspace(1e-2, 1 - 1e-2, 10)\n t_map = noise_schedule.t_map(ts) # map time to the prescribed log_SNR range\n\n if kind == \"log_snr\":\n g = noise_schedule.g(ts)\n\n # compute true beta_t\n l_range = noise_schedule.log_snr_range\n snr = torch.exp((1 - t_map) * l_range[1] + t_map * l_range[0])\n g_true = (-(l_range[0] - l_range[1]) / (snr + 1)).sqrt()\n\n assert torch.allclose(g, g_true, atol=1e-4)\n\n if kind == \"ot_linear\":\n g = noise_schedule.g(ts)\n g_true = (2.0 * t_map / (1.0 - t_map)).sqrt()\n assert torch.allclose(g, g_true, atol=1e-4)", + "completions": [ + "```python\ndef test_noise_schedule_diffusion_coeff(kind):\n from torch.testing import assert_allclose\n from torch.nn import functional as F\n from torch import tensor\n from torch.distributions import Uniform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import Transform\n from torch.distributions.transforms import SigmoidTransform\n from torch.distributions.transforms import TanhTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import ComposeTransform\n from torch.distributions.transforms import PowerTransform\n from torch.distributions.transforms import ExpTransform\n from torch.distributions.transforms import LogTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import" + ], + "line": 17, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 42, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_data_quality", + "reference": "\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index, temperature_data, temperature_mean=False, data_quality=True\n )\n assert df.shape == (27, 4)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_not_null\",\n \"temperature_null\",\n ]\n assert round(df.temperature_not_null.mean(), 2) == 729.23\n assert round(df.temperature_null.mean(), 2) == 0.0", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"1H\", tz=\"America/Chicago\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0" + ], + "line": 168, + "token": 984, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 43, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n\n # when\n result = http_client.unload_model(model_id=\"some/1\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"some/1\",\n }", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n model_id = \"my_model_id\"\n de_aliased_model_id = \"my_de_aliased_model_id\"\n response_payload = {\n \"model_id\": de_aliased_model_id,\n \"model_name\": \"my_model_name\",\n \"model_version\": \"my_model_version\",\n \"model_type\": \"my_model_type\",\n \"model_status\": \"my_model_status\",\n \"model_description\": \"my_model_description\",\n \"model_tags\": [\"my_model_tag1\", \"my_model_tag2\"],\n \"model_created_at\": \"my_model_created_at\",\n \"model_updated_at\": \"my_model_updated_at\",\n }\n requests_mock.post(\n f\"{self.__api_url}/model/remove\",\n json=response_payload,\n headers=DEFAULT_HEADERS,\n )\n # when\n response = self.client.unload_model(model_id=model_id)\n # then\n assert response.model_id == de_aliased_model_id\n assert response.model_name == \"my_model_name\"\n assert response.model_version == \"my_model_version\"\n assert response.model_type == \"my_model_type\"\n assert response.model_status == \"my_model_status\"\n assert response.model_description == \"my_model_description\"\n assert response.model_tags == [\"my_model_tag1\", \"my_model_tag2\"]\n assert response.model_created_at == \"my_model_created_at\"\n assert response.model_updated_at == \"my_model_updated_at\"\n assert self.client.selected_model is None\n```\n" + ], + "line": 31, + "token": 162, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=20.0,\n max_iteration=100,\n distance_strict=False,\n seed=None,\n):\n rng = np.random.default_rng(seed=seed)\n units_locations = np.zeros((num_units, 3), dtype=\"float32\")\n\n minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um\n minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um\n\n units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)\n units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)\n units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)\n\n if minimum_distance is not None:\n solution_found = False\n renew_inds = None\n for i in range(max_iteration):\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n inds0, inds1 = np.nonzero(distances < minimum_distance)\n mask = inds0 != inds1\n inds0 = inds0[mask]\n inds1 = inds1[mask]\n\n if inds0.size > 0:\n if renew_inds is None:\n renew_inds = np.unique(inds0)\n else:\n # random only bad ones in the previous set\n renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]\n\n units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)\n units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)\n units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)\n else:\n solution_found = True\n break\n\n if not solution_found:\n if distance_strict:\n raise ValueError(\n f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} \"\n \"You can use distance_strict=False or reduce minimum distance\"\n )\n else:\n warnings.warn(f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}\")\n\n return units_locations\n```\n###test function signature:\n```python\n\ndef test_generate_unit_locations():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_unit_locations", + "reference": "\ndef test_generate_unit_locations():\n seed = 0\n\n probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)\n channel_locations = probe.contact_positions\n\n num_units = 100\n minimum_distance = 20.0\n unit_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=minimum_distance,\n max_iteration=500,\n distance_strict=False,\n seed=seed,\n )\n distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)\n dist_flat = np.triu(distances, k=1).flatten()\n dist_flat = dist_flat[dist_flat > 0]\n assert np.all(dist_flat > minimum_distance)\n\n # import matplotlib.pyplot as plt\n # fig, ax = plt.subplots()\n # ax.hist(dist_flat, bins = np.arange(0, 400, 10))\n # fig, ax = plt.subplots()\n # from probeinterface.plotting import plot_probe\n # plot_probe(probe, ax=ax)\n # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)\n # plt.show()", + "completions": [ + "```python\n\ndef test_generate_unit_locations():\n num_units = 10\n channel_locations = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])\n margin_um = 20.0\n minimum_z = 5.0\n maximum_z = 40.0\n minimum_distance = 20.0\n max_iteration = 100\n distance_strict = False\n seed = None\n\n units_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=margin_um,\n minimum_z=minimum_z,\n maximum_z=maximum_z,\n minimum_distance=minimum_distance,\n max_iteration=max_iteration,\n distance_strict=distance_strict,\n seed=seed,\n )\n\n assert units_locations.shape == (num_units, 3)\n assert np.all(units_locations[:, 0] >= margin_um)\n assert np.all(units_locations[:, 0] <= 100 - margin_um)\n assert np.all(units_locations[:, 1] >= margin_um)\n assert np.all(units_locations[:, 1] <= 100 - margin_um)\n assert np.all(units_locations[:, 2] >= minimum_z)\n assert np.all(units_locations[:, 2] <= maximum_z)\n\n # check distance\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n assert np.all(distances >= minimum_distance)\n\n # check no duplicate\n assert np.all(np.unique(units_locations, axis=0) == units_locations)\n```\n" + ], + "line": 58, + "token": 453, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 45, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def to_dataframe(\n self, index: Optional[Union[str, None]] = \"date\", sort_by: Optional[str] = None\n ) -> pd.DataFrame:\n\n def is_list_of_basemodel(items: Union[List[T], T]) -> bool:\n return isinstance(items, list) and all(\n isinstance(item, BaseModel) for item in items\n )\n\n if self.results is None or not self.results:\n raise OpenBBError(\"Results not found.\")\n\n if isinstance(self.results, pd.DataFrame):\n return self.results\n\n try:\n res = self.results\n df = None\n sort_columns = True\n\n # List[Dict]\n if isinstance(res, list) and len(res) == 1 and isinstance(res[0], dict):\n r = res[0]\n dict_of_df = {}\n\n for k, v in r.items():\n # Dict[str, List[BaseModel]]\n if is_list_of_basemodel(v):\n dict_of_df[k] = basemodel_to_df(v, index)\n sort_columns = False\n # Dict[str, Any]\n else:\n dict_of_df[k] = pd.DataFrame(v)\n\n df = pd.concat(dict_of_df, axis=1)\n\n # List[BaseModel]\n elif is_list_of_basemodel(res):\n dt: Union[List[Data], Data] = res # type: ignore\n df = basemodel_to_df(dt, index)\n sort_columns = False\n # List[List | str | int | float] | Dict[str, Dict | List | BaseModel]\n else:\n try:\n df = pd.DataFrame(res)\n # Set index, if any\n if index is not None and index in df.columns:\n df.set_index(index, inplace=True)\n\n except ValueError:\n if isinstance(res, dict):\n df = pd.DataFrame([res])\n\n if df is None:\n raise OpenBBError(\"Unsupported data format.\")\n\n # Drop columns that are all NaN, but don't rearrange columns\n if sort_columns:\n df.sort_index(axis=1, inplace=True)\n df = df.dropna(axis=1, how=\"all\")\n\n # Sort by specified column\n if sort_by:\n df.sort_values(by=sort_by, inplace=True)\n\n except OpenBBError as e:\n raise e\n except ValueError as ve:\n raise OpenBBError(\n f\"ValueError: {ve}. Ensure the data format matches the expected format.\"\n ) from ve\n except TypeError as te:\n raise OpenBBError(\n f\"TypeError: {te}. Check the data types in your results.\"\n ) from te\n except Exception as ex:\n raise OpenBBError(f\"An unexpected error occurred: {ex}\") from ex\n\n return df\n```\n###test function signature:\n```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_to_dataframe_w_args", + "reference": "def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n co = OBBject(results=results)\n\n # Act and Assert\n result = co.to_dataframe(index=index, sort_by=sort_by)\n assert isinstance(result, pd.DataFrame)\n assert result.index.name == index\n\n # check if dataframe is properly sorted\n assert result[sort_by].is_monotonic_increasing", + "completions": [ + "```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n from openbb_terminal.decorators import log_start_end\n from openbb_terminal.stocks.fundamental_analysis import financedatabase_model\n\n @log_start_end(log=logger)\n def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n df = financedatabase_model.to_dataframe(results, index, sort_by)\n # Act\n # Assert\n assert df.shape[0] > 0\n assert df.shape[1] > 0\n assert df.index.name == index\n assert df.columns.tolist() == sort_by\n\n # Act\n test_to_dataframe_w_args(results, index, sort_by)\n # Assert\n```\n" + ], + "line": 76, + "token": 566, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 46, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef unique(ar, return_index=False, return_inverse=False,\n return_counts=False, axis=None, *, equal_nan=True):\n ar = np.asanyarray(ar)\n if axis is None:\n ret = _unique1d(ar, return_index, return_inverse, return_counts, \n equal_nan=equal_nan)\n return _unpack_tuple(ret)\n\n # axis was specified and not None\n try:\n ar = np.moveaxis(ar, axis, 0)\n except np.AxisError:\n # this removes the \"axis1\" or \"axis2\" prefix from the error message\n raise np.AxisError(axis, ar.ndim) from None\n\n # Must reshape to a contiguous 2D array for this to work...\n orig_shape, orig_dtype = ar.shape, ar.dtype\n ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))\n ar = np.ascontiguousarray(ar)\n dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]\n\n # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured\n # data type with `m` fields where each field has the data type of `ar`.\n # In the following, we create the array `consolidated`, which has\n # shape `(n,)` with data type `dtype`.\n try:\n if ar.shape[1] > 0:\n consolidated = ar.view(dtype)\n else:\n # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is\n # a data type with itemsize 0, and the call `ar.view(dtype)` will\n # fail. Instead, we'll use `np.empty` to explicitly create the\n # array with shape `(len(ar),)`. Because `dtype` in this case has\n # itemsize 0, the total size of the result is still 0 bytes.\n consolidated = np.empty(len(ar), dtype=dtype)\n except TypeError as e:\n # There's no good way to do this for object arrays, etc...\n msg = 'The axis argument to unique is not supported for dtype {dt}'\n raise TypeError(msg.format(dt=ar.dtype)) from e\n\n def reshape_uniq(uniq):\n n = len(uniq)\n uniq = uniq.view(orig_dtype)\n uniq = uniq.reshape(n, *orig_shape[1:])\n uniq = np.moveaxis(uniq, 0, axis)\n return uniq\n\n output = _unique1d(consolidated, return_index,\n return_inverse, return_counts, equal_nan=equal_nan)\n output = (reshape_uniq(output[0]),) + output[1:]\n return _unpack_tuple(output)\n```\n###test function signature:\n```python\n\n def test_unique_1d(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_1d", + "reference": "\n def test_unique_1d(self):\n\n def check_all(a, b, i1, i2, c, dt):\n base_msg = 'check {0} failed for type {1}'\n\n msg = base_msg.format('values', dt)\n v = unique(a)\n assert_array_equal(v, b, msg)\n\n msg = base_msg.format('return_index', dt)\n v, j = unique(a, True, False, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i1, msg)\n\n msg = base_msg.format('return_inverse', dt)\n v, j = unique(a, False, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i2, msg)\n\n msg = base_msg.format('return_counts', dt)\n v, j = unique(a, False, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, c, msg)\n\n msg = base_msg.format('return_index and return_inverse', dt)\n v, j1, j2 = unique(a, True, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n\n msg = base_msg.format('return_index and return_counts', dt)\n v, j1, j2 = unique(a, True, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format('return_inverse and return_counts', dt)\n v, j1, j2 = unique(a, False, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i2, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format(('return_index, return_inverse '\n 'and return_counts'), dt)\n v, j1, j2, j3 = unique(a, True, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n assert_array_equal(j3, c, msg)\n\n a = [5, 7, 1, 2, 1, 5, 7]*10\n b = [1, 2, 5, 7]\n i1 = [2, 3, 0, 1]\n i2 = [2, 3, 0, 1, 0, 2, 3]*10\n c = np.multiply([2, 1, 2, 2], 10)\n\n # test for numeric arrays\n types = []\n types.extend(np.typecodes['AllInteger'])\n types.extend(np.typecodes['AllFloat'])\n types.append('datetime64[D]')\n types.append('timedelta64[D]')\n for dt in types:\n aa = np.array(a, dt)\n bb = np.array(b, dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for object arrays\n dt = 'O'\n aa = np.empty(len(a), dt)\n aa[:] = a\n bb = np.empty(len(b), dt)\n bb[:] = b\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for structured arrays\n dt = [('', 'i'), ('', 'i')]\n aa = np.array(list(zip(a, a)), dt)\n bb = np.array(list(zip(b, b)), dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for ticket #2799\n aa = [1. + 0.j, 1 - 1.j, 1]\n assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])\n\n # test for ticket #4785\n a = [(1, 2), (1, 2), (2, 3)]\n unq = [1, 2, 3]\n inv = [0, 1, 0, 1, 1, 2]\n a1 = unique(a)\n assert_array_equal(a1, unq)\n a2, a2_inv = unique(a, return_inverse=True)\n assert_array_equal(a2, unq)\n assert_array_equal(a2_inv, inv)\n\n # test for chararrays with return_inverse (gh-5099)\n a = np.chararray(5)\n a[...] = ''\n a2, a2_inv = np.unique(a, return_inverse=True)\n assert_array_equal(a2_inv, np.zeros(5))\n\n # test for ticket #9137\n a = []\n a1_idx = np.unique(a, return_index=True)[1]\n a2_inv = np.unique(a, return_inverse=True)[1]\n a3_idx, a3_inv = np.unique(a, return_index=True,\n return_inverse=True)[1:]\n assert_equal(a1_idx.dtype, np.intp)\n assert_equal(a2_inv.dtype, np.intp)\n assert_equal(a3_idx.dtype, np.intp)\n assert_equal(a3_inv.dtype, np.intp)\n\n # test for ticket 2111 - float\n a = [2.0, np.nan, 1.0, np.nan]\n ua = [1.0, 2.0, np.nan]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - complex\n a = [2.0-1j, np.nan, 1.0+1j, complex(0.0, np.nan), complex(1.0, np.nan)]\n ua = [1.0+1j, 2.0-1j, complex(0.0, np.nan)]\n ua_idx = [2, 0, 3]\n ua_inv = [1, 2, 0, 2, 2]\n ua_cnt = [1, 1, 3]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - datetime64\n nat = np.datetime64('nat')\n a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]\n ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - timedelta\n nat = np.timedelta64('nat')\n a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]\n ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for gh-19300\n all_nans = [np.nan] * 4\n ua = [np.nan]\n ua_idx = [0]\n ua_inv = [0, 0, 0, 0]\n ua_cnt = [4]\n assert_equal(np.unique(all_nans), ua)\n assert_equal(np.unique(all_nans, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(all_nans, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(all_nans, return_counts=True), (ua, ua_cnt))", + "completions": [ + "```python\n def test_unique_1d(self):\n ar = np.array([1, 2, 3, 4, 5, 5, 5, 5, 5, 5])\n self.assertEqual(unique(ar), np.array([1, 2, 3, 4, 5]))\n self.assertEqual(unique(ar, return_index=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n```\n" + ], + "line": 56, + "token": 571, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 47, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef scaled_dot_product_attention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n att_mask: Optional[Union[AttentionMask, \"SparseCS\", torch.Tensor]],\n dropout: Optional[torch.nn.Module] = None,\n block_size: int = 128,\n) -> torch.Tensor:\n autocast_disabled = (\n _has_cpp_library\n and isinstance(att_mask, SparseCS)\n or (att_mask is not None and att_mask.is_sparse)\n )\n seq_len = q.shape[-2]\n\n # switch if:\n # causal is required but mask is not sparse\n # fp16 or under amp context\n # sequence length is divisible by block size\n # same seq len for K and Q\n switch_to_blocksparse = (\n _is_blocksparse_available\n and (att_mask is not None and not att_mask.is_sparse)\n and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)\n and (q.dtype == torch.float16 or torch.is_autocast_enabled())\n and not seq_len % block_size\n and q.shape[-2] == k.shape[-2]\n )\n\n if switch_to_blocksparse:\n logger.info(\"Switching causal attention to Triton blocksparse...\")\n return blocksparse_attention(q, k, v, dropout, block_size)\n\n with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore\n if autocast_disabled:\n q, k, v = q.float(), k.float(), v.float()\n\n att = scaled_query_key_softmax(q, k, att_mask=att_mask)\n\n # Optional dropout, could be part of the masking in the future\n att = _apply_dropout(att, dropout)\n\n # Get to the predicted values, for all heads\n # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)\n y = bmm(att, v)\n return y\n```\n###test function signature:\n```python\n\ndef test_amp_attention_dense(device):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_amp_attention_dense", + "reference": "\ndef test_amp_attention_dense(device):\n b, s, d = 8, 64, 32\n prob = 0.9\n\n a = torch.rand(b, s, d, device=device)\n m = torch.rand(s, s, device=device) > prob\n\n with torch.cuda.amp.autocast():\n r = scaled_dot_product_attention(a, a, a, m)\n\n expected_device = torch.float16 if device == \"cuda\" else torch.float32\n assert r.dtype == expected_device", + "completions": [ + "```python\n\ndef test_amp_attention_dense(device):\n q = torch.randn(1, 128, 128, device=device)\n k = torch.randn(1, 128, 128, device=device)\n v = torch.randn(1, 128, 128, device=device)\n att_mask = torch.randn(1, 128, 128, device=device)\n dropout = torch.nn.Dropout(0.1)\n block_size = 128\n y = scaled_dot_product_attention(q, k, v, att_mask, dropout, block_size)\n assert y.shape == (1, 128, 128)\n```\n" + ], + "line": 50, + "token": 390, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 48, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def quantize(\n cls,\n tensor: Tensor,\n nbits: int = 4,\n channel_wise: bool = True,\n group_size: int = 64,\n optimize: bool = False,\n round_zero: bool = False,\n axis: int = 0,\n bitpack: bool = True,\n compute_dtype: torch.dtype | None = None,\n view_as_float: bool = False,\n device: str = \"cuda\",\n ) -> tuple:\n assert nbits in Quantizer.SUPPORTED_BITS, (\n \"nbits=\" + str(nbits) + \" not supported.\"\n )\n assert axis in [0, 1], \"axis should be either 0 or 1\"\n if group_size is not None:\n assert is_divisible(tensor.numel(), group_size), (\n \"group_size should be divisble by the total tensor dimensions. shape: \"\n + str(tensor.shape)\n + \", group_size: \"\n + str(group_size)\n )\n\n W = tensor.float()\n shape = W.shape\n\n # Reshape for grouping\n if (group_size is not None) and channel_wise:\n W = (\n W.reshape([-1, group_size])\n if (axis == 1)\n else W.reshape([group_size, -1])\n )\n\n # Get min/max values\n if not channel_wise:\n _min, _max = W.min(), W.max()\n optimize = False\n else:\n _min = W.min(axis=axis, keepdim=True)[0]\n _max = W.max(axis=axis, keepdim=True)[0]\n\n max_v = 2**nbits - 1\n min_v = 0\n min_max = [min_v, max_v]\n\n # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.\n scale = (max_v / (_max - _min)).clamp(\n max=2e4\n ) # clamp to avoid half-precision problems\n zero = -_min * scale\n\n # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14\n if round_zero:\n zero = torch.round(zero)\n\n # Fine-tune weights\n if optimize:\n W_q, scale, zero = Quantizer.optimize_weights(\n tensor=W,\n scale=scale,\n zero=zero,\n min_max=min_max,\n axis=axis,\n device=device,\n )\n else:\n W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])\n\n # Store meta-data (we invert the scale for dequantization)\n meta = {\n \"nbits\": nbits,\n \"group_size\": group_size,\n \"shape\": shape,\n \"scale\": 1.0 / scale,\n \"zero\": zero,\n \"axis\": axis,\n \"packing\": Quantizer.bit_to_packing[nbits],\n }\n meta[\"unpack_view_dtype\"] = Quantizer.unpack_view_dtype[meta[\"packing\"]]\n\n # Pack bits\n meta[\"view_as_float\"] = view_as_float\n if bitpack:\n W_q = Quantizer.pack[meta[\"packing\"]](W_q)\n if view_as_float:\n W_q = W_q.view(\n torch.float32 if compute_dtype is None else compute_dtype\n ) # store quantized weights as compute_dtype\n else:\n W_q = W_q.to(tensor.dtype)\n meta[\"packing\"] = None\n\n # cleanup\n del W, _min, _max\n torch.cuda.empty_cache()\n\n return W_q, meta\n```\n###test function signature:\n```python\n\n def test_quantizer(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_quantizer", + "reference": "\n def test_quantizer(self): \n for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:\n for nbits in [8,4,3,2,1]:\n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, view_as_float=False)\n if nbits == 3:\n assert W_q.dtype == torch.int32\n else:\n assert W_q.dtype == torch.uint8\n w_dq = Quantizer.dequantize(W_q, meta)\n norm1 = torch.norm(self.w - w_dq, p=0.7)\n \n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, compute_dtype=compute_dtype, view_as_float=True)\n assert W_q.dtype == compute_dtype\n w_dq = Quantizer.dequantize(W_q, meta)\n norm2 = torch.norm(self.w - w_dq, p=0.7)\n \n self.assertTrue(torch.equal(norm1, norm2))", + "completions": [ + "```python\n\n def test_quantizer(self):\n # test case 1\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 2\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=True, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], 0.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 3\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=True, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 4\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor," + ], + "line": 100, + "token": 645, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 49, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_num_ref_area", + "reference": "\ndef test_nanoplot_out_bars_with_num_ref_area():\n\n out_bars_with_num_ref_area = _generate_nanoplot(**CASES[13])\n\n assert _is_nanoplot_output(out_bars_with_num_ref_area)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_area = [1, 3]\n y_ref_line = 2\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_" + ], + "line": 749, + "token": 7038, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 50, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef mean(x, axis=None, keepdims=False):\n if isinstance(x, (list, tuple)):\n x = stack(x)\n x = convert_to_tensor(x)\n if axis == () or axis == []:\n # Torch handles the empty axis case differently from numpy.\n return x\n axis = to_tuple_or_list(axis) # see [NB] below\n\n ori_dtype = standardize_dtype(x.dtype)\n # torch.mean only supports floating point inputs\n compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n if \"int\" in ori_dtype or ori_dtype == \"bool\":\n result_dtype = compute_dtype\n else:\n result_dtype = ori_dtype\n\n # [NB] the python torch op torch.mean() is generated into\n # `torch._C._VariableFunctions.pyi`, and the method\n # signature is overloaded.\n # Dynamo won't actually find the correct signature of\n # `torch.mean()` if arguments are passed via kwargs\n # So we have to pass the arguments via positional args\n # EXCEPT for those that are forced as kwargs via the `*`\n # delimiter in the overloaded method signatures.\n # Additionally, we have to create a singleton-tuple\n # when `axis` is an int to match the existing fn signature\n result = torch.mean(\n x,\n axis,\n keepdims,\n dtype=to_torch_dtype(compute_dtype),\n )\n return cast(result, result_dtype)\n```\n###test function signature:\n```python\n\n def test_mean(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mean", + "reference": "\n def test_mean(self):\n x = np.array([[1, 2, 3], [3, 2, 1]])\n self.assertAllClose(knp.mean(x), np.mean(x))\n self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n self.assertAllClose(\n knp.mean(x, axis=1, keepdims=True),\n np.mean(x, axis=1, keepdims=True),\n )\n\n self.assertAllClose(knp.Mean()(x), np.mean(x))\n self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n self.assertAllClose(\n knp.Mean(axis=1, keepdims=True)(x),\n np.mean(x, axis=1, keepdims=True),\n )\n\n # test overflow\n x = np.array([65504, 65504, 65504], dtype=\"float16\")\n self.assertAllClose(knp.mean(x), np.mean(x))", + "completions": [ + "```python\n\n def test_mean(self):\n x = torch.randn(2, 3, 4)\n self.assertEqual(mean(x), torch.mean(x))\n self.assertEqual(mean(x, axis=0), torch.mean(x, dim=0))\n self.assertEqual(mean(x, axis=1), torch.mean(x, dim=1))\n self.assertEqual(mean(x, axis=2), torch.mean(x, dim=2))\n self.assertEqual(mean(x, axis=(0, 1)), torch.mean(x, dim=(0, 1)))\n self.assertEqual(mean(x, axis=(0, 2)), torch.mean(x, dim=(0, 2)))\n self.assertEqual(mean(x, axis=(1, 2)), torch.mean(x, dim=(1, 2)))\n self.assertEqual(mean(x, axis=(0, 1, 2)), torch.mean(x, dim=(0, 1, 2)))\n self.assertEqual(mean(x, axis=0, keepdims=True), torch.mean(x, dim=0, keepdim=True))\n self.assertEqual(mean(x, axis=1, keepdims=True), torch.mean(x, dim=1, keepdim=True))\n self.assertEqual(mean(x, axis=2, keepdims=True), torch.mean(x, dim=2, keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1), keepdims=True), torch.mean(x, dim=(0, 1), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 2), keepdims=True), torch.mean(x, dim=(0, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(1, 2), keepdims=True), torch.mean(x, dim=(1, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1, 2), keepdims=True), torch.mean(x, dim=(0, 1, 2), keepdim=True))\n```\n" + ], + "line": 42, + "token": 322, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 51, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_intrinsic_value(\n cash_flow: float,\n growth_rate: float,\n perpetual_growth_rate: float,\n weighted_average_cost_of_capital: float,\n cash_and_cash_equivalents: float,\n total_debt: float,\n shares_outstanding: float,\n periods: int = 5,\n) -> pd.DataFrame:\n components = {}\n\n cash_flow_projection = [cash_flow]\n\n # Cash Flow to Use\n for period in range(1, periods + 1):\n if period == 1:\n cash_flow_projection.append(cash_flow_projection[0] * (1 + growth_rate))\n else:\n cash_flow_projection.append(\n cash_flow_projection[period - 1] * (1 + growth_rate)\n )\n\n # Calculate the Terminal Value\n terminal_value = (\n cash_flow_projection[-1]\n * (1 + perpetual_growth_rate)\n / (weighted_average_cost_of_capital - perpetual_growth_rate)\n )\n\n # Add Terminal Value to the end of the cash flow projection\n cash_flow_projection[-1] = cash_flow_projection[-1] + terminal_value\n\n # Calculate the Present Value based on the Discounted Cash Flow Formula\n cash_flow_present_value = []\n for index, cash_flow_value in enumerate(cash_flow_projection):\n cash_flow_present_value.append(\n cash_flow_value / (1 + weighted_average_cost_of_capital) ** (index + 1)\n )\n\n # Calculate the Enterprise Value\n enterprise_value = sum(cash_flow_present_value)\n\n # Calculate the Equity Value\n equity_value = enterprise_value + cash_and_cash_equivalents - total_debt\n\n # Calculate the Intrinsic Value\n intrinsic_value = equity_value / shares_outstanding\n\n # Combine the components into a dictionary\n components = {\n \"Terminal Value\": terminal_value,\n \"Cash Flow Projection\": cash_flow_projection[-1],\n \"Enterprise Value\": enterprise_value,\n \"Equity Value\": equity_value,\n \"Intrinsic Value\": intrinsic_value,\n }\n\n return pd.DataFrame.from_dict(\n components, orient=\"index\", columns=[f\"Periods = {periods}\"]\n )\n```\n###test function signature:\n```python\n\ndef test_get_intrinsic_value(recorder):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_intrinsic_value", + "reference": "\ndef test_get_intrinsic_value(recorder):\n recorder.capture(\n intrinsic_model.get_intrinsic_value(\n cash_flow=500,\n growth_rate=0.05,\n perpetual_growth_rate=0.02,\n weighted_average_cost_of_capital=0.1,\n cash_and_cash_equivalents=100,\n total_debt=100,\n shares_outstanding=100,\n periods=5,\n )\n )", + "completions": [ + "```python\n\ndef test_get_intrinsic_value(recorder):\n cash_flow = 1000000\n growth_rate = 0.05\n perpetual_growth_rate = 0.05\n weighted_average_cost_of_capital = 0.1\n cash_and_cash_equivalents = 1000000\n total_debt = 1000000\n shares_outstanding = 1000000\n periods = 5\n\n result = get_intrinsic_value(\n cash_flow,\n growth_rate,\n perpetual_growth_rate,\n weighted_average_cost_of_capital,\n cash_and_cash_equivalents,\n total_debt,\n shares_outstanding,\n periods,\n )\n\n recorder.capture(result)\n```\n" + ], + "line": 61, + "token": 379, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 52, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef to_numpy(\n evset: EventSet,\n timestamp_to_datetime: bool = True,\n timestamps: bool = True,\n) -> Dict[str, ndarray]:\n timestamp_key = \"timestamp\"\n index_names = evset.schema.index_names()\n feature_names = evset.schema.feature_names()\n\n column_names = index_names + feature_names\n if timestamps:\n column_names += [timestamp_key]\n\n dst = {column_name: [] for column_name in column_names}\n for index, data in evset.data.items():\n assert isinstance(index, tuple)\n\n if timestamps:\n # Timestamps\n if evset.schema.is_unix_timestamp and timestamp_to_datetime:\n dst[timestamp_key].append(\n data.timestamps.astype(\"datetime64[s]\")\n )\n else:\n dst[timestamp_key].append(data.timestamps)\n\n # Features\n for feature_name, feature in zip(feature_names, data.features):\n dst[feature_name].append(feature)\n\n # Indexes\n num_timestamps = len(data.timestamps)\n for index_name, index_item in zip(index_names, index):\n dst[index_name].append(np.repeat(index_item, num_timestamps))\n\n dst = {k: np.concatenate(v) for k, v in dst.items()}\n return dst\n```\n###test function signature:\n```python\n\n def test_correct(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct", + "reference": "\n def test_correct(self):\n evset = event_set(\n timestamps=[\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n features={\n \"feature_1\": [0.5, 0.6],\n \"my_index\": [\"red\", \"blue\"],\n },\n indexes=[\"my_index\"],\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"timestamp\": np.array(\n [\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n dtype=\"datetime64[s]\",\n ),\n \"feature_1\": np.array([0.5, 0.6]),\n \"my_index\": np.array([b\"red\", b\"blue\"]),\n }\n\n for k in expected:\n np.testing.assert_array_equal(\n np.sort(result[k]), np.sort(expected[k])\n )", + "completions": [ + "```python\n def test_correct(self):\n evset = EventSet(\n data={\n \"timestamp\": [1610000000, 1610000000, 1610000000, 1610000000],\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n },\n schema=EventSetSchema(\n index_names=[\"index\"],\n feature_names=[\"feature1\", \"feature2\"],\n is_unix_timestamp=True,\n ),\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n \"timestamp\": [\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n ],\n }\n\n self.assertDictEqual(result, expected)\n```\n" + ], + "line": 41, + "token": 276, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 53, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def from_file(cls, path: str) -> \"LayerConfig\":\n logger.info(f\"Loading route config from {path}\")\n _, ext = os.path.splitext(path)\n with open(path, \"r\") as f:\n if ext == \".json\":\n layer = json.load(f)\n elif ext in [\".yaml\", \".yml\"]:\n layer = yaml.safe_load(f)\n else:\n raise ValueError(\n \"Unsupported file type. Only .json and .yaml are supported\"\n )\n\n if not is_valid(json.dumps(layer)):\n raise Exception(\"Invalid config JSON or YAML\")\n\n encoder_type = layer[\"encoder_type\"]\n encoder_name = layer[\"encoder_name\"]\n routes = []\n for route_data in layer[\"routes\"]:\n # Handle the 'llm' field specially if it exists\n if \"llm\" in route_data and route_data[\"llm\"] is not None:\n llm_data = route_data.pop(\n \"llm\"\n ) # Remove 'llm' from route_data and handle it separately\n # Use the module path directly from llm_data without modification\n llm_module_path = llm_data[\"module\"]\n # Dynamically import the module and then the class from that module\n llm_module = importlib.import_module(llm_module_path)\n llm_class = getattr(llm_module, llm_data[\"class\"])\n # Instantiate the LLM class with the provided model name\n llm = llm_class(name=llm_data[\"model\"])\n # Reassign the instantiated llm object back to route_data\n route_data[\"llm\"] = llm\n\n # Dynamically create the Route object using the remaining route_data\n route = Route(**route_data)\n routes.append(route)\n\n return cls(\n encoder_type=encoder_type, encoder_name=encoder_name, routes=routes\n )\n```\n###test function signature:\n```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_from_file_yaml", + "reference": "\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n config_path = tmp_path / \"config.yaml\"\n config_path.write_text(\n layer_yaml()\n ) # Assuming layer_yaml() returns a valid YAML string\n\n # Load the LayerConfig from the temporary file\n layer_config = LayerConfig.from_file(str(config_path))\n\n # Assertions to verify the loaded configuration\n assert layer_config.encoder_type == \"cohere\"\n assert layer_config.encoder_name == \"embed-english-v3.0\"\n assert len(layer_config.routes) == 2\n assert layer_config.routes[0].name == \"politics\"", + "completions": [ + "```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n yaml_file = tmp_path / \"test_layer.yaml\"\n yaml_file.write_text(\n \"\"\"\n encoder_type: openai\n encoder_name: gpt2\n routes:\n - name: test_route\n model_name: test_model\n llm:\n module: tests.test_layer\n class: TestLLM\n model: test_model\n encoder_kwargs:\n max_position_embeddings: 512\n num_attention_heads: 8\n intermediate_size: 2048\n hidden_act: gelu\n hidden_dropout_prob: 0.1\n attention_probs_dropout_prob: 0.1\n type_vocab_size: 16\n initializer_range: 0.02\n layer_norm_eps: 1e-12\n use_cache: True\n is_decoder: False\n classifier_dropout: 0.1\n use_multiplicative_input_mask: False\n num_labels: 2\n output_attentions: False\n output_hidden_states: False\n return_dict: True\n task_specific_params:\n summarization:\n max_length: 20\n min_length: 5\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n" + ], + "line": 49, + "token": 405, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 54, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def bind_all(self, circuit: QuantumCircuit) -> np.ndarray:\n arr = np.empty(self.shape, dtype=object)\n for idx in np.ndindex(self.shape):\n arr[idx] = self.bind(circuit, idx)\n return arr\n```\n###test function signature:\n```python\n def test_data(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_data", + "reference": " def test_data(self):\n with self.subTest(\"binding a single value\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 0)\n self.assertEqual(ba.shape, ())\n self.assertEqual(ba.size, 1)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuit = ba.bind(self.circuit, ())\n self.assertEqual(bound_circuit, self.circuit.assign_parameters(vals))\n\n with self.subTest(\"binding an array\"):\n vals = np.linspace(0, 1, 300).reshape((2, 3, 50))\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 2)\n self.assertEqual(ba.shape, (2, 3))\n self.assertEqual(ba.size, 6)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuits = ba.bind_all(self.circuit)\n self.assertEqual(bound_circuits.shape, (2, 3))\n self.assertEqual(bound_circuits[0, 0], self.circuit.assign_parameters(vals[0, 0]))\n self.assertEqual(bound_circuits[0, 1], self.circuit.assign_parameters(vals[0, 1]))\n self.assertEqual(bound_circuits[0, 2], self.circuit.assign_parameters(vals[0, 2]))\n self.assertEqual(bound_circuits[1, 0], self.circuit.assign_parameters(vals[1, 0]))\n self.assertEqual(bound_circuits[1, 1], self.circuit.assign_parameters(vals[1, 1]))\n self.assertEqual(bound_circuits[1, 2], self.circuit.assign_parameters(vals[1, 2]))\n\n with self.subTest(\"binding a single param\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params[0]: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 1)\n self.assertEqual(ba.ndim, 1)\n self.assertEqual(ba.shape, (50,))\n self.assertEqual(ba.size, 50)\n self.assertEqual(list(ba.data.keys()), [(self.params[0].name,)])\n np.testing.assert_allclose(list(ba.data.values()), [vals[..., np.newaxis]])", + "completions": [ + "```python\n def test_data(self):\n from qiskit import QuantumCircuit\n from qiskit.quantum_info import Pauli\n from qiskit.quantum_info import PauliTable\n from qiskit.quantum_info import PauliSumOp\n from qiskit.quantum_info import PauliOp\n from qiskit.quantum_info import PauliExpectation\n from qiskit.quantum_info import PauliMoment\n from qiskit.quantum_info import PauliTensor\n from qiskit.quantum_info import PauliList\n from qiskit.quantum_info import PauliTransform\n from qiskit.quantum_info import PauliTransformOp\n from qiskit.quantum_info import PauliTransformList\n from qiskit.quantum_info import PauliTransformSum\n from qiskit.quantum_info import PauliTransformProduct\n from qiskit.quantum_info import PauliTransformExpectation\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum" + ], + "line": 15, + "token": 122, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef set_seed(\n seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None\n) -> int:\n world_size, rank = get_world_size_and_rank()\n max_val = np.iinfo(np.uint32).max - world_size + 1\n min_val = np.iinfo(np.uint32).min\n if seed is None:\n rand_seed = torch.randint(min_val, max_val, (1,))\n seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks\n if seed < min_val or seed > max_val:\n raise ValueError(\n f\"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]\"\n )\n local_seed = seed + rank\n _log.debug(\n f\"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}\"\n )\n\n torch.manual_seed(local_seed)\n np.random.seed(local_seed)\n random.seed(local_seed)\n\n if debug_mode is not None:\n _log.debug(f\"Setting deterministic debug mode to {debug_mode}\")\n torch.set_deterministic_debug_mode(debug_mode)\n deterministic_debug_mode = torch.get_deterministic_debug_mode()\n if deterministic_debug_mode == 0:\n _log.debug(\"Disabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = False\n torch.backends.cudnn.benchmark = True\n else:\n _log.debug(\"Enabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.benchmark = False\n # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n return seed\n```\n###test function signature:\n```python\n\n def test_deterministic_true(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_deterministic_true", + "reference": "\n def test_deterministic_true(self) -> None:\n for det_debug_mode, det_debug_mode_str in [(1, \"warn\"), (2, \"error\")]:\n warn_only = det_debug_mode == 1\n for debug_mode in (det_debug_mode, det_debug_mode_str):\n set_seed(42, debug_mode=debug_mode)\n assert torch.backends.cudnn.deterministic\n assert not torch.backends.cudnn.benchmark\n assert det_debug_mode == torch.get_deterministic_debug_mode()\n assert torch.are_deterministic_algorithms_enabled()\n assert (\n warn_only == torch.is_deterministic_algorithms_warn_only_enabled()\n )\n assert os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] == \":4096:8\"", + "completions": [ + "```python\n def test_deterministic_true(self) -> None:\n seed = set_seed(debug_mode=True)\n self.assertEqual(seed, 123456789)\n```\n" + ], + "line": 45, + "token": 345, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 56, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n```\n###test function signature:\n```python\n\ndef test_concat_periodarray_2d():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_concat_periodarray_2d", + "reference": "\ndef test_concat_periodarray_2d():\n pi = pd.period_range(\"2016-01-01\", periods=36, freq=\"D\")\n arr = pi._data.reshape(6, 6)\n\n result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)\n tm.assert_period_array_equal(result, arr)\n\n result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)\n tm.assert_period_array_equal(result, arr)\n\n msg = (\n \"all the input array dimensions.* for the concatenation axis must match exactly\"\n )\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)\n\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:2], arr[2:]], axis=1)", + "completions": [ + "```python\n\ndef test_concat_periodarray_2d():\n import numpy as np\n import pandas as pd\n from pandas.core.dtypes.common import is_period_dtype\n\n def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n\n def _concat_datetime(to_concat, axis: AxisInt = 0):\n from pandas.core.arrays import PeriodArray\n\n if axis == 0:\n " + ], + "line": 70, + "token": 610, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 57, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef ingest_ocp_payload(request):\n request_id = uuid4().hex\n response_data = {\"request-id\": request_id, \"payload-name\": []}\n if request.method == \"POST\":\n if not request.FILES:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for _, file in request.FILES.items():\n payload_name = file.name\n response_data[\"payload-name\"].append(payload_name)\n s3_signature = get_s3_signature(settings.S3_ENDPOINT, payload_name, method=\"put_object\")\n res = upload_file_to_s3(s3_signature, data=file.file)\n if res.status_code == HTTPStatus.OK:\n response_data[\"upload\"] = \"success\"\n else:\n response_data[\"upload\"] = \"failed\"\n response_data[\"failed-reason\"] = res.reason\n return Response(response_data, status=res.status_code)\n send_payload(request_id, payload_name)\n else:\n params = request.query_params\n payload_names = params.get(\"payload_name\")\n if not payload_names:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for payload_name in payload_names.split(\",\"):\n send_payload(request_id, payload_name)\n response_data[\"payload-name\"].append(payload_name)\n\n response_data[\"ingest-started\"] = True\n\n return Response(response_data, status=HTTPStatus.ACCEPTED)\n```\n###test function signature:\n```python\n\n def test_ingest_ocp_payload(self):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ingest_ocp_payload", + "reference": "\n def test_ingest_ocp_payload(self):\n # Arrange\n file1 = SimpleUploadedFile(\"file1.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n file2 = SimpleUploadedFile(\"file2.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n test_table = [\n # Happy path tests\n (\n \"POST\",\n {\"file1\": file1, \"file2\": file2},\n \"\",\n HTTPStatus.ACCEPTED,\n {\"upload\": \"success\", \"ingest-started\": True},\n \"happy_path_post\",\n ),\n (\"GET\", {}, \"?payload_name=file1,file2\", HTTPStatus.ACCEPTED, {\"ingest-started\": True}, \"happy_path_get\"),\n # Edge cases\n (\"POST\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_files_post\"),\n (\"GET\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_payload_name_get\"),\n # Error cases\n (\"POST\", {\"file1\": file1}, \"\", HTTPStatus.INTERNAL_SERVER_ERROR, {\"upload\": \"failed\"}, \"error_case_post\"),\n (\"GET\", {}, \"?payload_name=non_existent_file\", HTTPStatus.ACCEPTED, {}, \"error_case_get\"),\n ]\n for test in test_table:\n method, files, query_params, expected_status, expected_response, test_id = test\n with self.subTest(test=test_id):\n\n mock_response = MagicMock()\n mock_response.status_code = (\n HTTPStatus.OK if expected_response.get(\"upload\") != \"failed\" else HTTPStatus.INTERNAL_SERVER_ERROR\n )\n mock_response.reason = \"Some reason\"\n\n with patch(\"masu.api.ingest_ocp_payload.get_s3_signature\") as mock_get_s3_signature, patch(\n \"masu.api.ingest_ocp_payload.upload_file_to_s3\", return_value=mock_response\n ) as mock_upload_file_to_s3, patch(\"masu.api.ingest_ocp_payload.send_payload\") as mock_send_payload:\n\n request = (\n APIRequestFactory().get(f\"/{query_params}\")\n if method == \"GET\"\n else APIRequestFactory().post(\"/\", data=files)\n )\n\n # Act\n response = ingest_ocp_payload(request)\n\n # Assert\n self.assertEqual(response.status_code, expected_status)\n if method == \"POST\" and files:\n mock_get_s3_signature.assert_called()\n mock_upload_file_to_s3.assert_called()\n if method == \"GET\" and \"payload_name\" in query_params:\n mock_send_payload.assert_called()\n for key, value in expected_response.items():\n self.assertEqual(response.data[key], value)", + "completions": [ + "```python\n def test_ingest_ocp_payload(self):\n # Arrange\n request = self.factory.post(\n \"/api/v1/ingest/ocp/payload\",\n data={\"payload_name\": \"payload_name\"},\n content_type=\"multipart/form-data\",\n )\n request.FILES = {\"payload_name\": \"payload_name\"}\n # Act\n response = ingest_ocp_payload(request)\n # Assert\n self.assertEqual(response.status_code, HTTPStatus.ACCEPTED)\n self.assertEqual(response.data[\"request-id\"], \"request_id\")\n self.assertEqual(response.data[\"payload-name\"], [\"payload_name\"])\n self.assertEqual(response.data[\"ingest-started\"], True)\n```\n" + ], + "line": 41, + "token": 308, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 58, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def compute_nb_items_with_metadata_field(\n project_id: str,\n metadata_field: str,\n collection_name: str,\n quantile_value: float = 0.1,\n) -> Tuple[int, float, int]:\n\n mongo_db = await get_mongo_db()\n\n try:\n pipeline = [\n # Match on the project id and the existence of the metadata field\n {\n \"$match\": {\n \"project_id\": project_id,\n f\"metadata.{metadata_field}\": {\"$exists\": True, \"$ne\": None},\n }\n },\n {\n \"$group\": {\n \"_id\": \"$user_id\", # Group by 'user_id'\n \"count\": {\"$sum\": 1}, # Count the documents in each group\n }\n },\n {\n \"$sort\": {\"count\": -1} # Sort by 'count' in descending order\n },\n ]\n\n # Execute the aggregation pipeline\n results = [] # List to hold the results\n async for group in mongo_db[collection_name].aggregate(pipeline):\n results.append(group[\"count\"])\n\n logger.debug(\n f\"Results for {metadata_field} in collection {collection_name} for project {project_id}: {results}\"\n )\n\n # If no results, return 0\n if len(results) == 0:\n return 0, 0.0, 0\n\n # Get the average and the quantiles\n average = sum(results) / len(results)\n bottom_quantile = results[int(len(results) * quantile_value)]\n top_quantile = results[int(len(results) * (1 - quantile_value))]\n\n return bottom_quantile, average, top_quantile\n\n except Exception as e:\n logger.warning(\n f\"Failed to fetch the number of {metadata_field} in collection {collection_name} for project {project_id}: {e}\",\n )\n return 0, 0.0, 0\n```\n###test function signature:\n```python\n\nasync def test_main_pipeline(db, populated_project):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_main_pipeline", + "reference": "\nasync def test_main_pipeline(db, populated_project):\n async for mongo_db in db:\n test_project_id = populated_project.id\n\n # Run the function\n topics = await fetch_topics(test_project_id)\n\n logger.debug(f\"Topics: {topics}\")\n\n count = await nb_items_with_a_metadata_field(\n test_project_id, \"tasks\", \"user_id\"\n )\n\n assert count > 0\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_successrate_metadata_quantiles(test_project_id, \"user_id\")\n\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert bottom_quantile >= 0\n # assert bottom_quantile <= top_quantile\n\n test_project_id = populated_project.id\n test_metadata_field = \"user_id\"\n test_collection_name = \"tasks\"\n test_quantile_value = 0.1\n\n # Run the function\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_nb_items_with_metadata_field(\n test_project_id,\n test_metadata_field,\n test_collection_name,\n test_quantile_value,\n )\n\n logger.debug(\"runned compute_nb_items_with_metadata_field\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, int)\n assert isinstance(average, float)\n assert isinstance(top_quantile, int)\n # assert bottom_quantile <= average <= top_quantile\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_session_length_per_metadata(\n test_project_id,\n test_metadata_field,\n test_quantile_value,\n )\n logger.debug(\"runned compute_session_length_per_metadata\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, float)\n assert isinstance(average, float)\n assert isinstance(top_quantile, float)\n # assert bottom_quantile <= average <= top_quantile", + "completions": [ + "```python\nasync def test_main_pipeline(db, populated_project):\n # Arrange\n project_id = populated_project.id\n collection_name = \"collection_name\"\n metadata_field = \"metadata_field\"\n quantile_value = 0.1\n\n # Act\n result = await compute_nb_items_with_metadata_field(\n project_id, metadata_field, collection_name, quantile_value\n )\n\n # Assert\n assert result == (0, 0.0, 0)\n```\n" + ], + "line": 56, + "token": 430, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 59, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n```\n###test function signature:\n```python\n\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_engine_auto_error_message", + "reference": "\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n from pandas.compat._optional import VERSIONS\n\n # Do we have engines installed, but a bad version of them?\n pa_min_ver = VERSIONS.get(\"pyarrow\")\n fp_min_ver = VERSIONS.get(\"fastparquet\")\n have_pa_bad_version = (\n False\n if not _HAVE_PYARROW\n else Version(pyarrow.__version__) < Version(pa_min_ver)\n )\n have_fp_bad_version = (\n False\n if not _HAVE_FASTPARQUET\n else Version(fastparquet.__version__) < Version(fp_min_ver)\n )\n # Do we have usable engines installed?\n have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version\n have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version\n\n if not have_usable_pa and not have_usable_fp:\n # No usable engines found.\n if have_pa_bad_version:\n match = f\"Pandas requires version .{pa_min_ver}. or newer of .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n\n if have_fp_bad_version:\n match = f\"Pandas requires version .{fp_min_ver}. or newer of .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")", + "completions": [ + "```python\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n import pyarrow\n import fastparquet\n from pyarrow.parquet import ParquetEngine\n from fastparquet.parquet import ParquetEngine as FastParquetEngine\n\n def get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n\n # pyarrow is installed\n assert get_engine(\"auto\").__class__ == PyArrowImpl\n assert get_engine(\"pyarrow\").__class__ == PyArrowImpl\n assert get_engine(\"fastparquet\").__class__ == FastParquetImpl\n\n # pyarrow is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'pyarrow'\" in str(excinfo.value)\n\n # fastparquet is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'fastparquet'\" in str(excinfo.value)\n\n # pyarrow is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n\n # fastparquet is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n" + ], + "line": 38, + "token": 277, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 60, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def attribute(\n self,\n inp: InterpretableInput,\n target: Union[str, torch.Tensor, None] = None,\n num_trials: int = 1,\n gen_args: Optional[Dict] = None,\n # internal callback hook can be used for logging\n _inspect_forward: Optional[Callable] = None,\n **kwargs,\n ) -> LLMAttributionResult:\n\n assert isinstance(\n inp, self.SUPPORTED_INPUTS\n ), f\"LLMAttribution does not support input type {type(inp)}\"\n\n if target is None:\n # generate when None\n assert hasattr(self.model, \"generate\") and callable(self.model.generate), (\n \"The model does not have recognizable generate function.\"\n \"Target must be given for attribution\"\n )\n\n if not gen_args:\n gen_args = DEFAULT_GEN_ARGS\n\n model_inp = self._format_model_input(inp.to_model_input())\n output_tokens = self.model.generate(model_inp, **gen_args)\n target_tokens = output_tokens[0][model_inp.size(1) :]\n else:\n assert gen_args is None, \"gen_args must be None when target is given\"\n\n if type(target) is str:\n # exclude sos\n target_tokens = self.tokenizer.encode(target)[1:]\n target_tokens = torch.tensor(target_tokens)\n elif type(target) is torch.Tensor:\n target_tokens = target\n\n attr = torch.zeros(\n [\n 1 + len(target_tokens) if self.include_per_token_attr else 1,\n inp.n_itp_features,\n ],\n dtype=torch.float,\n device=self.device,\n )\n\n for _ in range(num_trials):\n attr_input = inp.to_tensor().to(self.device)\n\n cur_attr = self.attr_method.attribute(\n attr_input,\n additional_forward_args=(inp, target_tokens, _inspect_forward),\n **kwargs,\n )\n\n # temp necessary due to FA & Shapley's different return shape of multi-task\n # FA will flatten output shape internally (n_output_token, n_itp_features)\n # Shapley will keep output shape (batch, n_output_token, n_input_features)\n cur_attr = cur_attr.reshape(attr.shape)\n\n attr += cur_attr\n\n attr = attr / num_trials\n\n attr = inp.format_attr(attr)\n\n return LLMAttributionResult(\n attr[0],\n attr[1:]\n if self.include_per_token_attr\n else None, # shape(n_output_token, n_input_features)\n inp.values,\n self.tokenizer.convert_ids_to_tokens(target_tokens),\n )\n```\n###test function signature:\n```python\n\n def test_llm_attr(self, AttrClass) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_llm_attr", + "reference": "\n def test_llm_attr(self, AttrClass) -> None:\n llm = DummyLLM()\n llm.to(self.device)\n tokenizer = DummyTokenizer()\n llm_attr = LLMAttribution(AttrClass(llm), tokenizer)\n\n inp = TextTemplateInput(\"{} b {} {} e {}\", [\"a\", \"c\", \"d\", \"f\"])\n res = llm_attr.attribute(inp, \"m n o p q\")\n\n self.assertEqual(res.seq_attr.shape, (4,))\n self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))\n self.assertEqual(res.input_tokens, [\"a\", \"c\", \"d\", \"f\"])\n self.assertEqual(res.output_tokens, [\"m\", \"n\", \"o\", \"p\", \"q\"])\n self.assertEqual(res.seq_attr.device.type, self.device)\n self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)", + "completions": [ + "```python\n def test_llm_attr(self, AttrClass) -> None:\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for L" + ], + "line": 72, + "token": 464, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 61, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def plan(self, task: str):\n agent_pool_description = self.agent_pool.get_agent_pool_description()\n feedback = self.get_last_feedback()\n last_plan = self.get_last_plan()\n\n # run the planning\n plan = self.planning.execute(task, agent_pool_description, last_plan, feedback)\n\n logger.info(f\"Plan: {plan}\")\n\n self.shared_memory.add(\n event_type=EventType.planning, agent=self.name, content=str(plan)\n )\n\n return plan\n```\n###test function signature:\n```python\n\ndef test_planner_succeeds(get_llm): # noqa: F811\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_planner_succeeds", + "reference": "\ndef test_planner_succeeds(get_llm): # noqa: F811\n llm = get_llm(__file__, test_planner_succeeds.__name__)\n\n physicist_description = (\n \"The physicist agent answers questions or research about physics-related topics\"\n )\n physicist = Physicist(\n name=\"Physicist\",\n description=physicist_description,\n llm=llm,\n )\n\n agent_pool = AgentPool()\n agent_pool.add_agents([physicist])\n\n shared_memory = SharedMemory(\n objective=\"Share the information across different agents.\",\n agent_pool=agent_pool,\n )\n\n planner = Planner(\n name=\"planner\",\n agent_pool=agent_pool,\n shared_memory=shared_memory,\n llm=llm,\n )\n\n task = \"\"\"We need to render a highly complex 3D image on the solar system. We can use any publicly avaliable\n resources to achieve this task.\"\"\" # noqa: E501\n\n p = planner.plan(task=task)\n assert len(p.steps) > 0", + "completions": [ + "```python\ndef test_planner_succeeds(get_llm): # noqa: F811\n planner = get_llm\n planner.agent_pool.set_agent_pool_description(\n {\n \"agent_pool\": {\n \"agents\": [\n {\n \"name\": \"agent1\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n {\n \"name\": \"agent2\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n ]\n }\n }\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task1\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task2\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task3\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task4\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task5\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task6\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task7\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\"," + ], + "line": 21, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_edge(self, from_step: str, to_step: str) -> None:\n if from_step not in self.G:\n raise ValueError(f\"Step with name '{from_step}' does not exist\")\n\n if to_step not in self.G:\n raise ValueError(f\"Step with name '{to_step}' does not exist\")\n\n if to_step in self.G[from_step]:\n raise ValueError(\n f\"There is already a edge from '{to_step}' to '{from_step}'\"\n )\n\n if to_step in nx.ancestors(self.G, from_step):\n raise ValueError(\n f\"Cannot add edge from '{from_step}' to '{to_step}' as it would create a cycle.\"\n )\n\n self.G.add_edge(from_step, to_step)\n```\n###test function signature:\n```python\n\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validate_missing_step_input", + "reference": "\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n class DummyStep3(Step):\n @property\n def inputs(self) -> List[str]:\n return [\"instruction\"]\n\n @property\n def outputs(self) -> List[str]:\n return [\"response\"]\n\n def process(self) -> \"StepOutput\": # type: ignore\n yield [{\"response\": \"response1\"}]\n\n dag = DAG()\n dag.add_step(dummy_generator_step)\n dag.add_step(DummyStep3(name=\"dummy_step_3\", pipeline=pipeline))\n dag.add_edge(\"dummy_generator_step\", \"dummy_step_3\")\n\n with pytest.raises(\n ValueError,\n match=\"Step 'dummy_step_3' should have a parameter with type hint `StepInput`\",\n ):\n dag.validate()", + "completions": [ + "```python\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_2\")\n```\n" + ], + "line": 26, + "token": 238, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 63, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_hourly_temp_mean", + "reference": "\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n meter_data = il_electricity_cdd_hdd_hourly[\"meter_data\"][\"2016-03-01\":\"2016-07-01\"]\n temperature_data = il_electricity_cdd_hdd_hourly[\"temperature_data\"][\n \"2016-03-01\":\"2016-07-01\"\n ]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert list(sorted(df.columns)) == [\n \"n_hours_dropped\",\n \"n_hours_kept\",\n \"temperature_mean\",\n ]\n assert df.shape == (2952, 3)\n\n assert round(df.temperature_mean.mean()) == 62.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n df = il_electricity_cdd_hdd_hourly.copy()\n df = df.iloc[100:150]\n df = df.assign(\n **{\n \"cdd_10\": np.maximum(df.temperature_mean - 10, 0),\n \"cdd_20\": np.maximum(df.temperature_mean - 20, 0),\n \"hdd_10\": np.maximum(10 - df.temperature_mean, 0),\n \"hdd_20\": np.maximum(20 - df.temperature_mean, 0),\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n df = df.drop([\"temperature_mean\"], axis=1)\n df = df.reindex(il_electricity_cdd_hdd_hourly.index)\n df = overwrite_partial_rows_with_nan(df)\n df = df.iloc[:-1].reindex(df.index)\n assert_frame_equal(df, il_electricity_cdd_hdd_hourly)\n```\n" + ], + "line": 167, + "token": 992, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 64, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef pandasSQL_builder(\n con,\n schema: str | None = None,\n need_transaction: bool = False,\n) -> PandasSQL:\n import sqlite3\n\n if isinstance(con, sqlite3.Connection) or con is None:\n return SQLiteDatabase(con)\n\n sqlalchemy = import_optional_dependency(\"sqlalchemy\", errors=\"ignore\")\n\n if isinstance(con, str) and sqlalchemy is None:\n raise ImportError(\"Using URI string without sqlalchemy installed.\")\n\n if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):\n return SQLDatabase(con, schema, need_transaction)\n\n warnings.warn(\n \"pandas only supports SQLAlchemy connectable (engine/connection) or \"\n \"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 \"\n \"objects are not tested. Please consider using SQLAlchemy.\",\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return SQLiteDatabase(con)\n```\n###test function signature:\n```python\n\n def test_execute_closed_connection(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_execute_closed_connection", + "reference": "\n def test_execute_closed_connection(self):\n create_sql = \"\"\"\n CREATE TABLE test\n (\n a TEXT,\n b TEXT,\n c REAL,\n PRIMARY KEY (a, b)\n );\n \"\"\"\n with contextlib.closing(sqlite3.connect(\":memory:\")) as conn:\n cur = conn.cursor()\n cur.execute(create_sql)\n\n with sql.pandasSQL_builder(conn) as pandas_sql:\n pandas_sql.execute('INSERT INTO test VALUES(\"foo\", \"bar\", 1.234)')\n\n msg = \"Cannot operate on a closed database.\"\n with pytest.raises(sqlite3.ProgrammingError, match=msg):\n tquery(\"select * from test\", con=conn)", + "completions": [ + "```python\n def test_execute_closed_connection(self):\n from pandas.io.sql import PandasSQL\n from pandas.io.sql.sqldatabase import SQLiteDatabase\n from pandas.io.sql.pandas_sql import PandasSQLQueryCompiler\n from pandas.io.sql.pandas_sql import PandasSQLEngine\n from pandas.io.sql.pandas_sql import PandasSQLTable\n from pandas.io.sql.pandas_sql import PandasSQLConverter\n from pandas.io.sql.pandas_sql import PandasSQLAlchemy\n from pandas.io.sql.pandas_sql import PandasSQLQuery\n from pandas.io.sql.pandas_sql import PandasSQLExecutionContext\n from pandas.io.sql.pandas_sql import PandasSQLCompiler\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas." + ], + "line": 31, + "token": 237, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 65, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_hourly_degree_days", + "reference": "\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index,\n temperature_data,\n heating_balance_points=[60, 61],\n cooling_balance_points=[65, 66],\n temperature_mean=False,\n degree_day_method=\"hourly\",\n )\n assert df.shape == (27, 6)\n assert list(sorted(df.columns)) == [\n \"cdd_65\",\n \"cdd_66\",\n \"hdd_60\",\n \"hdd_61\",\n \"n_hours_dropped\",\n \"n_hours_kept\",\n ]\n snapshot.assert_match(\n [\n round(df.hdd_60.mean(), 2),\n round(df.hdd_61.mean(), 2),\n round(df.cdd_65.mean(), 2),\n round(df.cdd_66.mean(), 2),\n round(df.n_hours_kept.mean(), 2),\n round(df.n_hours_dropped.mean(), 2),\n ],\n \"values\",\n )", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"H\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 1" + ], + "line": 168, + "token": 985, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 66, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef initialize_model_parallel(\n tensor_model_parallel_size: int = 1,\n pipeline_model_parallel_size: int = 1,\n virtual_pipeline_model_parallel_size: Optional[int] = None,\n pipeline_model_parallel_split_rank: Optional[int] = None,\n use_fp8: bool = False,\n) -> None:\n # Get world size and rank. Ensure some consistencies.\n assert torch.distributed.is_initialized()\n world_size: int = torch.distributed.get_world_size()\n\n if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:\n raise RuntimeError(\n f\"world_size ({world_size}) is not divisible by tensor_model_parallel_size \"\n f\"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n )\n\n data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n\n num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n num_data_parallel_groups: int = world_size // data_parallel_size\n\n if virtual_pipeline_model_parallel_size is not None:\n if not pipeline_model_parallel_size > 2:\n raise RuntimeError(\"pipeline-model-parallel size should be greater than 2 with \" \"interleaved schedule\")\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size\n\n if pipeline_model_parallel_split_rank is not None:\n global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK\n _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank\n\n rank = torch.distributed.get_rank()\n\n # Build the data-parallel groups.\n global _DATA_PARALLEL_GROUP\n global _DATA_PARALLEL_GROUP_GLOO\n global _DATA_PARALLEL_GLOBAL_RANKS\n assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'\n all_data_parallel_group_ranks = []\n for i in range(pipeline_model_parallel_size):\n start_rank = i * num_pipeline_model_parallel_groups\n end_rank = (i + 1) * num_pipeline_model_parallel_groups\n for j in range(tensor_model_parallel_size):\n ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)\n all_data_parallel_group_ranks.append(list(ranks))\n group = torch.distributed.new_group(ranks)\n group_gloo = torch.distributed.new_group(ranks, backend=\"gloo\")\n if rank in ranks:\n _DATA_PARALLEL_GROUP = group\n _DATA_PARALLEL_GROUP_GLOO = group_gloo\n _DATA_PARALLEL_GLOBAL_RANKS = ranks\n\n # Build the model-parallel groups.\n global _MODEL_PARALLEL_GROUP\n assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'\n for i in range(data_parallel_size):\n ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _MODEL_PARALLEL_GROUP = group\n\n # Build the tensor model-parallel groups.\n global _TENSOR_MODEL_PARALLEL_GROUP\n assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized'\n for i in range(num_tensor_model_parallel_groups):\n ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _TENSOR_MODEL_PARALLEL_GROUP = group\n\n # Build the pipeline model-parallel groups and embedding groups\n # (first and last rank in each pipeline model-parallel group).\n global _PIPELINE_MODEL_PARALLEL_GROUP\n global _PIPELINE_GLOBAL_RANKS\n assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized'\n global _EMBEDDING_GROUP\n global _EMBEDDING_GLOBAL_RANKS\n assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'\n global _POSITION_EMBEDDING_GROUP\n global _POSITION_EMBEDDING_GLOBAL_RANKS\n assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'\n for i in range(num_pipeline_model_parallel_groups):\n ranks = range(i, world_size, num_pipeline_model_parallel_groups)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _PIPELINE_MODEL_PARALLEL_GROUP = group\n _PIPELINE_GLOBAL_RANKS = ranks\n # Setup embedding group (to exchange gradients between\n # first and last stages).\n if len(ranks) > 1:\n embedding_ranks = [ranks[0], ranks[-1]]\n position_embedding_ranks = [ranks[0]]\n if pipeline_model_parallel_split_rank is not None:\n if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:\n embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]]\n if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:\n position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]\n else:\n embedding_ranks = ranks\n position_embedding_ranks = ranks\n\n group = torch.distributed.new_group(embedding_ranks)\n if rank in embedding_ranks:\n _EMBEDDING_GROUP = group\n if rank in ranks:\n _EMBEDDING_GLOBAL_RANKS = embedding_ranks\n\n group = torch.distributed.new_group(position_embedding_ranks)\n if rank in position_embedding_ranks:\n _POSITION_EMBEDDING_GROUP = group\n if rank in ranks:\n _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks\n\n # Build the FP8 groups.\n global _AMAX_REDUCTION_GROUP\n assert _AMAX_REDUCTION_GROUP is None, \\\n 'FP8 amax reduction group is already initialized'\n if use_fp8:\n amax_group_size: int = tensor_model_parallel_size * data_parallel_size\n num_amax_groups: int = world_size // amax_group_size\n for i in range(num_amax_groups):\n start_rank = i * amax_group_size\n end_rank = (i + 1) * amax_group_size\n ranks = range(start_rank, end_rank)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _AMAX_REDUCTION_GROUP = group\n\n # Initialize global memory buffer\n # This isn't really \"parallel state\" but there isn't another good place to\n # put this. If we end up with a more generic initialization of megatron-core\n # we could stick it there\n _set_global_memory_buffer()\n```\n###test function signature:\n```python\n\ndef test_initialize__and_destroy_model_parallel():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_initialize__and_destroy_model_parallel", + "reference": "\ndef test_initialize__and_destroy_model_parallel():\n with pytest.raises(AssertionError):\n assert(ps.initialize_model_parallel())\n Utils.initialize_distributed()\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))\n Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)\n\n assert(ps.model_parallel_is_initialized())\n assert(ps.get_model_parallel_group() is not None)\n assert(ps.get_tensor_model_parallel_group() is not None)\n assert(ps.get_pipeline_model_parallel_group() is not None)\n assert(ps.get_data_parallel_group() is not None) \n Utils.destroy_model_parallel()\n assert(ps._MODEL_PARALLEL_GROUP is None)", + "completions": [ + "```python\ndef test_initialize_and_destroy_model_parallel():\n initialize_model_parallel(\n tensor_model_parallel_size=1,\n pipeline_model_parallel_size=1,\n virtual_pipeline_model_parallel_size=None,\n pipeline_model_parallel_split_rank=None,\n use_fp8=False,\n )\n destroy_model_parallel()\n```\n" + ], + "line": 133, + "token": 886, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 67, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_shapes(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_shapes", + "reference": "\n def test_spectrogram_shapes(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 729))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 183))", + "completions": [ + "```python\n def test_spectrogram_shapes(self):\n waveform = np.random.rand(1000)\n window = np.hanning(100)\n frame_length = 100\n hop_length = 10\n fft_length = 100\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (10, 100))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 68, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_base64_image_given_and_no_resize_needed", + "reference": "\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n _, image = example_local_image\n _, buffer = cv2.imencode(\".jpg\", image)\n base64_image = base64.b64encode(buffer).decode(\"ascii\")\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=base64_image,\n max_width=None,\n max_height=None,\n )\n\n # then\n assert (\n scaling_factor is None\n ), \"No resize parameters given, so scaling factor should not be established\"\n assert (\n serialised_image == base64_image\n ), \"Serialised image should be identical with input\"", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n image_bytes = base64.b64decode(example_local_image[0])\n image = bytes_to_opencv_image(payload=image_bytes)\n # when\n result = load_image_from_string(\n reference=example_local_image[0],\n max_height=None,\n max_width=None,\n )\n # then\n assert result[0] == example_local_image[0]\n assert result[1] is None\n assert np.array_equal(image, result[1])\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 69, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def submit_payout_address(self, robot_index=1):\n path = reverse(\"order\")\n params = f\"?order_id={self.order_id}\"\n headers = self.get_robot_auth(robot_index)\n\n payout_address = create_address(\"robot\")\n signed_payout_address = sign_message(\n payout_address,\n passphrase_path=f\"tests/robots/{robot_index}/token\",\n private_key_path=f\"tests/robots/{robot_index}/enc_priv_key\",\n )\n body = {\n \"action\": \"update_address\",\n \"address\": signed_payout_address,\n \"mining_fee_rate\": 50,\n }\n self.response = self.client.post(path + params, body, **headers)\n```\n###test function signature:\n```python\n def test_successful_onchain(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_successful_onchain", + "reference": " def test_successful_onchain(self):\n trade = Trade(self.client)\n trade.publish_order()\n trade.take_order()\n trade.lock_taker_bond()\n trade.lock_escrow(trade.taker_index)\n trade.submit_payout_address(trade.maker_index)\n trade.confirm_fiat(trade.maker_index)\n trade.confirm_fiat(trade.taker_index)\n\n trade.process_payouts(mine_a_block=True)\n trade.get_order(trade.maker_index)\n\n data = trade.response.json()\n\n self.assertEqual(trade.response.status_code, 200)\n self.assertResponse(trade.response)\n\n self.assertEqual(data[\"status_message\"], Order.Status(Order.Status.SUC).label)\n self.assertTrue(data[\"is_fiat_sent\"])\n self.assertFalse(data[\"is_disputed\"])\n self.assertIsInstance(data[\"maker_summary\"][\"address\"], str)\n self.assertIsHash(data[\"maker_summary\"][\"txid\"])\n\n self.assert_order_logs(data[\"id\"])", + "completions": [ + "```python\n def test_successful_onchain(self):\n self.submit_payout_address(robot_index=1)\n self.assertEqual(self.response.status_code, 200)\n self.assertEqual(self.response.json()[\"status\"], \"success\")\n```\n" + ], + "line": 26, + "token": 183, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 70, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_blank_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_blank_line", + "reference": "\ndef test_header_to_metadata_and_cell_blank_line():\n text = \"\"\"---\ntitle: Sample header\n---\n\nHeader is followed by a blank line\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {}\n assert lines[pos].startswith(\"Header is\")", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_blank_line():\n lines = [\"#' ---\", \"\"]\n header_prefix = \"#'\"\n header_suffix = \"\"\n ext = None\n root_level_metadata_as_raw_cell = True\n metadata, jupyter, cell, i = header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext, root_level_metadata_as_raw_cell\n )\n assert metadata == {}\n assert jupyter == []\n assert cell == new_raw_cell(source=\"\\n\".join([\"---\"]), metadata={\"lines_to_next_cell\": 1})\n assert i == 1\n```\n" + ], + "line": 104, + "token": 642, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 71, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def _create_output(\n self,\n accepted: torch.Tensor, # [batch_size, k]\n recovered_token_ids: torch.Tensor, # [batch_size, k]\n draft_token_ids: torch.Tensor, # [batch_size, k]\n bonus_token_ids: torch.Tensor, # [batch_size]\n ) -> torch.Tensor:\n bonus_token_ids = bonus_token_ids.squeeze()\n batch_size, k = recovered_token_ids.shape\n\n # Determine the index of the first False value for each row.\n limits = (accepted == 0).max(1).indices\n limits[~(accepted == 0).any(1)] = k\n\n # Create masks using the indices.\n indices = torch.arange(k, device=accepted.device).unsqueeze(0)\n accepted_mask = indices < limits.unsqueeze(1)\n after_false_mask = indices == limits.unsqueeze(1)\n\n # Create an extended output tensor\n output_with_bonus_tokens = -torch.ones(\n (batch_size, k + self._num_bonus_tokens),\n dtype=self.token_id_dtype,\n device=accepted.device)\n output = output_with_bonus_tokens[:, :k]\n\n # Fill in the first k columns of the output tensor using masks and data\n # tensors.\n output[:, :k] = torch.where(accepted_mask, draft_token_ids,\n -torch.ones_like(draft_token_ids))\n\n # Fill the last column.\n # We check output directly as accepted may have True values inconsistent\n # with causal acceptance.\n output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,\n bonus_token_ids, -1)\n\n # We disable bonus tokens because it causes corrupt KV cache for\n # proposal methods that require KV cache. We can fix it by \"prefilling\"\n # the bonus token in the proposer. The following issue tracks the fix.\n # https://github.com/vllm-project/vllm/issues/4212\n output_with_bonus_tokens[:, -1] = -1\n\n # Fill the recovered token ids.\n output.mul_(~after_false_mask).add_(\n recovered_token_ids.mul(after_false_mask))\n\n self.num_accepted_tokens += accepted.sum()\n self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()\n self.num_draft_tokens += batch_size * k\n\n return output_with_bonus_tokens\n```\n###test function signature:\n```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct_output_format", + "reference": "def test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n set_random_seed(seed)\n torch.set_default_device(device)\n\n batch_size = 10\n k = 5\n vocab_size = 3000\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"no_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"some_tokens_accepted\":\n last_accepted_indices = torch.randint(low=-1,\n high=k,\n size=(batch_size, ))\n accepted = mock_causal_accepted_tensor(k, last_accepted_indices)\n else:\n raise AssertionError()\n\n recovered_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n draft_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n bonus_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, 1),\n dtype=torch.int64)\n\n rejection_sampler = RejectionSampler()\n rejection_sampler.init_gpu_tensors(rank=0)\n output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access\n accepted,\n recovered_token_ids,\n draft_token_ids,\n bonus_token_ids,\n )\n\n # Bonus tokens are currently disabled. Verify they're set to -1.\n # See https://github.com/vllm-project/vllm/issues/4212\n expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n # Expect all tokens to be equal to draft tokens.\n assert torch.equal(output_token_ids[:, :-1], draft_token_ids)\n\n # Expect all bonus tokens to be included.\n assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)\n elif which_tokens_accepted == \"no_tokens_accepted\":\n # Expect first token to be equal to recovered tokens.\n assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])\n\n # Expect everything else to be -1.\n assert torch.equal(output_token_ids[:, 1:],\n torch.ones_like(output_token_ids[:, 1:]) * -1)\n elif which_tokens_accepted == \"some_tokens_accepted\":\n recovered_plus_bonus = torch.cat(\n (recovered_token_ids, expected_bonus_token_ids), dim=-1)\n # Assert first rejected token is a recovered token or bonus token.\n assert torch.equal(\n recovered_plus_bonus[torch.arange(0, batch_size),\n last_accepted_indices + 1],\n output_token_ids[torch.arange(0, batch_size),\n last_accepted_indices + 1])\n\n # Assert every subsequent token is -1.\n subsequent_mask = torch.arange(0, k + 1).expand(\n batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)\n assert torch.all(output_token_ids[subsequent_mask] == -1)", + "completions": [ + "```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n torch.manual_seed(seed)\n batch_size = 10\n k = 10\n num_bonus_tokens = 1\n accepted = torch.randint(0, 2, (batch_size, k), dtype=torch.int32)\n recovered_token_ids = torch.randint(0, 10, (batch_size, k),\n dtype=torch.int32)\n draft_token_ids = torch.randint(0, 10, (batch_size, k), dtype=torch.int32)\n bonus_token_ids = torch.randint(0, 10, (batch_size,), dtype=torch.int32)\n output = _create_output(accepted, recovered_token_ids, draft_token_ids,\n bonus_token_ids)\n assert output.shape == (batch_size, k + num_bonus_tokens)\n assert output.dtype == torch.int32\n assert output.device == device\n assert (output[:, :k] == torch.where(accepted, draft_token_ids,\n -torch.ones_like(draft_token_ids))).all()\n assert (output[:, -1] == torch.where(output[:, -1] != -1, bonus_token_ids,\n -1)).all()\n assert (output[:, -1] == -1).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted" + ], + "line": 54, + "token": 433, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 72, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_array(fp, allow_pickle=False, pickle_kwargs=None, *,\n max_header_size=_MAX_HEADER_SIZE):\n if allow_pickle:\n # Effectively ignore max_header_size, since `allow_pickle` indicates\n # that the input is fully trusted.\n max_header_size = 2**64\n\n version = read_magic(fp)\n _check_version(version)\n shape, fortran_order, dtype = _read_array_header(\n fp, version, max_header_size=max_header_size)\n if len(shape) == 0:\n count = 1\n else:\n count = numpy.multiply.reduce(shape, dtype=numpy.int64)\n\n # Now read the actual data.\n if dtype.hasobject:\n # The array contained Python objects. We need to unpickle the data.\n if not allow_pickle:\n raise ValueError(\"Object arrays cannot be loaded when \"\n \"allow_pickle=False\")\n if pickle_kwargs is None:\n pickle_kwargs = {}\n try:\n array = pickle.load(fp, **pickle_kwargs)\n except UnicodeError as err:\n # Friendlier error message\n raise UnicodeError(\"Unpickling a python object failed: %r\\n\"\n \"You may need to pass the encoding= option \"\n \"to numpy.load\" % (err,)) from err\n else:\n if isfileobj(fp):\n # We can use the fast fromfile() function.\n array = numpy.fromfile(fp, dtype=dtype, count=count)\n else:\n # This is not a real file. We have to read it the\n # memory-intensive way.\n # crc32 module fails on reads greater than 2 ** 32 bytes,\n # breaking large reads from gzip streams. Chunk reads to\n # BUFFER_SIZE bytes to avoid issue and reduce memory overhead\n # of the read. In non-chunked case count < max_read_count, so\n # only one read is performed.\n\n # Use np.ndarray instead of np.empty since the latter does\n # not correctly instantiate zero-width string dtypes; see\n # https://github.com/numpy/numpy/pull/6430\n array = numpy.ndarray(count, dtype=dtype)\n\n if dtype.itemsize > 0:\n # If dtype.itemsize == 0 then there's nothing more to read\n max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)\n\n for i in range(0, count, max_read_count):\n read_count = min(max_read_count, count - i)\n read_size = int(read_count * dtype.itemsize)\n data = _read_bytes(fp, read_size, \"array data\")\n array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n count=read_count)\n\n if fortran_order:\n array.shape = shape[::-1]\n array = array.transpose()\n else:\n array.shape = shape\n\n return array\n```\n###test function signature:\n```python\n\ndef test_version_2_0():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_version_2_0", + "reference": "\ndef test_version_2_0():\n f = BytesIO()\n # requires more than 2 byte for header\n dt = [((\"%d\" % i) * 100, float) for i in range(500)]\n d = np.ones(1000, dtype=dt)\n\n format.write_array(f, d, version=(2, 0))\n with warnings.catch_warnings(record=True) as w:\n warnings.filterwarnings('always', '', UserWarning)\n format.write_array(f, d)\n assert_(w[0].category is UserWarning)\n\n # check alignment of data portion\n f.seek(0)\n header = f.readline()\n assert_(len(header) % format.ARRAY_ALIGN == 0)\n\n f.seek(0)\n n = format.read_array(f, max_header_size=200000)\n assert_array_equal(d, n)\n\n # 1.0 requested but data cannot be saved this way\n assert_raises(ValueError, format.write_array, f, d, (1, 0))", + "completions": [ + "```python\ndef test_version_2_0():\n fp = open(\"test_file.npy\", \"rb\")\n read_array(fp)\n```\n" + ], + "line": 70, + "token": 516, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 73, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):\n if not torch.is_tensor(audio):\n if isinstance(audio, str):\n audio = load_audio(audio)\n audio = torch.from_numpy(audio)\n\n window = torch.hann_window(N_FFT).to(audio.device)\n stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)\n magnitudes = stft[..., :-1].abs() ** 2\n\n filters = mel_filters(audio.device, n_mels)\n mel_spec = filters @ magnitudes\n\n log_spec = torch.clamp(mel_spec, min=1e-10).log10()\n log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)\n log_spec = (log_spec + 4.0) / 4.0\n return log_spec\n```\n###test function signature:\n```python\n\ndef test_audio():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_audio", + "reference": "\ndef test_audio():\n audio_path = os.path.join(os.path.dirname(__file__), \"jfk.flac\")\n audio = load_audio(audio_path)\n assert audio.ndim == 1\n assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12\n assert 0 < audio.std() < 1\n\n mel_from_audio = log_mel_spectrogram(audio)\n mel_from_file = log_mel_spectrogram(audio_path)\n\n assert np.allclose(mel_from_audio, mel_from_file)\n assert mel_from_audio.max() - mel_from_audio.min() <= 2.0", + "completions": [ + "```python\ndef test_audio():\n audio = torch.randn(1, 16000)\n log_mel_spectrogram(audio)\n```\n" + ], + "line": 24, + "token": 213, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 74, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef transform_grid_cell(\n grid: Grid,\n cell: npt.NDArray[np.float_] = np.eye(3),\n output_shape: Optional[Tuple[int, int, int]] = None,\n mode: str = \"constant\",\n order: int = 1,\n **kwargs,\n) -> Grid:\n # Take the current shape of the grid if no output shape was provided\n if output_shape is None:\n output_shape = grid.shape\n\n # Make sure the cell has type float\n cell = np.asarray(cell, dtype=float)\n\n # Get the current cell in coordinates of the destination axes\n inv_cell = cell_invert(cell).T\n projected_cell = grid.cell.dot(inv_cell)\n\n # From that, infere how long will the bounding box of the cell be\n lengths = abs(projected_cell).sum(axis=0)\n\n # Create the transformation matrix. Since we want to control the shape\n # of the output, we can not use grid.dcell directly, we need to modify it.\n scales = output_shape / lengths\n forward_t = (grid.dcell.dot(inv_cell) * scales).T\n\n # Scipy's affine transform asks for the inverse transformation matrix, to\n # map from output pixels to input pixels. By taking the inverse of our\n # transformation matrix, we get exactly that.\n tr = cell_invert(forward_t).T\n\n # Calculate the offset of the image so that all points of the grid \"fall\" inside\n # the output array.\n # For this we just calculate the centers of the input and output images\n center_input = 0.5 * (_a.asarrayd(grid.shape) - 1)\n center_output = 0.5 * (_a.asarrayd(output_shape) - 1)\n\n # And then make sure that the input center that is interpolated from the output\n # falls in the actual input's center\n offset = center_input - tr.dot(center_output)\n\n # We pass all the parameters to scipy's affine_transform\n transformed_image = affine_transform(\n grid.grid,\n tr,\n order=1,\n offset=offset,\n output_shape=output_shape,\n mode=mode,\n **kwargs,\n )\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1))\n new_grid.grid = transformed_image\n new_grid.geometry = grid.geometry\n new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset))\n\n # Find the offset between the origin before and after the transformation\n return new_grid\n```\n###test function signature:\n```python\n\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_grid_cell", + "reference": "\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10))\n\n assert new_grid.shape == (10, 10, 10)\n assert new_grid.lattice.is_cartesian()\n\n if not skewed:\n assert np.allclose(new_grid.lattice.cell, grid.lattice.cell)\n\n assert np.allclose(new_grid.grid, grid.grid) == (not skewed)\n assert not np.allclose(new_grid.grid, 0)\n\n assert np.allclose(new_grid.origin, grid.origin) == (not skewed)\n\n # Convert to a skewed cell\n directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]])\n new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5))\n\n assert new_grid.shape == (5, 5, 5)\n for i in range(3):\n n = new_grid.lattice.cell[i] / directions[i]\n assert np.allclose(n, n[0])", + "completions": [ + "```python\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n cell = grid.cell.dot(cell_invert(grid.cell))\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = transform_grid_cell(grid, cell, grid.shape, mode=\"constant\", order=1)\n\n # Check that the new grid has the same shape and cell as the original grid\n assert new_grid.shape == grid.shape\n assert np.allclose(new_grid.cell, grid.cell)\n\n # Check that the new grid has the same geometry as the original grid\n assert new_grid.geometry == grid.geometry\n\n # Check that the new grid has the same origin as the original grid\n assert np.allclose(new_grid.origin, grid.origin)\n\n # Check that the new grid has the same dcell as the original grid\n assert np.allclose(new_grid.dcell, grid.dcell)\n\n # Check that the new grid has the same grid as the original grid\n assert np.allclose(new_grid.grid, grid.grid)\n\n # Check that the new grid has the same skewed as the original grid\n assert np.allclose(new_grid.skewed, grid.skewed)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid" + ], + "line": 63, + "token": 533, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 75, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:\n\n index_norm = norm(index, axis=1)\n xq_norm = norm(xq.T)\n sim = np.dot(index, xq.T) / (index_norm * xq_norm)\n return sim\n```\n###test function signature:\n```python\ndef test_top_scores__is_sorted(test_index):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_top_scores__is_sorted", + "reference": "def test_top_scores__is_sorted(test_index):\n\n xq = test_index[0] # should have max similarity\n\n sim = similarity_matrix(xq, test_index)\n _, idx = top_scores(sim, 3)\n\n # Scores and indexes should be sorted ascending\n assert np.array_equal(idx, np.array([2, 1, 0]))", + "completions": [ + "```python\ndef test_top_scores__is_sorted(test_index):\n sim = similarity_matrix(xq, test_index)\n assert np.all(np.diff(sim, axis=1) >= 0)\n```\n" + ], + "line": 15, + "token": 125, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_mixed_ref_area_2", + "reference": "\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n\n out_bars_with_mixed_ref_area_2 = _generate_nanoplot(**CASES[16])\n\n assert _is_nanoplot_output(out_bars_with_mixed_ref_area_2)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n y_ref_area = [1, 10]\n y_ref_line = 5\n x_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n expand_x = \"2020-01-01 00:00:00 UTC\"\n expand_y = \"2020-01-01 00:00:00 UTC\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n all_single_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_" + ], + "line": 749, + "token": 7038, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 77, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def mean(self, lower=0.0, upper=np.inf):\n loc = []\n for idx in self.index:\n qpd = self.qpd.loc[idx, :].values[0]\n l, _ = quad(exp_func, args=(qpd), a=lower, b=upper)\n loc.append(l)\n loc_arr = np.array(loc)\n return pd.DataFrame(loc_arr, index=self.index, columns=self.columns)\n```\n###test function signature:\n```python\ndef test_qpd_b_simple_use():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_qpd_b_simple_use", + "reference": "def test_qpd_b_simple_use():\n from skpro.distributions.qpd import QPD_B\n\n qpd = QPD_B(\n alpha=0.2,\n qv_low=[1, 2],\n qv_median=[3, 4],\n qv_high=[5, 6],\n lower=0,\n upper=10,\n )\n\n qpd.mean()", + "completions": [ + "```python\ndef test_qpd_b_simple_use():\n import numpy as np\n import pandas as pd\n from scipy.integrate import quad\n from scipy.special import expit\n\n def exp_func(x, qpd):\n return expit(qpd * x)\n\n qpd = pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z'])\n qpd_b = qpd.mean(lower=0.0, upper=np.inf)\n assert qpd_b.equals(pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z']))\n```\n" + ], + "line": 18, + "token": 146, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef create(\n supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False\n) -> adapters.HTTPAdapter:\n settings = supply_settings()\n allowed_methods = set(util.Retry.DEFAULT_ALLOWED_METHODS)\n if include_post:\n allowed_methods.add(\"POST\")\n\n retries = util.Retry(\n total=settings.total,\n backoff_factor=settings.backoff_factor,\n status_forcelist=settings.statuses,\n allowed_methods=allowed_methods,\n )\n return adapters.HTTPAdapter(max_retries=retries)\n```\n###test function signature:\n```python\n\ndef test_should_create_adapter(include_post):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_should_create_adapter", + "reference": "\ndef test_should_create_adapter(include_post):\n total_retries = 5\n backoff_factor = 0.4\n statuses = (random.randint(400, 600),)\n settings = config.ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses)\n adapter = http_adapter.create(supply_settings=lambda: settings, include_post=include_post)\n assert adapter.max_retries.total == total_retries\n assert adapter.max_retries.backoff_factor == backoff_factor\n assert adapter.max_retries.status_forcelist == statuses\n allowed_methods = adapter.max_retries.allowed_methods or []\n if include_post:\n assert \"POST\" in allowed_methods\n else:\n assert \"POST\" not in allowed_methods", + "completions": [ + "```python\n\ndef test_should_create_adapter(include_post):\n assert create(include_post=include_post) == adapters.HTTPAdapter(max_retries=util.Retry(\n total=config.ResilienceSettings().total,\n backoff_factor=config.ResilienceSettings().backoff_factor,\n status_forcelist=config.ResilienceSettings().statuses,\n allowed_methods=set(util.Retry.DEFAULT_ALLOWED_METHODS) if not include_post else set(util.Retry.DEFAULT_ALLOWED_METHODS).union({\"POST\"})\n ))\n```\n" + ], + "line": 24, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 79, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:\n summary_path = os.path.join(trial_path, 'summary.csv')\n if not os.path.exists(summary_path):\n raise ValueError(f\"summary.csv does not exist in {trial_path}.\")\n trial_summary_df = load_summary_file(summary_path, dict_columns=['best_module_params'])\n config_yaml_path = os.path.join(trial_path, 'config.yaml')\n with open(config_yaml_path, 'r') as f:\n config_dict = yaml.safe_load(f)\n yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)\n if output_path is not None:\n with open(output_path, 'w') as f:\n yaml.dump(yaml_dict, f)\n return yaml_dict\n```\n###test function signature:\n```python\n\ndef test_runner(evaluator):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_runner", + "reference": "\ndef test_runner(evaluator):\n os.environ['BM25'] = 'bm25'\n evaluator.start_trial(os.path.join(resource_dir, 'simple.yaml'))\n project_dir = evaluator.project_dir\n\n def runner_test(runner: Runner):\n answer = runner.run('What is the best movie in Korea? Have Korea movie ever won Oscar?',\n 'retrieved_contents')\n assert len(answer) == 1\n assert isinstance(answer, list)\n assert isinstance(answer[0], str)\n\n runner = Runner.from_trial_folder(os.path.join(project_dir, '0'))\n runner_test(runner)\n runner_test(runner)\n\n with tempfile.NamedTemporaryFile(suffix='yaml', mode='w+t') as yaml_path:\n extract_best_config(os.path.join(project_dir, '0'), yaml_path.name)\n runner = Runner.from_yaml(yaml_path.name, project_dir=project_dir)\n runner_test(runner)", + "completions": [ + "```python\ndef test_runner(evaluator):\n def test_extract_best_config():\n trial_path = 'path/to/trial'\n output_path = 'path/to/output'\n yaml_dict = extract_best_config(trial_path, output_path)\n assert yaml_dict is not None\n assert isinstance(yaml_dict, dict)\n assert 'best_module_params' in yaml_dict\n assert isinstance(yaml_dict['best_module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['" + ], + "line": 23, + "token": 205, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 80, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def init_weights(self):\n logger = MMLogger.get_current_instance()\n if self.init_cfg is None:\n logger.warn(f'No pre-trained weights for '\n f'{self.__class__.__name__}, '\n f'training start from scratch')\n if self.use_abs_pos_embed:\n trunc_normal_(self.absolute_pos_embed, std=0.02)\n for m in self.modules():\n if isinstance(m, nn.Linear):\n trunc_normal_init(m, std=.02, bias=0.)\n elif isinstance(m, nn.LayerNorm):\n constant_init(m, 1.0)\n else:\n assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n f'specify `Pretrained` in ' \\\n f'`init_cfg` in ' \\\n f'{self.__class__.__name__} '\n ckpt = CheckpointLoader.load_checkpoint(\n self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n if 'state_dict' in ckpt:\n _state_dict = ckpt['state_dict']\n elif 'model' in ckpt:\n _state_dict = ckpt['model']\n else:\n _state_dict = ckpt\n if self.convert_weights:\n # supported loading weight from original repo,\n _state_dict = swin_converter(_state_dict)\n\n state_dict = OrderedDict()\n for k, v in _state_dict.items():\n if k.startswith('backbone.'):\n state_dict[k[9:]] = v\n\n # strip prefix of state_dict\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n # reshape absolute position embedding\n if state_dict.get('absolute_pos_embed') is not None:\n absolute_pos_embed = state_dict['absolute_pos_embed']\n N1, L, C1 = absolute_pos_embed.size()\n N2, C2, H, W = self.absolute_pos_embed.size()\n if N1 != N2 or C1 != C2 or L != H * W:\n logger.warning('Error in loading absolute_pos_embed, pass')\n else:\n state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n # interpolate position bias table if needed\n relative_position_bias_table_keys = [\n k for k in state_dict.keys()\n if 'relative_position_bias_table' in k\n ]\n for table_key in relative_position_bias_table_keys:\n table_pretrained = state_dict[table_key]\n table_current = self.state_dict()[table_key]\n L1, nH1 = table_pretrained.size()\n L2, nH2 = table_current.size()\n if nH1 != nH2:\n logger.warning(f'Error in loading {table_key}, pass')\n elif L1 != L2:\n S1 = int(L1**0.5)\n S2 = int(L2**0.5)\n table_pretrained_resized = F.interpolate(\n table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n size=(S2, S2),\n mode='bicubic')\n state_dict[table_key] = table_pretrained_resized.view(\n nH2, L2).permute(1, 0).contiguous()\n\n # load state_dict\n self.load_state_dict(state_dict, False)\n```\n###test function signature:\n```python\ndef test_swin_transformer():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_swin_transformer", + "reference": "def test_swin_transformer():\n\n with pytest.raises(TypeError):\n # Pretrained arg must be str or None.\n SwinTransformer(pretrained=123)\n\n with pytest.raises(AssertionError):\n # Because swin uses non-overlapping patch embed, so the stride of patch\n # embed must be equal to patch size.\n SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)\n\n # test pretrained image size\n with pytest.raises(AssertionError):\n SwinTransformer(pretrain_img_size=(224, 224, 224))\n\n # Test absolute position embedding\n temp = torch.randn((1, 3, 224, 224))\n model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)\n model.init_weights()\n model(temp)\n\n # Test patch norm\n model = SwinTransformer(patch_norm=False)\n model(temp)\n\n # Test normal inference\n temp = torch.randn((1, 3, 32, 32))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 31, 31))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 112, 137))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 28, 35)\n assert outs[1].shape == (1, 192, 14, 18)\n assert outs[2].shape == (1, 384, 7, 9)\n assert outs[3].shape == (1, 768, 4, 5)\n\n model = SwinTransformer(frozen_stages=4)\n model.train()\n for p in model.parameters():\n assert not p.requires_grad", + "completions": [ + "```python\ndef test_swin_transformer():\n from mmcv.cnn import (build_pos_embed, build_pos_embed_sine,\n constant_init, trunc_normal_init)\n from mmcv.runner import CheckpointLoader\n from mmdet.models.builder import build_backbone\n from mmdet.models.utils import swin_converter\n from mmdet.utils import get_root_logger\n from mmdet.utils.logger import MMLogger\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import PatchEmbed\n from mmdet.models.backbones.swin import PatchMerging\n from mmdet.models.backbones.swin import BasicLayer\n from mmdet.models.backbones.swin import BasicBlock\n from mmdet.models.backbones.swin import DepthwiseSeparableConv\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransform" + ], + "line": 79, + "token": 586, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 81, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\n\ndef test_date_columns_equal():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_date_columns_equal", + "reference": "\ndef test_date_columns_equal():\n data = \"\"\"a|b|expected\n2017-01-01|2017-01-01|True\n2017-01-02|2017-01-02|True\n2017-10-01|2017-10-10|False\n2017-01-01||False\n|2017-01-01|False\n||True\"\"\"\n df = pl.read_csv(\n io.StringIO(data),\n separator=\"|\",\n null_values=[\"NULL\"],\n missing_utf8_is_empty_string=True,\n )\n # First compare just the strings\n actual_out = columns_equal(df[\"a\"], df[\"b\"], rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)\n\n # Then compare converted to datetime objects\n col_a = df[\"a\"].str.to_datetime(strict=False)\n col_b = df[\"b\"].str.to_datetime(strict=False)\n actual_out = columns_equal(col_a, col_b, rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)", + "completions": [ + "```python\ndef test_date_columns_equal():\n col_1 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n col_2 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 82, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def upsert_child_firing(\n db: PrefectDBInterface,\n session: AsyncSession,\n firing: Firing,\n):\n automation_id = firing.trigger.automation.id\n parent_trigger_id = firing.trigger.parent.id\n child_trigger_id = firing.trigger.id\n\n upsert = (\n postgresql.insert(db.CompositeTriggerChildFiring)\n .values(\n automation_id=automation_id,\n parent_trigger_id=parent_trigger_id,\n child_trigger_id=child_trigger_id,\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n )\n .on_conflict_do_update(\n index_elements=[\n db.CompositeTriggerChildFiring.automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id,\n ],\n set_=dict(\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n updated=pendulum.now(\"UTC\"),\n ),\n )\n )\n\n await session.execute(upsert)\n\n result = await session.execute(\n sa.select(db.CompositeTriggerChildFiring).filter(\n db.CompositeTriggerChildFiring.automation_id == automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id == parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id == child_trigger_id,\n )\n )\n\n return result.scalars().one()\n```\n###test function signature:\n```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_insert_compound_trigger_child_firing", + "reference": "\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n firing = Firing(\n trigger=saw_a_red_dragon,\n trigger_states={TriggerState.Triggered},\n triggered=frozen_time.add(seconds=1),\n triggering_labels={},\n triggering_event=baby_red_dragon_passed_by,\n )\n res = await upsert_child_firing(\n session=automations_session,\n firing=firing,\n )\n\n assert res and res.id\n assert res.automation_id == raise_the_alarm.id\n assert res.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert res.child_trigger_id == saw_a_red_dragon.id\n assert res.child_firing_id == firing.id\n assert res.child_fired_at == firing.triggered\n\n child_firings = await get_child_firings(\n session=automations_session, trigger=saw_a_red_and_green_dragon\n )\n assert len(child_firings) == 1\n assert child_firings[0].id == res.id\n assert child_firings[0].child_firing_id == firing.id", + "completions": [ + "```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n \"\"\"\n Test that a compound trigger child firing is inserted correctly\n \"\"\"\n # Arrange\n firing = Firing(\n id=uuid4(),\n triggered=frozen_time,\n trigger=saw_a_red_and_green_dragon,\n event=baby_red_dragon_passed_by,\n )\n\n # Act\n result = await upsert_child_firing(\n db=automations_session,\n session=automations_session,\n firing=firing,\n )\n\n # Assert\n assert result.automation_id == raise_the_alarm.id\n assert result.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert result.child_trigger_id == saw_a_red_dragon.id\n assert result.child_firing_id == firing.id\n assert result.child_fired_at == firing.triggered\n assert result.child_firing == firing.dict()\n assert result.updated == frozen_time\n\n```\n" + ], + "line": 58, + "token": 225, + "line_diff": 4, + "token_diff": 1 + }, + { + "id": 83, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef write_output(output_file, output_type, results):\n\n # If json or txt, write contents directly to file.\n if output_type == \"json\" or output_type == \"txt\":\n with open(output_file, \"w\") as f:\n json.dump(results, f, indent=4)\n return\n # If csv or xlsx, convert results to pandas dataframes.\n urls_df = get_urls_df(results)\n codes_df = get_codes_df(results)\n\n # If csv, write dataframes to separate csv files for urls, codes.\n if output_type == \"csv\":\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n urls_df.to_csv(urls_output_file, index=False)\n codes_df.to_csv(codes_output_file, index=False)\n\n # If xlsx, write dataframes to separate sheets for urls, codes.\n if output_type == \"xlsx\":\n writer = pd.ExcelWriter(output_file, engine=\"xlsxwriter\")\n urls_df.to_excel(writer, sheet_name=\"URLs\", index=False)\n codes_df.to_excel(writer, sheet_name=\"Codes\", index=False)\n writer.close()\n```\n###test function signature:\n```python\n def test_write_output_csv(self, mock_urls, mock_codes):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_output_csv", + "reference": " def test_write_output_csv(self, mock_urls, mock_codes):\n\n test_file = \"./test_output/test_file.csv\"\n test_file_urls = \"./test_output/test_file_urls.csv\"\n test_file_codes = \"./test_output/test_file_codes.csv\"\n test_results = {\"test\": \"test\"}\n mock_urls.return_value = pd.DataFrame([test_results])\n mock_codes.return_value = pd.DataFrame([test_results])\n\n with open(test_file_urls, \"w\") as f:\n pass\n\n with open(test_file_codes, \"w\") as f:\n pass\n\n write_output(test_file, \"csv\", test_results)\n\n test_data_urls = pd.read_csv(test_file_urls).to_dict(orient=\"records\")[0]\n test_data_codes = pd.read_csv(test_file_codes).to_dict(orient=\"records\")[0]\n\n os.remove(test_file_urls)\n os.remove(test_file_codes)\n\n self.assertEqual(test_data_urls, test_results)\n self.assertEqual(test_data_codes, test_results)", + "completions": [ + "```python\ndef test_write_output_csv(self, mock_urls, mock_codes):\n \"\"\"\n Test write_output function with csv output type.\n \"\"\"\n output_file = \"output.csv\"\n output_type = \"csv\"\n results = {\"urls\": mock_urls, \"codes\": mock_codes}\n write_output(output_file, output_type, results)\n\n # Check that the urls and codes dataframes are written to separate csv files.\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n self.assertTrue(os.path.isfile(urls_output_file))\n self.assertTrue(os.path.isfile(codes_output_file))\n\n # Check that the dataframes are correct.\n urls_df = pd.read_csv(urls_output_file)\n codes_df = pd.read_csv(codes_output_file)\n self.assertEqual(urls_df.shape, (2, 2))\n self.assertEqual(codes_df.shape, (2, 2))\n self.assertEqual(urls_df.iloc[0, 0], \"url1\")\n self.assertEqual(codes_df.iloc[0, 0], \"code1\")\n self.assertEqual(urls_df.iloc[1, 0], \"url2\")\n self.assertEqual(codes_df.iloc[1, 0], \"code2\")\n```\n" + ], + "line": 31, + "token": 279, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 84, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def terminate_idle_instance() -> None:\n async with get_session_ctx() as session:\n async with PROCESSING_POOL_LOCK:\n res = await session.execute(\n select(InstanceModel)\n .where(\n InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE,\n InstanceModel.deleted == False,\n InstanceModel.job == None, # noqa: E711\n InstanceModel.status == InstanceStatus.IDLE,\n )\n .options(joinedload(InstanceModel.project))\n )\n instances = res.scalars().all()\n\n for instance in instances:\n last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)\n if instance.last_job_processed_at is not None:\n last_time = instance.last_job_processed_at.replace(\n tzinfo=datetime.timezone.utc\n )\n\n idle_seconds = instance.termination_idle_time\n delta = datetime.timedelta(seconds=idle_seconds)\n\n current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc)\n\n if last_time + delta < current_time:\n jpd = JobProvisioningData.__response__.parse_raw(\n instance.job_provisioning_data\n )\n await terminate_job_provisioning_data_instance(\n project=instance.project, job_provisioning_data=jpd\n )\n instance.deleted = True\n instance.deleted_at = get_current_datetime()\n instance.finished_at = get_current_datetime()\n instance.status = InstanceStatus.TERMINATED\n instance.termination_reason = \"Idle timeout\"\n\n idle_time = current_time - last_time\n logger.info(\n \"Instance %s terminated by termination policy: idle time %ss\",\n instance.name,\n str(idle_time.seconds),\n extra={\n \"instance_name\": instance.name,\n \"instance_status\": InstanceStatus.TERMINATED.value,\n },\n )\n\n await session.commit()\n```\n###test function signature:\n```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_terminate_by_idle_timeout", + "reference": "\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n project = await create_project(session=session)\n pool = await create_pool(session, project)\n\n instance = await create_instance(session, project, pool, status=InstanceStatus.IDLE)\n instance.termination_idle_time = 300\n instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE\n instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)\n await session.commit()\n\n with patch(\n \"dstack._internal.server.background.tasks.process_instances.terminate_job_provisioning_data_instance\"\n ):\n await terminate_idle_instance()\n\n await session.refresh(instance)\n\n assert instance is not None\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"", + "completions": [ + "```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n # given\n project = ProjectModel(name=\"test_project\")\n await session.add(project)\n await session.commit()\n\n instance = InstanceModel(\n name=\"test_instance\",\n project=project,\n termination_policy=TerminationPolicy.DESTROY_AFTER_IDLE,\n termination_idle_time=10,\n status=InstanceStatus.IDLE,\n created_at=datetime.datetime.now(datetime.timezone.utc),\n )\n await session.add(instance)\n await session.commit()\n\n # when\n await terminate_idle_instance()\n\n # then\n instance = await session.get(InstanceModel, instance.id)\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"\n```\n" + ], + "line": 56, + "token": 289, + "line_diff": 4, + "token_diff": 2 + }, + { + "id": 85, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef nargsort(\n items,\n kind: str = \"quicksort\",\n ascending: bool = True,\n na_position: str = \"last\",\n key: Callable | None = None,\n mask: npt.NDArray[np.bool_] | None = None,\n) -> npt.NDArray[np.intp]:\n\n if key is not None:\n items = ensure_key_mapped(items, key)\n return nargsort(\n items,\n kind=kind,\n ascending=ascending,\n na_position=na_position,\n key=None,\n mask=mask,\n )\n\n if isinstance(items, ABCRangeIndex):\n return items.argsort(ascending=ascending) # TODO: test coverage with key?\n elif not isinstance(items, ABCMultiIndex):\n items = extract_array(items)\n if mask is None:\n mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?\n\n if is_extension_array_dtype(items):\n return items.argsort(ascending=ascending, kind=kind, na_position=na_position)\n else:\n items = np.asanyarray(items)\n\n idx = np.arange(len(items))\n non_nans = items[~mask]\n non_nan_idx = idx[~mask]\n\n nan_idx = np.nonzero(mask)[0]\n if not ascending:\n non_nans = non_nans[::-1]\n non_nan_idx = non_nan_idx[::-1]\n indexer = non_nan_idx[non_nans.argsort(kind=kind)]\n if not ascending:\n indexer = indexer[::-1]\n # Finally, place the NaNs at the end or the beginning according to\n # na_position\n if na_position == \"last\":\n indexer = np.concatenate([indexer, nan_idx])\n elif na_position == \"first\":\n indexer = np.concatenate([nan_idx, indexer])\n else:\n raise ValueError(f\"invalid na_position: {na_position}\")\n return ensure_platform_int(indexer)\n```\n###test function signature:\n```python\n\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nargsort", + "reference": "\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)\n\n # mergesort is the most difficult to get right because we want it to be\n # stable.\n\n # According to numpy/core/tests/test_multiarray, \"\"\"The number of\n # sorted items must be greater than ~50 to check the actual algorithm\n # because quick and merge sort fall over to insertion sort for small\n # arrays.\"\"\"\n\n result = nargsort(\n items, kind=\"mergesort\", ascending=ascending, na_position=na_position\n )\n tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)", + "completions": [ + "```python\n def test_nargsort(self, ascending, na_position, exp, box):\n items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216," + ], + "line": 58, + "token": 412, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 86, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_rotated_rb_circuits(\n n_qubits: int,\n num_cliffords: int,\n theta: Optional[float] = None,\n trials: int = 1,\n return_type: Optional[str] = None,\n seed: Optional[int] = None,\n) -> List[QPROGRAM]:\n\n circuits = cast(\n List[cirq.Circuit],\n generate_rb_circuits(n_qubits, num_cliffords, trials, seed=seed),\n )\n\n if theta is None:\n theta = 2 * np.pi * np.random.rand()\n\n for circ in circuits:\n qubits = list(circ.all_qubits())\n circ.insert(len(circ) // 2, cirq.Rz(rads=theta).on(qubits[0]))\n\n return_type = \"cirq\" if not return_type else return_type\n return [convert_from_mitiq(circuit, return_type) for circuit in circuits]\n```\n###test function signature:\n```python\n\ndef test_rotated_rb_circuits(n_qubits, theta):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rotated_rb_circuits", + "reference": "\ndef test_rotated_rb_circuits(n_qubits, theta):\n depth = 10\n for trials in [2, 3]:\n circuits = generate_rotated_rb_circuits(\n n_qubits=n_qubits, num_cliffords=depth, theta=theta, trials=trials\n )\n for circ in circuits:\n zero_prob = (\n cirq.DensityMatrixSimulator()\n .simulate(circ)\n .final_density_matrix[0, 0]\n .real\n )\n assert -1.0001 <= zero_prob <= 1.0001", + "completions": [ + "```python\ndef test_rotated_rb_circuits(n_qubits, theta):\n from cirq import Circuit, LineQubit\n from cirq.testing import assert_has_diagram\n from mitiq.generate import generate_rotated_rb_circuits\n\n circuits = generate_rotated_rb_circuits(n_qubits, 10, theta)\n for circuit in circuits:\n assert_has_diagram(\n circuit,\n f\"\"\"\n0: ───Rz({theta})───\n1: ───Rz({theta})───\n2: ───Rz({theta})───\n3: ───Rz({theta})───\n4: ───Rz({theta})───\n5: ───Rz({theta})───\n6: ───Rz({theta})───\n7: ───Rz({theta})───\n8: ───Rz({theta})───\n9: ───Rz({theta})───\n10: ───Rz({theta})───\n11: ───Rz({theta})───\n12: ───Rz({theta})───\n13: ───Rz({theta})───\n14: ───Rz({theta})───\n15: ───Rz({theta})───\n16: ───Rz({theta})───\n17: ───Rz({theta})───\n18: ───Rz({theta})───\n19: ───Rz({theta})───\n20: ───Rz({theta})───\n21: ───Rz({theta})───\n22: ───Rz({theta})───\n23: ───Rz({theta})───\n24: ───Rz({theta})───\n25: ───Rz({theta})───\n26: ───Rz({theta})───\n27: ───Rz({theta})───\n28: ───Rz({theta})───\n29: ───Rz({theta})───\n30: ───Rz({theta})───\n31: ───Rz({theta})───\n32: ───Rz({theta})───\n33: ───Rz({theta})───\n34: ───Rz({theta})───\n35: ───Rz({theta})───\n36: ───Rz({theta})───\n37: ───Rz({theta})───\n38: ───Rz({theta})───\n39: ───Rz({theta})───\n40: ───Rz({theta})───\n41: ───Rz({theta})───\n42: ───Rz({theta})───\n43: ───Rz({theta})───\n44: ───Rz({theta})───\n45: ───Rz({theta})───\n46: ───Rz({theta})───\n47: ───Rz({theta})───\n48: ───Rz({theta})───\n49: ───Rz({theta})───\n50: ───Rz({theta})───\n51: ───Rz({theta})───\n52: ───Rz({theta})───\n53: ───Rz({theta})───\n54: ───Rz({theta})───\n55: ───Rz({theta})───\n56: ───Rz({theta})───\n57: ───Rz({theta})───\n58: ───Rz({theta})───\n59: ───Rz({theta})───\n60: ───Rz({theta})───\n61: ───Rz({theta})───\n62: ───Rz({theta})───\n63: ───" + ], + "line": 29, + "token": 226, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 87, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef eye(N, M=None, k=None, dtype=None):\n dtype = to_torch_dtype(dtype or config.floatx())\n M = N if M is None else M\n k = 0 if k is None else k\n if k == 0:\n # TODO: torch.eye doesn't support bfloat16 with cpu\n if get_device() == \"cpu\" and dtype == torch.bfloat16:\n return cast(\n torch.eye(\n N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n ),\n dtype,\n )\n return torch.eye(N, M, dtype=dtype, device=get_device())\n diag_length = builtins.max(N, M)\n diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n return torch.diag(diag, diagonal=k)[:N, :M]\n```\n###test function signature:\n```python\n\n def test_eye(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_eye", + "reference": "\n def test_eye(self):\n self.assertAllClose(knp.eye(3), np.eye(3))\n self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n self.assertAllClose(knp.Eye()(3), np.eye(3))\n self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1))\n\n # Test k >= N\n self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3))\n\n # Test k > 0 and N >= M\n self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1))\n\n # Test k > 0 and N < M and N + k > M\n self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2))\n\n # Test k < 0 and M >= N\n self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1))\n\n # Test k < 0 and M < N and M - k > N\n self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))", + "completions": [ + "```python\n\n def test_eye(self):\n self.assertEqual(eye(3), torch.eye(3))\n self.assertEqual(eye(3, 4), torch.eye(3, 4))\n self.assertEqual(eye(3, 4, 1), torch.eye(3, 4, 1))\n self.assertEqual(eye(3, 4, -1), torch.eye(3, 4, -1))\n self.assertEqual(eye(3, 4, 2), torch.eye(3, 4, 2))\n self.assertEqual(eye(3, 4, -2), torch.eye(3, 4, -2))\n self.assertEqual(eye(3, 4, 3), torch.eye(3, 4, 3))\n self.assertEqual(eye(3, 4, -3), torch.eye(3, 4, -3))\n self.assertEqual(eye(3, 4, 4), torch.eye(3, 4, 4))\n self.assertEqual(eye(3, 4, -4), torch.eye(3, 4, -4))\n self.assertEqual(eye(3, 4, 5), torch.eye(3, 4, 5))\n self.assertEqual(eye(3, 4, -5), torch.eye(3, 4, -5))\n self.assertEqual(eye(3, 4, 6), torch.eye(3, 4, 6))\n self.assertEqual(eye(3, 4, -6), torch.eye(3, 4, -6))\n self.assertEqual(eye(3, 4, 7), torch.eye(3, 4, 7))\n self.assertEqual(eye(3, 4, -7), torch.eye(3, 4, -7))\n self.assertEqual(eye(3, 4, 8), torch.eye(3, 4, 8))\n self.assertEqual(eye(3, 4, -8), torch.eye(3, 4, -8))\n self.assertEqual(eye(3, 4, 9), torch.eye(3, 4, 9))\n self.assertEqual(eye(3, 4, -9), torch.eye(3, 4, -9))\n self.assertEqual(eye(3, 4, 10), torch.eye(3, 4, 10))\n self.assertEqual(eye(3, 4, -10), torch.eye(3, 4, -10))\n self.assertEqual(eye(3, 4, 11), torch.eye(3, 4, 11))\n self.assertEqual(eye(3, 4, -11), torch.eye(3, 4, -11))\n self.assertEqual(eye(3, 4, 12), torch.eye(3, 4, 12))\n self.assertEqual(eye(3, 4, -12), torch.eye(3, 4, -12))\n self.assertEqual(eye(3, 4, 13), torch.eye(3, 4, 13))\n self.assertEqual(eye(3, 4, -13), torch.eye(3, 4, -13))\n self.assertEqual(eye(3, 4, 14), torch.eye(3, 4, 14))\n self.assertEqual(eye(3, 4, -14), torch.eye(3, 4, -14))\n self.assertEqual(eye(3, 4, 15), torch.eye(" + ], + "line": 27, + "token": 219, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 88, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef ragged_qk_dotprod(\n query: RaggedActivations, key: RaggedActivations, lut: RaggedQkPidLookupTable\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n # check constraints\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n # allocates output\n # max_n_ctx_q_across_seqs = query.max_n_ctx_per_seq\n\n assert query.n_seqs == key.n_seqs\n # TODO: flag use zeros for garbage\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n # Integers\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n```\n###test function signature:\n```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ragged_qk_dotprod_single_seq", + "reference": "\ndef test_ragged_qk_dotprod_single_seq(dtype):\n d_head = 2\n\n key = RaggedActivations.from_list(\n [\n make_seq(n_ctx=3, value=42, d_model=d_head, dtype=dtype),\n ]\n )\n query = RaggedActivations.from_list(\n [\n make_seq(n_ctx=4, value=55, d_model=d_head, dtype=dtype),\n ]\n )\n torch_scores = scores_via_qk_dotprod(query, key)\n print(f\"{torch_scores=}\")\n\n lut = RaggedQkPidLookupTable.from_query_and_key_tokens_per_seq(\n n_ctx_q_per_seq=query.n_ctx_per_seq, n_ctx_k_per_seq=key.n_ctx_per_seq\n )\n\n scores = ragged_qk_dotprod(query, key, lut)\n assert_eq(torch_scores, scores)", + "completions": [ + "```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n # query\n query_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n query = RaggedActivations(\n raw_tensor=query_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # key\n key_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n key = RaggedActivations(\n raw_tensor=key_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # lut\n pid_to_in_q_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_in_k_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_q_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_k_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_seq_idx = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n lut = RaggedQkPidLookupTable(\n pid_to_in_q_token_offset=pid_to_in_q_token_offset,\n pid_to_in_k_token_offset=pid_to_in_k_token_offset,\n pid_to_out_q_block=pid_to_out_q_block,\n pid_to_out_k_block=pid_to_out_k_block,\n pid_to_out_seq_idx=pid_to_out_seq_idx,\n n_pids_total=1,\n )\n\n # scores_out\n scores_out = torch.ones(\n (1, 1, 1), dtype=dtype, device=torch.device(\"cuda\")\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_" + ], + "line": 57, + "token": 310, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 89, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):\n # Ravel both arrays, behavior for the first array could be different\n ar1 = np.asarray(ar1).ravel()\n ar2 = np.asarray(ar2).ravel()\n\n # Ensure that iteration through object arrays yields size-1 arrays\n if ar2.dtype == object:\n ar2 = ar2.reshape(-1, 1)\n\n if kind not in {None, 'sort', 'table'}:\n raise ValueError(\n f\"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.\")\n\n # Can use the table method if all arrays are integers or boolean:\n is_int_arrays = all(ar.dtype.kind in (\"u\", \"i\", \"b\") for ar in (ar1, ar2))\n use_table_method = is_int_arrays and kind in {None, 'table'}\n\n if use_table_method:\n if ar2.size == 0:\n if invert:\n return np.ones_like(ar1, dtype=bool)\n else:\n return np.zeros_like(ar1, dtype=bool)\n\n # Convert booleans to uint8 so we can use the fast integer algorithm\n if ar1.dtype == bool:\n ar1 = ar1.astype(np.uint8)\n if ar2.dtype == bool:\n ar2 = ar2.astype(np.uint8)\n\n ar2_min = np.min(ar2)\n ar2_max = np.max(ar2)\n\n ar2_range = int(ar2_max) - int(ar2_min)\n\n # Constraints on whether we can actually use the table method:\n # 1. Assert memory usage is not too large\n below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)\n # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype\n range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max\n # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype\n if ar1.size > 0:\n ar1_min = np.min(ar1)\n ar1_max = np.max(ar1)\n\n # After masking, the range of ar1 is guaranteed to be\n # within the range of ar2:\n ar1_upper = min(int(ar1_max), int(ar2_max))\n ar1_lower = max(int(ar1_min), int(ar2_min))\n\n range_safe_from_overflow &= all((\n ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,\n ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min\n ))\n\n # Optimal performance is for approximately\n # log10(size) > (log10(range) - 2.27) / 0.927.\n # However, here we set the requirement that by default\n # the intermediate array can only be 6x\n # the combined memory allocation of the original\n # arrays. See discussion on \n # https://github.com/numpy/numpy/pull/12065.\n\n if (\n range_safe_from_overflow and \n (below_memory_constraint or kind == 'table')\n ):\n\n if invert:\n outgoing_array = np.ones_like(ar1, dtype=bool)\n else:\n outgoing_array = np.zeros_like(ar1, dtype=bool)\n\n # Make elements 1 where the integer exists in ar2\n if invert:\n isin_helper_ar = np.ones(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 0\n else:\n isin_helper_ar = np.zeros(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 1\n\n # Mask out elements we know won't work\n basic_mask = (ar1 <= ar2_max) & (ar1 >= ar2_min)\n outgoing_array[basic_mask] = isin_helper_ar[ar1[basic_mask] -\n ar2_min]\n\n return outgoing_array\n elif kind == 'table': # not range_safe_from_overflow\n raise RuntimeError(\n \"You have specified kind='table', \"\n \"but the range of values in `ar2` or `ar1` exceed the \"\n \"maximum integer of the datatype. \"\n \"Please set `kind` to None or 'sort'.\"\n )\n elif kind == 'table':\n raise ValueError(\n \"The 'table' method is only \"\n \"supported for boolean or integer arrays. \"\n \"Please select 'sort' or None for kind.\"\n )\n\n\n # Check if one of the arrays may contain arbitrary objects\n contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject\n\n # This code is run when\n # a) the first condition is true, making the code significantly faster\n # b) the second condition is true (i.e. `ar1` or `ar2` may contain\n # arbitrary objects), since then sorting is not guaranteed to work\n if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:\n if invert:\n mask = np.ones(len(ar1), dtype=bool)\n for a in ar2:\n mask &= (ar1 != a)\n else:\n mask = np.zeros(len(ar1), dtype=bool)\n for a in ar2:\n mask |= (ar1 == a)\n return mask\n\n # Otherwise use sorting\n if not assume_unique:\n ar1, rev_idx = np.unique(ar1, return_inverse=True)\n ar2 = np.unique(ar2)\n\n ar = np.concatenate((ar1, ar2))\n # We need this to be a stable sort, so always use 'mergesort'\n # here. The values from the first array should always come before\n # the values from the second array.\n order = ar.argsort(kind='mergesort')\n sar = ar[order]\n if invert:\n bool_ar = (sar[1:] != sar[:-1])\n else:\n bool_ar = (sar[1:] == sar[:-1])\n flag = np.concatenate((bool_ar, [invert]))\n ret = np.empty(ar.shape, dtype=bool)\n ret[order] = flag\n\n if assume_unique:\n return ret[:len(ar1)]\n else:\n return ret[rev_idx]\n```\n###test function signature:\n```python\n\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_in1d_ravel", + "reference": "\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n a = np.arange(6).reshape(2, 3)\n b = np.arange(3, 9).reshape(3, 2)\n long_b = np.arange(3, 63).reshape(30, 2)\n ec = np.array([False, False, False, True, True, True])\n\n assert_array_equal(in1d(a, b, assume_unique=True, kind=kind),\n ec)\n assert_array_equal(in1d(a, b, assume_unique=False,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=True,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=False,\n kind=kind),\n ec)", + "completions": [ + "```python\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n ar1 = np.array([1, 2, 3, 4, 5])\n ar2 = np.array([1, 2, 3, 4, 5])\n ar1_ravel = ar1.ravel()\n ar2_ravel = ar2.ravel()\n self.assertTrue(np.array_equal(ar1_ravel, ar1))\n self.assertTrue(np.array_equal(ar2_ravel, ar2))\n self.assertTrue(np.array_equal(in1d(ar1, ar2, kind=kind),\n in1d(ar1_ravel, ar2_ravel, kind=kind)))\n```\n" + ], + "line": 133, + "token": 1112, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 90, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_metadata", + "reference": "\ndef test_header_to_metadata_and_cell_metadata():\n text = \"\"\"---\ntitle: Sample header\njupyter:\n mainlanguage: python\n---\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {\"mainlanguage\": \"python\"}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {\"lines_to_next_cell\": 0}\n assert pos == len(lines)", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_metadata():\n lines = [\n \"#' ---\",\n \"#' jupyter:\",\n \"#' jupytext:\",\n \"#' text_representation:\",\n \"#' extension: .py\",\n \"#' format_name: light\",\n \"#' format_version: '1.5'\",\n \"#' jupytext_version: 1.11.4\",\n \"#' kernelspec:\",\n \"#' display_name: Python 3\",\n \"#' language: python\",\n \"#' name: python3\",\n \"#' ---\",\n \"#'\",\n \"#' This is a comment\",\n \"#'\",\n \"#' And another one\",\n \"#'\",\n \"#' ```python\",\n \"#' import numpy as np\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' import matplotlib.pyplot as plt\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %matplotlib inline\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2" + ], + "line": 104, + "token": 642, + "line_diff": 6, + "token_diff": 5 + }, + { + "id": 91, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef import_optional_dependency(\n name: str,\n extra: str = \"\",\n errors: str = \"raise\",\n min_version: str | None = None,\n):\n\n assert errors in {\"warn\", \"raise\", \"ignore\"}\n\n package_name = INSTALL_MAPPING.get(name)\n install_name = package_name if package_name is not None else name\n\n msg = (\n f\"Missing optional dependency '{install_name}'. {extra} \"\n f\"Use pip or conda to install {install_name}.\"\n )\n try:\n module = importlib.import_module(name)\n except ImportError:\n if errors == \"raise\":\n raise ImportError(msg)\n return None\n\n # Handle submodules: if we have submodule, grab parent module from sys.modules\n parent = name.split(\".\")[0]\n if parent != name:\n install_name = parent\n module_to_get = sys.modules[install_name]\n else:\n module_to_get = module\n minimum_version = min_version if min_version is not None else VERSIONS.get(parent)\n if minimum_version:\n version = get_version(module_to_get)\n if version and Version(version) < Version(minimum_version):\n msg = (\n f\"Pandas requires version '{minimum_version}' or newer of '{parent}' \"\n f\"(version '{version}' currently installed).\"\n )\n if errors == \"warn\":\n warnings.warn(\n msg,\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return None\n elif errors == \"raise\":\n raise ImportError(msg)\n\n return module\n```\n###test function signature:\n```python\n\ndef test_bad_version(monkeypatch):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_bad_version", + "reference": "\ndef test_bad_version(monkeypatch):\n name = \"fakemodule\"\n module = types.ModuleType(name)\n module.__version__ = \"0.9.0\"\n sys.modules[name] = module\n monkeypatch.setitem(VERSIONS, name, \"1.0.0\")\n\n match = \"Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'\"\n with pytest.raises(ImportError, match=match):\n import_optional_dependency(\"fakemodule\")\n\n # Test min_version parameter\n result = import_optional_dependency(\"fakemodule\", min_version=\"0.8\")\n assert result is module\n\n with tm.assert_produces_warning(UserWarning):\n result = import_optional_dependency(\"fakemodule\", errors=\"warn\")\n assert result is None\n\n module.__version__ = \"1.0.0\" # exact match is OK\n result = import_optional_dependency(\"fakemodule\")\n assert result is module", + "completions": [ + "```python\ndef test_bad_version(monkeypatch):\n monkeypatch.setattr(import_optional_dependency, \"get_version\", lambda x: \"0.1.0\")\n monkeypatch.setattr(import_optional_dependency, \"VERSIONS\", {\"pandas\": \"0.2.0\"})\n with pytest.raises(ImportError):\n import_optional_dependency(\"pandas\")\n```\n" + ], + "line": 54, + "token": 352, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 92, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_mel_spectrogram(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mel_spectrogram", + "reference": "\n def test_mel_spectrogram(self):\n waveform = self._load_datasamples(1)[0]\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=513,\n num_mel_filters=13,\n min_frequency=100,\n max_frequency=4000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"htk\",\n )\n self.assertEqual(mel_filters.shape, (513, 13))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n )\n self.assertEqual(spec.shape, (513, 732))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n mel_filters=mel_filters,\n )\n self.assertEqual(spec.shape, (13, 732))\n\n # fmt: off\n expected = np.array([\n 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,\n 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,\n 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,\n 9.44153646e-04\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:, 300], expected))", + "completions": [ + "```python\n def test_mel_spectrogram(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n mel_floor = 1e-10\n log_mel = \"log\"\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (1, 10))\n self.assertEqual(spectrogram[0][0], 1)\n self.assertEqual(spectrogram[0][1], 2)\n self.assertEqual(spectrogram[0][2], 3)\n self.assertEqual(spectrogram[0][3], 4)\n self.assertEqual(spectrogram[0][4], 5)\n self.assertEqual(spectrogram[0][5], 6)\n self.assertEqual(spectrogram[0][6], 7)\n self.assertEqual(spectrogram[0][7], 8)\n self.assertEqual(spectrogram[0][8], 9)\n self.assertEqual(spectrogram[0][9], 10)\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 93, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef create(\n meta_schema,\n validators=(),\n version=None,\n type_checker=_types.draft202012_type_checker,\n format_checker=_format.draft202012_format_checker,\n id_of=_id_of,\n applicable_validators=methodcaller(\"items\"),\n):\n # preemptively don't shadow the `Validator.format_checker` local\n format_checker_arg = format_checker\n\n @attr.s\n class Validator:\n\n VALIDATORS = dict(validators)\n META_SCHEMA = dict(meta_schema)\n TYPE_CHECKER = type_checker\n FORMAT_CHECKER = format_checker_arg\n ID_OF = staticmethod(id_of)\n\n schema = attr.ib(repr=reprlib.repr)\n resolver = attr.ib(default=None, repr=False)\n format_checker = attr.ib(default=None)\n\n def __init_subclass__(cls):\n warnings.warn(\n (\n \"Subclassing validator classes is not intended to \"\n \"be part of their public API. A future version \"\n \"will make doing so an error, as the behavior of \"\n \"subclasses isn't guaranteed to stay the same \"\n \"between releases of jsonschema. Instead, prefer \"\n \"composition of validators, wrapping them in an object \"\n \"owned entirely by the downstream library.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n\n def __attrs_post_init__(self):\n if self.resolver is None:\n self.resolver = RefResolver.from_schema(\n self.schema,\n id_of=id_of,\n )\n\n @classmethod\n def check_schema(cls, schema, format_checker=_UNSET):\n Validator = validator_for(cls.META_SCHEMA, default=cls)\n if format_checker is _UNSET:\n format_checker = Validator.FORMAT_CHECKER\n validator = Validator(\n schema=cls.META_SCHEMA,\n format_checker=format_checker,\n )\n for error in validator.iter_errors(schema):\n raise exceptions.SchemaError.create_from(error)\n\n def evolve(self, **changes):\n # Essentially reproduces attr.evolve, but may involve instantiating\n # a different class than this one.\n cls = self.__class__\n\n schema = changes.setdefault(\"schema\", self.schema)\n NewValidator = validator_for(schema, default=cls)\n\n for field in attr.fields(cls):\n if not field.init:\n continue\n attr_name = field.name # To deal with private attributes.\n init_name = attr_name if attr_name[0] != \"_\" else attr_name[1:]\n if init_name not in changes:\n changes[init_name] = getattr(self, attr_name)\n\n return NewValidator(**changes)\n\n def iter_errors(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.iter_errors \"\n \"is deprecated and will be removed in a future \"\n \"release. Call validator.evolve(schema=new_schema).\"\n \"iter_errors(...) instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n else:\n _schema = self.schema\n\n if _schema is True:\n return\n elif _schema is False:\n yield exceptions.ValidationError(\n f\"False schema does not allow {instance!r}\",\n validator=None,\n validator_value=None,\n instance=instance,\n schema=_schema,\n )\n return\n\n scope = id_of(_schema)\n if scope:\n self.resolver.push_scope(scope)\n try:\n for k, v in applicable_validators(_schema):\n validator = self.VALIDATORS.get(k)\n if validator is None:\n continue\n\n errors = validator(self, v, instance, _schema) or ()\n for error in errors:\n # set details if not already set by the called fn\n error._set(\n validator=k,\n validator_value=v,\n instance=instance,\n schema=_schema,\n type_checker=self.TYPE_CHECKER,\n )\n if k not in {\"if\", \"$ref\"}:\n error.schema_path.appendleft(k)\n yield error\n finally:\n if scope:\n self.resolver.pop_scope()\n\n def descend(self, instance, schema, path=None, schema_path=None):\n for error in self.evolve(schema=schema).iter_errors(instance):\n if path is not None:\n error.path.appendleft(path)\n if schema_path is not None:\n error.schema_path.appendleft(schema_path)\n yield error\n\n def validate(self, *args, **kwargs):\n for error in self.iter_errors(*args, **kwargs):\n raise error\n\n def is_type(self, instance, type):\n try:\n return self.TYPE_CHECKER.is_type(instance, type)\n except exceptions.UndefinedTypeCheck:\n raise exceptions.UnknownType(type, instance, self.schema)\n\n def is_valid(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.is_valid is deprecated \"\n \"and will be removed in a future release. Call \"\n \"validator.evolve(schema=new_schema).is_valid(...) \"\n \"instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n self = self.evolve(schema=_schema)\n\n error = next(self.iter_errors(instance), None)\n return error is None\n\n if version is not None:\n safe = version.title().replace(\" \", \"\").replace(\"-\", \"\")\n Validator.__name__ = Validator.__qualname__ = f\"{safe}Validator\"\n Validator = validates(version)(Validator)\n\n return Validator\n```\n###test function signature:\n```python\n\n def test_validates_registers_meta_schema_id(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validates_registers_meta_schema_id", + "reference": "\n def test_validates_registers_meta_schema_id(self):\n meta_schema_key = \"meta schema id\"\n my_meta_schema = {\"id\": meta_schema_key}\n\n validators.create(\n meta_schema=my_meta_schema,\n version=\"my version\",\n id_of=lambda s: s.get(\"id\", \"\"),\n )\n self.addCleanup(validators._META_SCHEMAS.pop, meta_schema_key)\n\n self.assertIn(meta_schema_key, validators._META_SCHEMAS)", + "completions": [ + "```python\n def test_validates_registers_meta_schema_id(self):\n validator = create(\n meta_schema={\n \"$id\": \"https://example.com/schema.json\",\n \"type\": \"object\",\n \"properties\": {\"foo\": {\"type\": \"string\"}},\n }\n )\n self.assertEqual(\n validator.resolver.store[\"https://example.com/schema.json\"],\n validator.schema,\n )\n```\n" + ], + "line": 160, + "token": 896, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 94, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_metrics(directory: str) -> pd.Series:\n\n dfs = []\n # Loop through all csv files in the directory\n csv_files = [\n filename for filename in os.listdir(directory) if filename.endswith(\".csv\")\n ]\n\n if len(csv_files) == 0:\n return pd.Series()\n\n for filename in csv_files:\n filepath = os.path.join(directory, filename)\n # Read the csv file into a pandas DataFrame\n df = pd.read_csv(\n filepath,\n header=None,\n names=[\n \"initial_time_iso\",\n \"lead_time_hours\",\n \"channel\",\n \"metric\",\n \"value\",\n ],\n )\n out = pd.DataFrame()\n out[\"initial_time\"] = pd.to_datetime(df[\"initial_time_iso\"])\n out[\"lead_time\"] = pd.to_timedelta(df[\"lead_time_hours\"], unit=\"h\")\n out[\"channel\"] = df[\"channel\"]\n out[\"metric\"] = df[\"metric\"]\n out[\"value\"] = df[\"value\"]\n dfs.append(out)\n\n df = pd.concat(dfs, axis=0)\n df.set_index([\"initial_time\", \"lead_time\", \"channel\", \"metric\"], inplace=True)\n return df[\"value\"]\n```\n###test function signature:\n```python\n\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_and_read_metric", + "reference": "\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n initial_time = datetime.datetime(2022, 1, 1, 0, 0, 0)\n lead_time = datetime.timedelta(hours=24)\n channel = \"t2m\"\n metric = \"rmse\"\n value = 25.6\n\n # Read the metric from the directory\n directory = tmpdir.mkdir(\"metrics\")\n filename = \"test_metric.csv\"\n filepath = os.path.join(directory, filename)\n with open(filepath, \"w\") as f:\n forecast_metrics_io.write_metric(\n f, initial_time, lead_time, channel, metric, value\n )\n metrics = forecast_metrics_io.read_metrics(directory)\n\n # Check that the metric was written correctly\n expected_index = pd.MultiIndex.from_tuples(\n [(initial_time, lead_time, channel, metric)],\n names=[\"initial_time\", \"lead_time\", \"channel\", \"metric\"],\n )\n expected_value = pd.Series([value], index=expected_index, name=\"value\")\n pd.testing.assert_series_equal(metrics, expected_value)", + "completions": [ + "```python\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n metric = pd.Series(\n data=[1, 2, 3],\n index=[\n pd.Timestamp(\"2020-01-01 00:00:00\"),\n pd.Timestamp(\"2020-01-01 01:00:00\"),\n pd.Timestamp(\"2020-01-01 02:00:00\"),\n ],\n name=\"metric\",\n )\n metric.to_csv(tmpdir.join(\"metric.csv\"), header=False)\n\n # Read the metric from the file object\n read_metric = read_metrics(tmpdir.strpath)\n\n # Check that the metric is the same\n assert_series_equal(metric, read_metric)\n```\n" + ], + "line": 43, + "token": 320, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 95, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n\n def test_kde_bandwidth_method(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_kde_bandwidth_method", + "reference": "\n def test_kde_bandwidth_method(self):\n\n np.random.seed(8765678)\n n_basesample = 50\n xn = np.random.randn(n_basesample)\n\n # Default\n gkde = mlab.GaussianKDE(xn)\n # Supply a callable\n gkde2 = mlab.GaussianKDE(xn, 'scott')\n # Supply a scalar\n gkde3 = mlab.GaussianKDE(xn, bw_method=gkde.factor)\n\n xs = np.linspace(-7, 7, 51)\n kdepdf = gkde.evaluate(xs)\n kdepdf2 = gkde2.evaluate(xs)\n assert kdepdf.all() == kdepdf2.all()\n kdepdf3 = gkde3.evaluate(xs)\n assert kdepdf.all() == kdepdf3.all()", + "completions": [ + "```python\n def test_kde_bandwidth_method(self):\n from sklearn.neighbors import KernelDensity\n from sklearn.metrics import euclidean_distances\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n from sklearn.model_selection import train_test_split\n from sklearn.metrics import accuracy_score\n from sklearn.metrics import f1_score\n from sklearn.metrics import precision_score\n from sklearn.metrics import recall_score\n from sklearn.metrics import roc_auc_score\n from sklearn.metrics import confusion_matrix\n from sklearn.metrics import classification_report\n from sklearn.metrics import cohen_kappa_score\n from sklearn.metrics import mean_squared_error\n from sklearn.metrics import mean_absolute_error\n from sklearn.metrics import mean_squared_log_error\n from sklearn.metrics import median_absolute_error\n from sklearn.metrics import r2_score\n from sklearn.metrics import explained_variance_score\n from sklearn.metrics import silhouette_score\n from sklearn.metrics import calinski_harabasz_score\n from sklearn.metrics import davies_bouldin_score\n from sklearn.metrics import mutual_info_score\n from sklearn.metrics import normalized_mutual_info_score\n from sklearn.metrics import adjusted_mutual_info_score\n from sklearn.metrics import adjusted_rand_score\n from sklearn.metrics import completeness_score\n from sklearn.metrics import contingency_matrix\n from sklearn.metrics import homogeneity_completeness_v_measure\n from sklearn.metrics import homogeneity_score\n from sklearn.metrics import v_measure_score\n from sklearn.metrics import normalized_v_measure_score\n from sklearn.metrics import poisson\n from sklearn.metrics import kld\n from sklearn.metrics import kld_normalized\n from sklearn.metrics import kld_unnormalized\n from sklearn.metrics import kld_unnormalized_normalized\n from sklearn.metrics import kld_unnormalized_normalized_sym\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from" + ], + "line": 33, + "token": 273, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 96, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_top_k_embeddings(\n query_embedding: List[float],\n embeddings: List[List[float]],\n similarity_fn: Optional[Callable[..., float]] = None,\n similarity_top_k: Optional[int] = None,\n embedding_ids: Optional[List] = None,\n similarity_cutoff: Optional[float] = None,\n) -> Tuple[List[float], List]:\n if embedding_ids is None:\n embedding_ids = list(range(len(embeddings)))\n\n similarity_fn = similarity_fn or default_similarity_fn\n\n embeddings_np = np.array(embeddings)\n query_embedding_np = np.array(query_embedding)\n\n similarity_heap: List[Tuple[float, Any]] = []\n for i, emb in enumerate(embeddings_np):\n similarity = similarity_fn(query_embedding_np, emb)\n if similarity_cutoff is None or similarity > similarity_cutoff:\n heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))\n if similarity_top_k and len(similarity_heap) > similarity_top_k:\n heapq.heappop(similarity_heap)\n result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)\n\n result_similarities = [s for s, _ in result_tups]\n result_ids = [n for _, n in result_tups]\n\n return result_similarities, result_ids\n```\n###test function signature:\n```python\ndef test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_top_k_mmr_embeddings", + "reference": "def test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n query_embedding = [5.0, 0.0, 0.0]\n embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]]\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.8\n )\n\n assert np.isclose(0.8 * 4 / 5, result_similarities[0], atol=0.00001)\n assert np.isclose(\n 0.8 * 3 / 5 - (1 - 0.8) * (3 * 4 / 25 + 3 * 4 / 25),\n result_similarities[1],\n atol=0.00001,\n )\n assert np.isclose(\n 0.8 * -4 / 5 - (1 - 0.8) * (3 * -4 / 25 + 4 * 3 / 25),\n result_similarities[2],\n atol=0.00001,\n )\n assert result_ids == [0, 1, 2]\n\n # Tests that if the first embedding vector is close to the second,\n # it will return the third\n query_embedding = [1.0, 0.0, 1.0]\n embeddings = [[1.0, 0.0, 0.9], [1.0, 0.0, 0.8], [0.7, 0.0, 1.0]]\n\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.5\n )\n assert result_ids == [0, 2, 1]\n\n # Tests that embedding ids map properly to results\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, embedding_ids=[\"A\", \"B\", \"C\"], mmr_threshold=0.5\n )\n assert result_ids == [\"A\", \"C\", \"B\"]\n # Test that it will go back to the original order under a high threshold\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n assert result_ids == [0, 1, 2]\n\n # Test similarity_top_k works\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1, similarity_top_k=2\n )\n assert result_ids == [0, 1]\n\n # Test the results for get_top_k_embeddings and get_top_k_mmr_embeddings are the\n # same for threshold = 1\n query_embedding = [10, 23, 90, 78]\n embeddings = [[1, 23, 89, 68], [1, 74, 144, 23], [0.23, 0.0, 1.0, 9]]\n result_similarities_no_mmr, result_ids_no_mmr = get_top_k_embeddings(\n query_embedding, embeddings\n )\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n\n for result_no_mmr, result_with_mmr in zip(\n result_similarities_no_mmr, result_similarities\n ):\n assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)", + "completions": [ + "```python\ndef test_get_top_k_mmr_embeddings() -> None:\n query_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]\n embeddings = [\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n ]\n result_similarities, result_ids = get_top_k_embeddings(\n query_embedding, embeddings, similarity_top_k=3\n )\n assert result_similarities == [0.5, 0.4, 0.3]\n assert result_ids == [0, 1, 2]\n```\n" + ], + "line": 35, + "token": 309, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 97, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef frame_accuracy(\n logits: np.ndarray,\n label: np.ndarray,\n background_logit_threshold: float = 0.0,\n) -> float:\n top1_idx = np.argmax(logits, axis=-1)\n background_label = np.sum(label, axis=-1) == 0\n pred_background = (\n np.max(logits, axis=-1) <= background_logit_threshold\n ).astype(np.int32)\n\n # Extracts the label at the highest logit index for each input.\n top1_correct = np.take_along_axis(label, top1_idx[..., None], axis=-1)\n top1_correct = np.squeeze(top1_correct)\n foreground_correct = ~pred_background.astype(bool) * top1_correct\n\n # Count correctly classified background samples.\n background_correct = pred_background * background_label\n correct = foreground_correct + background_correct\n return np.sum(correct) / len(logits)\n```\n###test function signature:\n```python\n\n def test_frame_accuracy_all_background(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_frame_accuracy_all_background", + "reference": "\n def test_frame_accuracy_all_background(self):\n logits = np.array(\n [\n [1.5, 0.0, -1.0],\n [0.1, 0.7, 0.2],\n [0.1, 0.7, 0.2],\n [0.1, -0.7, 0.2],\n [-0.1, -0.7, -0.2],\n ],\n np.float32,\n )\n label = np.array(\n [\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n ],\n np.int32,\n )\n actual = metrics.frame_accuracy(logits, label)\n self.assertAlmostEqual(actual, 0.2)", + "completions": [ + "```python\ndef test_frame_accuracy_all_background(self):\n logits = np.array([[0.1, 0.2, 0.3, 0.4]])\n label = np.array([[0, 0, 0, 0]])\n self.assertAlmostEqual(frame_accuracy(logits, label), 1.0)\n```\n" + ], + "line": 28, + "token": 206, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 98, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_local_image_to_be_loaded", + "reference": "\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n file_path, _ = example_local_image\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=file_path,\n max_width=64,\n max_height=128,\n )\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert abs(scaling_factor - 0.5) < 1e-5\n assert decoding_result.shape == (64, 64, 3)\n assert (decoding_result == 0).all()", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n reference, image = example_local_image\n # when\n result = load_image_from_string(\n reference=reference,\n max_height=image.shape[0],\n max_width=image.shape[1],\n )\n # then\n assert result[0] == reference\n assert result[1] is None\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 99, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = \"htk\") -> Union[float, np.ndarray]:\n\n if mel_scale not in [\"slaney\", \"htk\", \"kaldi\"]:\n raise ValueError('mel_scale should be one of \"htk\", \"slaney\" or \"kaldi\".')\n\n if mel_scale == \"htk\":\n return 2595.0 * np.log10(1.0 + (freq / 700.0))\n elif mel_scale == \"kaldi\":\n return 1127.0 * np.log(1.0 + (freq / 700.0))\n\n min_log_hertz = 1000.0\n min_log_mel = 15.0\n logstep = 27.0 / np.log(6.4)\n mels = 3.0 * freq / 200.0\n\n if isinstance(freq, np.ndarray):\n log_region = freq >= min_log_hertz\n mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep\n elif freq >= min_log_hertz:\n mels = min_log_mel + np.log(freq / min_log_hertz) * logstep\n\n return mels\n```\n###test function signature:\n```python\n\n def test_hertz_to_mel(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_hertz_to_mel", + "reference": "\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(0.0), 0.0)\n self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)\n\n inputs = np.array([100, 200])\n expected = np.array([150.48910241, 283.22989816])\n self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))\n\n self.assertEqual(hertz_to_mel(0.0, \"slaney\"), 0.0)\n self.assertEqual(hertz_to_mel(100, \"slaney\"), 1.5)\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"slaney\"), expected))\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"kaldi\"), expected))\n\n with pytest.raises(ValueError):\n hertz_to_mel(100, mel_scale=None)", + "completions": [ + "```python\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(" + ], + "line": 27, + "token": 262, + "line_diff": 1, + "token_diff": 2 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/7/QS/line_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Python/7/QS/line_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..3d38b0779e0e4b0b9cbb95f83ea22ef8ccefd0f7 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/7/QS/line_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(15~24),subset_1(25~31),subset_2(31~43),subset_3(43~54),subset_4(56~70),subset_5(70~104),subset_6(104~749) +StarCoder2-15b,29.05,24.67,22.80,24.33,20.94,22.59,25.53 +CodeLlama-7b,21.57,25.27,25.26,31.56,28.68,27.55,25.91 +CodeLlama-13b,30.51,28.09,25.45,28.21,23.53,24.90,26.36 +CodeLlama-34b,30.78,27.64,25.71,22.99,26.28,27.86,27.48 +DeepSeek-Coder-1.3b,22.34,25.37,26.73,27.58,24.04,26.44,27.12 +DeepSeek-Coder-6.7b,25.09,24.86,24.24,25.94,26.58,21.44,23.71 +DeepSeek-Coder-33b,30.47,28.20,28.06,29.35,28.57,27.93,30.30 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/7/QS/token_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Python/7/QS/token_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..bfed0bd7c3271183485e6ca5587106bbe32877bf --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/7/QS/token_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(122~166),subset_1(168~242),subset_2(250~309),subset_3(310~384),subset_4(389~537),subset_5(560~807),subset_6(850~7038) +StarCoder2-15b,25.42,27.23,27.15,21.43,22.69,22.32,24.91 +CodeLlama-7b,22.90,25.36,31.09,27.10,25.25,25.94,28.06 +CodeLlama-13b,32.32,28.32,18.63,26.07,30.21,22.07,26.82 +CodeLlama-34b,26.09,33.28,27.24,24.44,23.49,27.79,27.23 +DeepSeek-Coder-1.3b,22.77,25.98,25.93,27.39,25.31,24.97,27.19 +DeepSeek-Coder-6.7b,25.44,25.53,25.88,24.16,24.85,22.53,23.40 +DeepSeek-Coder-33b,28.10,31.11,28.34,28.25,28.35,29.11,29.76 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/7/QS/tongji.py b/dataset/Test Generation/ComplexCodeEval-Python/7/QS/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/7/QS/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/8/EI/EI.json b/dataset/Test Generation/ComplexCodeEval-Python/8/EI/EI.json new file mode 100644 index 0000000000000000000000000000000000000000..fa9d39ed749f97cf579356987a6e60d93e402e8d --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/8/EI/EI.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_pair_vector_and_distance(g: dgl.DGLGraph):\n dst_pos = g.ndata[\"pos\"][g.edges()[1]] + g.edata[\"pbc_offshift\"]\n src_pos = g.ndata[\"pos\"][g.edges()[0]]\n bond_vec = dst_pos - src_pos\n bond_dist = torch.norm(bond_vec, dim=1)\n\n return bond_vec, bond_dist\n```\n###test function signature:\n```python\n\n def test_compute_angle(self, graph_Mo, graph_CH4):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_angle", + "reference": "\n def test_compute_angle(self, graph_Mo, graph_CH4):\n s1, g1, state1 = graph_Mo\n lattice = torch.tensor(s1.lattice.matrix, dtype=matgl.float_th).unsqueeze(dim=0)\n g1.edata[\"pbc_offshift\"] = torch.matmul(g1.edata[\"pbc_offset\"], lattice[0])\n g1.ndata[\"pos\"] = g1.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g1)\n g1.edata[\"bond_vec\"] = bv\n g1.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g1, 4.0)\n\n line_graph = create_line_graph(g1, 4.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))\n np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata[\"theta\"])), decimal=4)\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n s2, g2, state2 = graph_CH4\n lattice = torch.tensor(np.identity(3), dtype=matgl.float_th).unsqueeze(dim=0)\n g2.edata[\"pbc_offshift\"] = torch.matmul(g2.edata[\"pbc_offset\"], lattice[0])\n g2.ndata[\"pos\"] = g2.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g2)\n g2.edata[\"bond_vec\"] = bv\n g2.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g2, 2.0)\n\n line_graph = create_line_graph(g2, 2.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata[\"theta\"]))\n )\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )", + "completions": [ + "```python\n def test_compute_angle(self, graph_Mo, graph_CH4):\n \"\"\"\n Test compute_angle function\n \"\"\"\n # compute angle between Mo-C and Mo-H\n angle_Mo_C = compute_angle(graph_Mo, graph_CH4, 0, 1)\n angle_Mo_H = compute_angle(graph_Mo, graph_CH4, 0, 2)\n\n # check angle between Mo-C and Mo-H\n self.assertAlmostEqual(angle_Mo_C, 104.5, places=1)\n self.assertAlmostEqual(angle_Mo_H, 109.5, places=1)\n```\n" + ], + "line": 16, + "token": 148, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 1, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def _make_schedule(self, targets: List[Task]) -> List[Task]:\n self.schedule = []\n self.dependencies = self._build_dependencies(targets)\n\n edge_tups = []\n for node in self.dependencies:\n for dependency in self.dependencies[node]:\n edge_tups.append((dependency, node))\n\n for task in targets:\n # add edges from a dummy node to each target to guarantee\n # they will be included in the final schedule\n edge_tups.append((Executor.DUMMY_TASK_VALUE, task))\n\n def _compare_key(task: Union[Task, str]):\n if task == Executor.DUMMY_TASK_VALUE:\n return (\"\", 0)\n return (\n task.group_label() or \"\",\n -task.priority(),\n )\n\n graph = networkx.DiGraph(edge_tups)\n res = [\n t\n for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)\n if t != Executor.DUMMY_TASK_VALUE\n ]\n return res\n```\n###test function signature:\n```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_group_label_with_dependencies", + "reference": "\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = create_mock_task(\"task1\", result=1, group_label=\"group1\")\n task2 = create_mock_task(\n \"task2\", result=2, dependencies={\"task1\": task1}, group_label=\"group2\"\n )\n task3 = create_mock_task(\n \"task3\", result=3, dependencies={\"task2\": task2}, group_label=\"group1\"\n )\n\n executor = Executor([task3])\n schedule = executor._make_schedule([task3])\n scheduled_labels = [\n task.group_label() for task in schedule if task.group_label()\n ]\n\n # Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1\n group1_indices = [\n i for i, label in enumerate(scheduled_labels) if label == \"group1\"\n ]\n group2_index = scheduled_labels.index(\"group2\")\n\n assert (\n group1_indices[-1] > group2_index\n ), \"Task with the same group label but later dependency was not scheduled after different group label\"", + "completions": [ + "```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = Task(name=\"task1\", group_label=\"group1\")\n task2 = Task(name=\"task2\", group_label=\"group1\")\n task3 = Task(name=\"task3\", group_label=\"group2\")\n task4 = Task(name=\"task4\", group_label=\"group2\")\n task5 = Task(name=\"task5\", group_label=\"group3\")\n task6 = Task(name=\"task6\", group_label=\"group3\")\n task7 = Task(name=\"task7\", group_label=\"group4\")\n task8 = Task(name=\"task8\", group_label=\"group4\")\n task9 = Task(name=\"task9\", group_label=\"group5\")\n task10 = Task(name=\"task10\", group_label=\"group5\")\n\n # Creating dependencies between tasks\n task1.depends_on(task2)\n task2.depends_on(task3)\n task3.depends_on(task4)\n task4.depends_on(task5)\n task5.depends_on(task6)\n task6.depends_on(task7)\n task7.depends_on(task8)\n task8.depends_on(task9)\n task9.depends_on(task10)\n\n # Creating a list of tasks\n tasks = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies\n tasks_with_dependencies = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with group labels\n tasks_with_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_" + ], + "line": 36, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 2, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n):\n if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:\n return\n\n head_dim = dim // heads\n max_seq_len = find_multiple(max_seq_len, 8)\n\n for b in layers:\n b.attention.kv_cache = KVCache(\n max_batch_size, max_seq_len, heads, head_dim\n )\n\n freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)\n causal_mask = torch.tril(\n torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)\n )\n\n return causal_mask, freq_cis\n```\n###test function signature:\n```python\n\ndef test_setup_cache_max_batch_size_greater_than_max():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_setup_cache_max_batch_size_greater_than_max", + "reference": "\ndef test_setup_cache_max_batch_size_greater_than_max():\n layers = [\n Mock(attention=Mock(kw_cache=None)),\n Mock(attention=Mock(kw_cache=None)),\n ]\n max_batch_size = 64\n max_seq_len = 32\n heads = 4\n head_dim = 32\n block_size = 32\n rope_base = 1000\n setup_cache(\n max_batch_size + 10,\n max_seq_len,\n head_dim * heads,\n heads,\n layers,\n block_size,\n rope_base,\n )\n for layer in layers:\n assert isinstance(layer.attention.kw_cache, KVCache)\n assert layer.attention.kw_cache.k_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )\n assert layer.attention.kw_cache.v_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )", + "completions": [ + "```python\ndef test_setup_cache_max_batch_size_greater_than_max():\n max_batch_size = 10\n max_seq_len = 10\n dim = 10\n heads = 10\n layers = [1, 2, 3]\n block_size = 10\n rope_base = 10\n\n causal_mask, freq_cis = setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n )\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size" + ], + "line": 26, + "token": 168, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 3, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_file_to_be_downloaded", + "reference": "\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n response = Response()\n response.status_code = 200\n image = np.zeros((128, 128, 3), dtype=np.uint8)\n _, encoded_image = cv2.imencode(\".jpg\", image)\n response._content = encoded_image\n requests_get_mock.return_value = response\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=\"https://some/file.jpg\"\n )\n\n # then\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert scaling_factor is None\n assert decoding_result.shape == image.shape\n assert np.allclose(decoding_result, image)\n requests_get_mock.assert_called_once_with(\"https://some/file.jpg\")", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n reference = \"https://www.example.com/image.jpg\"\n max_height = 100\n max_width = 100\n requests_get_mock.return_value.content = b\"image_bytes\"\n # when\n result = load_image_from_string(\n reference=reference, max_height=max_height, max_width=max_width\n )\n # then\n assert result == (\n \"data:image/jpeg;base64,image_bytes\",\n None,\n )\n```\n" + ], + "line": 47, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 4, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_random_circuits", + "reference": "\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n placer = placer(connectivity=connectivity)\n layout_circ = Circuit(5)\n initial_layout = placer(layout_circ)\n router = Sabre(connectivity=connectivity, lookahead=look, decay_lookahead=decay)\n circuit = generate_random_circuit(nqubits=5, ngates=n_gates)\n measurement = gates.M(*range(5))\n circuit.add(measurement)\n transpiled_circuit, final_qubit_map = router(circuit, initial_layout)\n assert router.added_swaps >= 0\n assert_connectivity(connectivity, transpiled_circuit)\n assert_placement(transpiled_circuit, final_qubit_map)\n assert n_gates + router.added_swaps + 1 == transpiled_circuit.ngates\n assert_circuit_equivalence(\n original_circuit=circuit,\n transpiled_circuit=transpiled_circuit,\n final_map=final_qubit_map,\n initial_map=initial_layout,\n )\n assert transpiled_circuit.queue[-1].register_name == measurement.register_name", + "completions": [ + "```python\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n import random\n import networkx as nx\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import" + ], + "line": 27, + "token": 250, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 5, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef xx_circuit_step(source, strength, target, embodiment):\n\n permute_source_for_overlap, permute_target_for_overlap = None, None\n\n # apply all possible reflections, shifts to the source\n for source_reflection_name in reflection_options:\n reflected_source_coord, source_reflection, reflection_phase_shift = apply_reflection(\n source_reflection_name, source\n )\n for source_shift_name in shift_options:\n shifted_source_coord, source_shift, shift_phase_shift = apply_shift(\n source_shift_name, reflected_source_coord\n )\n\n # check for overlap, back out permutation\n source_shared, target_shared = None, None\n for i, j in [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]:\n\n if (\n abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi)) < EPSILON\n or abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi) - np.pi)\n < EPSILON\n ):\n source_shared, target_shared = i, j\n break\n if source_shared is None:\n continue\n\n # pick out the other coordinates\n source_first, source_second = (x for x in [0, 1, 2] if x != source_shared)\n target_first, target_second = (x for x in [0, 1, 2] if x != target_shared)\n\n # check for arccos validity\n r, s, u, v, x, y = decompose_xxyy_into_xxyy_xx(\n float(target[target_first]),\n float(target[target_second]),\n float(shifted_source_coord[source_first]),\n float(shifted_source_coord[source_second]),\n float(strength),\n )\n if any(math.isnan(val) for val in (r, s, u, v, x, y)):\n continue\n\n # OK: this combination of things works.\n # save the permutation which rotates the shared coordinate into ZZ.\n permute_source_for_overlap = canonical_rotation_circuit(source_first, source_second)\n permute_target_for_overlap = canonical_rotation_circuit(target_first, target_second)\n break\n\n if permute_source_for_overlap is not None:\n break\n\n if permute_source_for_overlap is None:\n raise QiskitError(\n \"Error during RZX decomposition: Could not find a suitable Weyl \"\n f\"reflection to match {source} to {target} along {strength}.\"\n )\n\n prefix_circuit, affix_circuit = QuantumCircuit(2), QuantumCircuit(2)\n\n # the basic formula we're trying to work with is:\n # target^p_t_f_o =\n # rs * (source^s_reflection * s_shift)^p_s_f_o * uv * operation * xy\n # but we're rearranging it into the form\n # target = affix source prefix\n # and computing just the prefix / affix circuits.\n\n # the outermost prefix layer comes from the (inverse) target permutation.\n prefix_circuit.compose(permute_target_for_overlap.inverse(), inplace=True)\n # the middle prefix layer comes from the local Z rolls.\n prefix_circuit.rz(2 * x, [0])\n prefix_circuit.rz(2 * y, [1])\n prefix_circuit.compose(embodiment, inplace=True)\n prefix_circuit.rz(2 * u, [0])\n prefix_circuit.rz(2 * v, [1])\n # the innermost prefix layer is source_reflection, shifted by source_shift,\n # finally conjugated by p_s_f_o.\n prefix_circuit.compose(permute_source_for_overlap, inplace=True)\n prefix_circuit.compose(source_reflection, inplace=True)\n prefix_circuit.global_phase += -np.log(reflection_phase_shift).imag\n prefix_circuit.global_phase += -np.log(shift_phase_shift).imag\n\n # the affix circuit is constructed in reverse.\n # first (i.e., innermost), we install the other half of the source transformations and p_s_f_o.\n affix_circuit.compose(source_reflection.inverse(), inplace=True)\n affix_circuit.compose(source_shift, inplace=True)\n affix_circuit.compose(permute_source_for_overlap.inverse(), inplace=True)\n # then, the other local rolls in the middle.\n affix_circuit.rz(2 * r, [0])\n affix_circuit.rz(2 * s, [1])\n # finally, the other half of the p_t_f_o conjugation.\n affix_circuit.compose(permute_target_for_overlap, inplace=True)\n\n return {\"prefix_circuit\": prefix_circuit, \"affix_circuit\": affix_circuit}\n```\n###test function signature:\n```python\n def test_xx_circuit_step(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_xx_circuit_step", + "reference": " def test_xx_circuit_step(self):\n\n for _ in range(100):\n source_coordinate, interaction, target_coordinate = self._generate_xxyy_test_case()\n\n source_embodiment = qiskit.QuantumCircuit(2)\n source_embodiment.append(UnitaryGate(canonical_matrix(*source_coordinate)), [0, 1])\n interaction_embodiment = qiskit.QuantumCircuit(2)\n interaction_embodiment.append(UnitaryGate(canonical_matrix(*interaction)), [0, 1])\n\n prefix_circuit, affix_circuit = itemgetter(\"prefix_circuit\", \"affix_circuit\")(\n xx_circuit_step(\n source_coordinate, interaction[0], target_coordinate, interaction_embodiment\n )\n )\n\n target_embodiment = QuantumCircuit(2)\n target_embodiment.compose(prefix_circuit, inplace=True)\n target_embodiment.compose(source_embodiment, inplace=True)\n target_embodiment.compose(affix_circuit, inplace=True)\n self.assertTrue(\n np.all(\n np.abs(\n qiskit.quantum_info.operators.Operator(target_embodiment).data\n - canonical_matrix(*target_coordinate)\n )\n < EPSILON\n )\n )", + "completions": [ + "```python\n def test_xx_circuit_step(self):\n source = np.array([0.1, 0.2, 0.3])\n strength = 0.4\n target = np.array([0.5, 0.6, 0.7])\n embodiment = QuantumCircuit(2)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.cx(0, 1)\n embodiment.barrier()\n embodiment.rz(0.1, 0)\n embodiment.rz(0.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.3, 0)\n embodiment.rz(0.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.5, 0)\n embodiment.rz(0.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.7, 0)\n embodiment.rz(0.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.9, 0)\n embodiment.rz(1.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.1, 0)\n embodiment.rz(1.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.3, 0)\n embodiment.rz(1.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.5, 0)\n embodiment.rz(1.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.7, 0)\n embodiment.rz(1.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.9, 0)\n embodiment.rz(2.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(2.1, 0)\n embodiment.rz(2.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment" + ], + "line": 90, + "token": 807, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 6, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_einsum(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_einsum", + "reference": "\n def test_einsum(self):\n x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n self.assertAllClose(\n knp.einsum(\"ijk,lkj->il\", x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"ijk,ikj->i\", x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"i...,j...k->...ijk\", x, y),\n np.einsum(\"i..., j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n self.assertAllClose(\n knp.Einsum(\"ijk,lkj->il\")(x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"ijk,ikj->i\")(x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n np.einsum(\"i...,j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))", + "completions": [ + "```python\n\n def test_arange(self):\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32),\n torch.arange(0, 10, 1, dtype=torch.float32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64),\n torch.arange(0, 10, 1, dtype=torch.float64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32),\n torch.arange(0, 10, 1, dtype=torch.int32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64),\n torch.arange(0, 10, 1, dtype=torch.int64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool),\n torch.arange(0, 10, 1, dtype=torch.bool),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64),\n torch.arange(0, 10, 1, dtype=torch.complex64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128),\n torch.arange(0, 10, 1, dtype=torch.complex128),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n )\n```\n" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 7, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_center_padding(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_center_padding", + "reference": "\n def test_spectrogram_center_padding(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"reflect\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,\n 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,\n 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,\n 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,\n 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,\n 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,\n 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,\n 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,\n 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,\n 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,\n 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,\n 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,\n 0.00217659, 0.00276204, 0.00260835, 0.00299299,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"constant\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,\n 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,\n 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,\n 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,\n 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,\n 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,\n 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,\n 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,\n 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,\n 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,\n 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,\n 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,\n 0.00788239, 0.00664407, 0.00824227, 0.00628301,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=False,\n )\n self.assertEqual(spec.shape, (257, 728))\n\n # fmt: off\n expected = np.array([\n 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,\n 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,\n 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,\n 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,\n 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,\n 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,\n 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,\n 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,\n 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,\n 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,\n 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,\n 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,\n 0.00811857, 0.00538216, 0.00685749, 0.00535275,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))", + "completions": [ + "```python\n def test_spectrogram_center_padding(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 4\n hop_length = 2\n fft_length = 16\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n np.testing.assert_array_equal(\n spectrogram,\n np.array(\n [\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0," + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 8, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_kword_ref_line", + "reference": "\ndef test_nanoplot_out_bars_with_kword_ref_line():\n\n out_bars_with_kword_ref_line = _generate_nanoplot(**CASES[12])\n\n assert _is_nanoplot_output(out_bars_with_kword_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = \"max\"\n y_ref_area = [1, 2]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = \"min\"\n expand_y = \"min\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_stroke_color=data_bar_negative_stroke_" + ], + "line": 749, + "token": 7038, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 9, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_take_along_axis(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_take_along_axis", + "reference": "\n def test_take_along_axis(self):\n x = np.arange(24).reshape([1, 2, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=1),\n np.take_along_axis(x, indices, axis=1),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=1)(x, indices),\n np.take_along_axis(x, indices, axis=1),\n )\n\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )\n\n # Test with axis=None\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.array([1, 2, 3], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=None),\n np.take_along_axis(x, indices, axis=None),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=None)(x, indices),\n np.take_along_axis(x, indices, axis=None),\n )\n\n # Test with negative indices\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )", + "completions": [ + "```python\n\n def test_take_along_axis(self):\n # test arange\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=1, dtype=torch.float32),\n torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=2, dtype=torch.float32),\n torch.tensor([0, 2, 4, 6, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=3, dtype=torch.float32),\n torch.tensor([0, 3, 6, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=4, dtype=torch.float32),\n torch.tensor([0, 4, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=5, dtype=torch.float32),\n torch.tensor([0, 5], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=6, dtype=torch.float32),\n torch.tensor([0, 6], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=7, dtype=torch.float32),\n torch.tensor([0, 7], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=8, dtype=torch.float32),\n torch.tensor([0, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=9, dtype=torch.float32),\n torch.tensor([0, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=10, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=11, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=12, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=13, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=14, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32" + ], + "line": 25, + "token": 195, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 10, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_correlation(\n response_level_scores: dict[str, list[dict[str, dict[str, Any]]]],\n name: str = '',\n) -> dict[str, dict[str, dict[str, float]]]:\n factscore_metrics = set(list_metrics(response_level_scores[FACTSCORE]))\n pred_metrics = set(list_metrics(response_level_scores[EVAL_METHOD]))\n result = {}\n\n for metric in factscore_metrics.intersection(pred_metrics):\n factscore_scores = [\n find_metric(response_level_scores[FACTSCORE][i], metric)\n for i in range(len(response_level_scores[FACTSCORE]))\n ]\n pred_scores = [\n find_metric(response_level_scores[EVAL_METHOD][i], metric)\n for i in range(len(response_level_scores[EVAL_METHOD]))\n ]\n\n # Remove failed runs\n for i in range(len(factscore_scores) - 1, -1, -1):\n if factscore_scores[i] == -1 or pred_scores[i] == -1:\n factscore_scores.pop(i)\n pred_scores.pop(i)\n\n # Cannot calculate correlation with only one data point\n if len(factscore_scores) <= 1 or len(pred_scores) <= 1:\n pearson_result, spearman_result = None, None\n else:\n pearson_result = stats.pearsonr(factscore_scores, pred_scores)\n spearman_result = stats.spearmanr(factscore_scores, pred_scores)\n\n scatter_plot(\n factscore_scores,\n pred_scores,\n f'{name}-{metric}',\n FACTSCORE,\n EVAL_METHOD,\n )\n result[metric] = {\n _PEARSON: CorrelationResult(pearson_result).result,\n _SPEARMAN: CorrelationResult(spearman_result).result,\n }\n\n return result\n```\n###test function signature:\n```python\n\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_correlation_one_data_point", + "reference": "\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n response_level_scores = {\n correlation_vs_factscore.FACTSCORE: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'N': 1}},\n ],\n correlation_vs_factscore.EVAL_METHOD: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'M': 1}},\n ],\n }\n actual_output = correlation_vs_factscore.compute_correlation(\n response_level_scores=response_level_scores\n )\n mock_pearsonr.assert_not_called()\n mock_spearmanr.assert_not_called()\n mock_scatter_plot.assert_called()\n self.assertIsInstance(actual_output, dict)\n self.assertLen(actual_output, 1)\n self.assertIn('S', actual_output)\n self.assertIsInstance(actual_output['S'], dict)\n self.assertLen(actual_output['S'], 2)\n self.assertIn('Pearson', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Pearson'], dict)\n self.assertIn('correlation', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['correlation']))\n self.assertIn('p_value', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['p_value']))\n self.assertIn('Spearman', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Spearman'], dict)\n self.assertIn('correlation', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['correlation']))\n self.assertIn('p_value', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['p_value']))", + "completions": [ + "```python\ndef test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n) -> None:\n response_level_scores = {\n FACTSCORE: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n EVAL_METHOD: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n }\n result = compute_correlation(response_level_scores)\n self.assertEqual(result, {})\n mock_spearmanr.assert_not_called()\n mock_pearsonr.assert_not_called()\n mock_scatter_plot.assert_not_called()\n```\n" + ], + "line": 54, + "token": 384, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 11, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_data_lines():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_data_lines", + "reference": "\ndef test_nanoplot_out_data_lines():\n\n out_data_lines = _generate_nanoplot(**CASES[0])\n\n assert _is_nanoplot_output(out_data_lines)\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_data_lines():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n expected = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_" + ], + "line": 749, + "token": 7038, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 12, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_daily_temp_mean", + "reference": "\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data = il_electricity_cdd_hdd_daily[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_daily[\"temperature_data\"]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert df.shape == (810, 3)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_mean\",\n ]\n\n assert round(df.temperature_mean.mean()) == 55.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"D\", tz=\"UTC\"\n )\n temperature_data = pd.DataFrame(\n {\"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]},\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n df = compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=[],\n cooling_balance_points=[],\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n )\n expected = pd.DataFrame(\n {\n \"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],\n \"n_hours_kept\": [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],\n \"n_hours_dropped\": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n },\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n pd.testing.assert_frame_equal(df, expected)\n```\n" + ], + "line": 166, + "token": 983, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 13, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef power_color(\n frequency,\n power,\n power_err=None,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n):\n freq_edges = np.asarray(freq_edges)\n if len(freq_edges) != 5:\n raise ValueError(\"freq_edges must have 5 elements\")\n\n frequency = np.asarray(frequency)\n power = np.asarray(power)\n\n if df is None:\n df = np.median(np.diff(frequency))\n input_frequency_low_edges = frequency - df / 2\n input_frequency_high_edges = frequency + df / 2\n\n if freq_edges.min() < input_frequency_low_edges[0]:\n raise ValueError(\"The minimum frequency is larger than the first frequency edge\")\n if freq_edges.max() > input_frequency_high_edges[-1]:\n raise ValueError(\"The maximum frequency is lower than the last frequency edge\")\n\n if power_err is None:\n power_err = power / np.sqrt(m)\n else:\n power_err = np.asarray(power_err)\n\n if freqs_to_exclude is not None:\n if len(np.shape(freqs_to_exclude)) == 1:\n freqs_to_exclude = [freqs_to_exclude]\n\n if (\n not isinstance(freqs_to_exclude, Iterable)\n or len(np.shape(freqs_to_exclude)) != 2\n or np.shape(freqs_to_exclude)[1] != 2\n ):\n raise ValueError(\"freqs_to_exclude must be of format [[f0, f1], [f2, f3], ...]\")\n for f0, f1 in freqs_to_exclude:\n frequency_mask = (input_frequency_low_edges > f0) & (input_frequency_high_edges < f1)\n idx0, idx1 = np.searchsorted(frequency, [f0, f1])\n power[frequency_mask] = np.mean([power[idx0], power[idx1]])\n\n var00, var00_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[:2],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var01, var01_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[2:4],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var10, var10_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[1:3],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var11, var11_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[3:5],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n pc0 = var00 / var01\n pc1 = var10 / var11\n pc0_err = pc0 * (var00_err / var00 + var01_err / var01)\n pc1_err = pc1 * (var10_err / var10 + var11_err / var11)\n if return_log:\n pc0_err = 1 / pc0 * pc0_err\n pc1_err = 1 / pc1 * pc1_err\n pc0 = np.log10(pc0)\n pc1 = np.log10(pc1)\n return pc0, pc0_err, pc1, pc1_err\n```\n###test function signature:\n```python\n\n def test_with_power_err(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_power_err", + "reference": "\n def test_with_power_err(self):\n pc0, pc0_err, pc1, pc1_err = power_color(\n self.freq,\n self.power,\n power_err=self.power / 2,\n )\n pc0e, pc0e_err, pc1e, pc1e_err = power_color(\n self.freq,\n self.power,\n power_err=self.power,\n )\n assert np.isclose(pc0, 1, atol=0.001)\n assert np.isclose(pc1, 1, atol=0.001)\n assert np.isclose(pc0e, 1, atol=0.001)\n assert np.isclose(pc1e, 1, atol=0.001)\n assert np.isclose(pc0e_err / pc0_err, 2, atol=0.001)\n assert np.isclose(pc1e_err / pc1_err, 2, atol=0.001)", + "completions": [ + "```python\n def test_with_power_err(self):\n frequency = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power_err = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n pc0, pc0_err, pc1, pc1_err = power_color(\n frequency,\n power,\n power_err=power_err,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n )\n self.assertEqual(pc0, 0.0)\n self.assertEqual(pc0_err, 0.0)\n self.assertEqual(pc1, 0.0)\n self.assertEqual(pc1_err, 0.0)\n```\n" + ], + "line": 97, + "token": 560, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 14, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def apply_top_k(\n self, logits: numpy.ndarray, filter_value=_MIN_FLOAT\n ) -> numpy.ndarray:\n logits_shape = logits.shape\n logits = logits.reshape(logits.shape[-1])\n top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]\n logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value\n\n return logits.reshape(logits_shape)\n```\n###test function signature:\n```python\n\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_apply_topk", + "reference": "\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11))\n\n token_generator = TokenGenerator(\n logits_shape=logits[-1].shape[-1],\n top_k=3,\n )\n\n filter_value = -float(\"Inf\")\n new_logits = token_generator.apply_top_k(\n logits.copy(), filter_value=filter_value\n )\n\n for _ in range(token_generator.top_k):\n curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits)\n assert curr_max > filter_value\n new_logits = numpy.delete(new_logits, idx)\n\n assert numpy.all(new_logits == filter_value)", + "completions": [ + "```python\n def test_apply_topk(\n self,\n ):\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9," + ], + "line": 21, + "token": 152, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(\n self,\n sources: List[Union[str, Path, ByteStream]],\n meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,\n ):\n parser = MarkdownIt(renderer_cls=RendererPlain)\n if self.table_to_single_line:\n parser.enable(\"table\")\n\n documents = []\n meta_list = normalize_metadata(meta=meta, sources_count=len(sources))\n\n for source, metadata in tqdm(\n zip(sources, meta_list),\n total=len(sources),\n desc=\"Converting markdown files to Documents\",\n disable=not self.progress_bar,\n ):\n try:\n bytestream = get_bytestream_from_source(source)\n except Exception as e:\n logger.warning(\"Could not read {source}. Skipping it. Error: {error}\", source=source, error=e)\n continue\n try:\n file_content = bytestream.data.decode(\"utf-8\")\n text = parser.render(file_content)\n except Exception as conversion_e:\n logger.warning(\n \"Failed to extract text from {source}. Skipping it. Error: {error}\",\n source=source,\n error=conversion_e,\n )\n continue\n\n merged_metadata = {**bytestream.meta, **metadata}\n document = Document(content=text, meta=merged_metadata)\n documents.append(document)\n\n return {\"documents\": documents}\n```\n###test function signature:\n```python\n\n def test_run_calls_normalize_metadata(self, test_files_path):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_calls_normalize_metadata", + "reference": "\n def test_run_calls_normalize_metadata(self, test_files_path):\n bytestream = ByteStream(data=b\"test\", meta={\"author\": \"test_author\", \"language\": \"en\"})\n\n converter = MarkdownToDocument()\n\n with patch(\"haystack.components.converters.markdown.normalize_metadata\") as normalize_metadata, patch(\n \"haystack.components.converters.markdown.MarkdownIt\"\n ):\n converter.run(sources=[bytestream, test_files_path / \"markdown\" / \"sample.md\"], meta={\"language\": \"it\"})\n\n # check that the metadata normalizer is called properly\n normalize_metadata.assert_called_with(meta={\"language\": \"it\"}, sources_count=2)", + "completions": [ + "```python\n def test_run_calls_normalize_metadata(self, test_files_path):\n # Arrange\n sources = [test_files_path / \"test_file.md\"]\n meta = {\"key\": \"value\"}\n\n # Act\n run(sources, meta)\n\n # Assert\n normalize_metadata.assert_called_once_with(meta=meta, sources_count=1)\n```\n" + ], + "line": 45, + "token": 311, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 16, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_config(config: object) -> None:\n if hasattr(config, \"kind\"):\n Config.running_config.classes[config.kind] = config # type: ignore\n Config.running_config.types[config.kind] = {}\n for field in fields(config): # type: ignore\n if hasattr(field, \"type\"):\n Config.running_config.types[config.kind][field.name] = optional_origin(field.type)\n else:\n raise RuntimeError(\"Config must have a 'kind' attribute\")\n```\n###test function signature:\n```python\n\ndef test_config():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_config", + "reference": "\ndef test_config():\n arg_parser = get_arg_parser()\n core_add_args(arg_parser)\n arg_parser.parse_known_args()\n cfg = Config(\"test\")\n cfg.add_config(ConfigTest)\n cfg.init_default_config()\n assert Config.dict() == {\n \"configtest\": {\n \"testvar1\": \"testing123\",\n \"testvar2\": 12345,\n \"testvar3\": {\n \"mydict\": {\"foo\": \"bar\", \"abc\": {\"def\": \"ghi\"}},\n \"myint\": 0,\n \"mystr\": \"Hello\",\n },\n }\n }\n cfg2 = Config(\"test2\")\n assert cfg.configtest.testvar1 == cfg2.configtest.testvar1\n assert cfg.configtest.testvar1 == Config.configtest.testvar1\n assert cfg.configtest.testvar2 == cfg2.configtest.testvar2\n assert cfg.configtest.testvar3 == cfg2.configtest.testvar3\n Config.configtest.testvar2 += 1\n assert cfg.configtest.testvar2 == 12346\n with pytest.raises(ConfigNotFoundError):\n Config.does_not_exist.foo = \"bar\"\n with pytest.raises(ConfigNotFoundError):\n cfg.does_not_exist.foo = \"bar\"", + "completions": [ + "```python\ndef test_config():\n from dataclasses import dataclass, field\n from typing import Optional\n\n @dataclass\n class Config:\n kind: str\n a: Optional[int] = field(default=None)\n b: Optional[int] = field(default=None)\n\n add_config(Config)\n\n assert Config.running_config.classes[\"Config\"] == Config\n assert Config.running_config.types[\"Config\"] == {\"a\": int, \"b\": int}\n```\n" + ], + "line": 19, + "token": 166, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 17, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_with_num_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_with_num_ref_line", + "reference": "\ndef test_nanoplot_out_with_num_ref_line():\n\n out_with_num_ref_line = _generate_nanoplot(**CASES[1])\n\n assert _is_nanoplot_output(out_with_num_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_with_num_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color," + ], + "line": 749, + "token": 7038, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 18, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def begin_typing(self, x=None):\n self.capturekeystrokes = True\n # Disable keypress shortcuts, which may otherwise cause the figure to\n # be saved, closed, etc., until the user stops typing. The way to\n # achieve this depends on whether toolmanager is in use.\n stack = ExitStack() # Register cleanup actions when user stops typing.\n self._on_stop_typing = stack.close\n toolmanager = getattr(\n self.ax.figure.canvas.manager, \"toolmanager\", None)\n if toolmanager is not None:\n # If using toolmanager, lock keypresses, and plan to release the\n # lock when typing stops.\n toolmanager.keypresslock(self)\n stack.callback(toolmanager.keypresslock.release, self)\n else:\n # If not using toolmanager, disable all keypress-related rcParams.\n # Avoid spurious warnings if keymaps are getting deprecated.\n with _api.suppress_matplotlib_deprecation_warning():\n stack.enter_context(mpl.rc_context(\n {k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n```\n###test function signature:\n```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_TextBox", + "reference": "\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n plt.rcParams._set(\"toolbar\", toolbar)\n\n submit_event = mock.Mock(spec=noop, return_value=None)\n text_change_event = mock.Mock(spec=noop, return_value=None)\n tool = widgets.TextBox(ax, '')\n tool.on_submit(submit_event)\n tool.on_text_change(text_change_event)\n\n assert tool.text == ''\n\n do_event(tool, '_click')\n\n tool.set_val('x**2')\n\n assert tool.text == 'x**2'\n assert text_change_event.call_count == 1\n\n tool.begin_typing()\n tool.stop_typing()\n\n assert submit_event.call_count == 2\n\n do_event(tool, '_click')\n do_event(tool, '_keypress', key='+')\n do_event(tool, '_keypress', key='5')\n\n assert text_change_event.call_count == 3", + "completions": [ + "```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n with ExitStack() as stack:\n stack.enter_context(mpl.rc_context({k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n ax.begin_typing()\n assert ax.capturekeystrokes is True\n assert toolbar.keypresslock.is_locked(ax)\n```\n" + ], + "line": 31, + "token": 258, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 19, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def read_events(\n db: PrefectDBInterface,\n session: AsyncSession,\n events_filter: EventFilter,\n limit: \"int | None\" = None,\n offset: \"int | None\" = None,\n) -> Sequence[ORMEvent]:\n # Always order by occurred timestamp, with placeholder for order direction\n order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc\n\n # Check if distinct fields are provided\n if distinct_fields := build_distinct_queries(events_filter):\n # Define window function\n window_function = (\n sa.func.row_number()\n .over(partition_by=distinct_fields, order_by=order(db.Event.occurred))\n .label(\"row_number\")\n )\n # Create a subquery with the window function\n subquery = (\n sa.select(db.Event, window_function)\n .where(\n sa.and_(\n *events_filter.build_where_clauses(db)\n ) # Ensure the same filters are applied here\n )\n .subquery()\n )\n\n # Alias the subquery for easier column references\n aliased_table = aliased(db.Event, subquery)\n\n # Create the final query from the subquery, filtering to get only rows with row_number = 1\n select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1)\n\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(subquery.c.occurred))\n\n else:\n # If no distinct fields are provided, create a query for all events\n select_events_query = sa.select(db.Event).where(\n sa.and_(*events_filter.build_where_clauses(db))\n )\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(db.Event.occurred))\n\n if limit is not None:\n limit = max(0, min(limit, events_filter.logical_limit))\n select_events_query = select_events_query.limit(limit=limit)\n if offset is not None:\n select_events_query = select_events_query.offset(offset=offset)\n\n logger.debug(\"Running PostgreSQL query: %s\", select_events_query)\n\n select_events_query_result = await session.execute(select_events_query)\n return select_events_query_result.scalars().unique().all()\n```\n###test function signature:\n```python\n\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_read_events", + "reference": "\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_1.id\n\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=3),\n until=pendulum.now(\"UTC\").subtract(days=1),\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_2.id\n\n # Read criteria should apply AND logic\n # Nothing should be returned for this resource\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n resource=EventResourceFilter(id=[\"prefect.garbage.foo\"]),\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 0", + "completions": [ + "```python\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n # Arrange\n events_filter = EventFilter(\n occurred_after=datetime.datetime(2022, 1, 1),\n occurred_before=datetime.datetime(2022, 1, 2),\n order=EventOrder.DESC,\n )\n # Act\n events = await read_events(\n db=self.db,\n session=session,\n events_filter=events_filter,\n )\n # Assert\n self.assertEqual(len(events), 2)\n self.assertEqual(events[0], event_2)\n self.assertEqual(events[1], event_1)\n```\n" + ], + "line": 58, + "token": 412, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 20, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n def test_evaluate_diff_dim(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_evaluate_diff_dim", + "reference": " def test_evaluate_diff_dim(self):\n x1 = np.arange(3, 10, 2)\n kde = mlab.GaussianKDE(x1)\n x2 = np.arange(3, 12, 2)\n y_expected = [\n 0.08797252, 0.11774109, 0.11774109, 0.08797252, 0.0370153\n ]\n y = kde.evaluate(x2)\n np.testing.assert_array_almost_equal(y, y_expected, 7)", + "completions": [ + "```python\n def test_evaluate_diff_dim(self):\n \"\"\"\n Test evaluate with different dimensions\n \"\"\"\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n\n # Generate a random dataset\n X, _ = make_blobs(n_samples=100, n_features=10, centers=5,\n random_state=42)\n\n # Standardize the data\n scaler = StandardScaler()\n X = scaler.fit_transform(X)\n\n # Create a GaussianMixture object\n gmm = GaussianMixture(n_components=5, covariance_type='full')\n gmm.fit(X)\n\n # Generate a random set of points\n n_points = 100\n points = np.random.rand(n_points, 10)\n\n # Evaluate the points\n result = gmm.evaluate(points)\n\n # Check that the result is correct\n self.assertEqual(result.shape, (n_points,))\n self.assertTrue(np.all(result >= 0))\n self.assertTrue(np.all(result <= 1))\n```\n" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 21, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _differentiate(func, x, *, args=(), atol=None, rtol=None, maxiter=10,\n order=8, initial_step=0.5, step_factor=2.0,\n step_direction=0, preserve_shape=False, callback=None):\n # TODO (followup):\n # - investigate behavior at saddle points\n # - array initial_step / step_factor?\n # - multivariate functions?\n\n res = _differentiate_iv(func, x, args, atol, rtol, maxiter, order, initial_step,\n step_factor, step_direction, preserve_shape, callback)\n (func, x, args, atol, rtol, maxiter, order,\n h0, fac, hdir, preserve_shape, callback) = res\n\n # Initialization\n # Since f(x) (no step) is not needed for central differences, it may be\n # possible to eliminate this function evaluation. However, it's useful for\n # input validation and standardization, and everything else is designed to\n # reduce function calls, so let's keep it simple.\n temp = eim._initialize(func, (x,), args, preserve_shape=preserve_shape)\n func, xs, fs, args, shape, dtype = temp\n x, f = xs[0], fs[0]\n df = np.full_like(f, np.nan)\n # Ideally we'd broadcast the shape of `hdir` in `_elementwise_algo_init`, but\n # it's simpler to do it here than to generalize `_elementwise_algo_init` further.\n # `hdir` and `x` are already broadcasted in `_differentiate_iv`, so we know\n # that `hdir` can be broadcasted to the final shape.\n hdir = np.broadcast_to(hdir, shape).flatten()\n\n status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress\n nit, nfev = 0, 1 # one function evaluations performed above\n # Boolean indices of left, central, right, and (all) one-sided steps\n il = hdir < 0\n ic = hdir == 0\n ir = hdir > 0\n io = il | ir\n\n # Most of these attributes are reasonably obvious, but:\n # - `fs` holds all the function values of all active `x`. The zeroth\n # axis corresponds with active points `x`, the first axis corresponds\n # with the different steps (in the order described in\n # `_differentiate_weights`).\n # - `terms` (which could probably use a better name) is half the `order`,\n # which is always even.\n work = _RichResult(x=x, df=df, fs=f[:, np.newaxis], error=np.nan, h=h0,\n df_last=np.nan, error_last=np.nan, h0=h0, fac=fac,\n atol=atol, rtol=rtol, nit=nit, nfev=nfev,\n status=status, dtype=dtype, terms=(order+1)//2,\n hdir=hdir, il=il, ic=ic, ir=ir, io=io)\n # This is the correspondence between terms in the `work` object and the\n # final result. In this case, the mapping is trivial. Note that `success`\n # is prepended automatically.\n res_work_pairs = [('status', 'status'), ('df', 'df'), ('error', 'error'),\n ('nit', 'nit'), ('nfev', 'nfev'), ('x', 'x')]\n\n def pre_func_eval(work):\n \"\"\"Determine the abscissae at which the function needs to be evaluated.\n\n See `_differentiate_weights` for a description of the stencil (pattern\n of the abscissae).\n\n In the first iteration, there is only one stored function value in\n `work.fs`, `f(x)`, so we need to evaluate at `order` new points. In\n subsequent iterations, we evaluate at two new points. Note that\n `work.x` is always flattened into a 1D array after broadcasting with\n all `args`, so we add a new axis at the end and evaluate all point\n in one call to the function.\n\n For improvement:\n - Consider measuring the step size actually taken, since `(x + h) - x`\n is not identically equal to `h` with floating point arithmetic.\n - Adjust the step size automatically if `x` is too big to resolve the\n step.\n - We could probably save some work if there are no central difference\n steps or no one-sided steps.\n \"\"\"\n n = work.terms # half the order\n h = work.h # step size\n c = work.fac # step reduction factor\n d = c**0.5 # square root of step reduction factor (one-sided stencil)\n # Note - no need to be careful about dtypes until we allocate `x_eval`\n\n if work.nit == 0:\n hc = h / c**np.arange(n)\n hc = np.concatenate((-hc[::-1], hc))\n else:\n hc = np.asarray([-h, h]) / c**(n-1)\n\n if work.nit == 0:\n hr = h / d**np.arange(2*n)\n else:\n hr = np.asarray([h, h/d]) / c**(n-1)\n\n n_new = 2*n if work.nit == 0 else 2 # number of new abscissae\n x_eval = np.zeros((len(work.hdir), n_new), dtype=work.dtype)\n il, ic, ir = work.il, work.ic, work.ir\n x_eval[ir] = work.x[ir, np.newaxis] + hr\n x_eval[ic] = work.x[ic, np.newaxis] + hc\n x_eval[il] = work.x[il, np.newaxis] - hr\n return x_eval\n\n def post_func_eval(x, f, work):\n \"\"\" Estimate the derivative and error from the function evaluations\n\n As in `pre_func_eval`: in the first iteration, there is only one stored\n function value in `work.fs`, `f(x)`, so we need to add the `order` new\n points. In subsequent iterations, we add two new points. The tricky\n part is getting the order to match that of the weights, which is\n described in `_differentiate_weights`.\n\n For improvement:\n - Change the order of the weights (and steps in `pre_func_eval`) to\n simplify `work_fc` concatenation and eliminate `fc` concatenation.\n - It would be simple to do one-step Richardson extrapolation with `df`\n and `df_last` to increase the order of the estimate and/or improve\n the error estimate.\n - Process the function evaluations in a more numerically favorable\n way. For instance, combining the pairs of central difference evals\n into a second-order approximation and using Richardson extrapolation\n to produce a higher order approximation seemed to retain accuracy up\n to very high order.\n - Alternatively, we could use `polyfit` like Jacobi. An advantage of\n fitting polynomial to more points than necessary is improved noise\n tolerance.\n \"\"\"\n n = work.terms\n n_new = n if work.nit == 0 else 1\n il, ic, io = work.il, work.ic, work.io\n\n # Central difference\n # `work_fc` is *all* the points at which the function has been evaluated\n # `fc` is the points we're using *this iteration* to produce the estimate\n work_fc = (f[ic, :n_new], work.fs[ic, :], f[ic, -n_new:])\n work_fc = np.concatenate(work_fc, axis=-1)\n if work.nit == 0:\n fc = work_fc\n else:\n fc = (work_fc[:, :n], work_fc[:, n:n+1], work_fc[:, -n:])\n fc = np.concatenate(fc, axis=-1)\n\n # One-sided difference\n work_fo = np.concatenate((work.fs[io, :], f[io, :]), axis=-1)\n if work.nit == 0:\n fo = work_fo\n else:\n fo = np.concatenate((work_fo[:, 0:1], work_fo[:, -2*n:]), axis=-1)\n\n work.fs = np.zeros((len(ic), work.fs.shape[-1] + 2*n_new))\n work.fs[ic] = work_fc\n work.fs[io] = work_fo\n\n wc, wo = _differentiate_weights(work, n)\n work.df_last = work.df.copy()\n work.df[ic] = fc @ wc / work.h\n work.df[io] = fo @ wo / work.h\n work.df[il] *= -1\n\n work.h /= work.fac\n work.error_last = work.error\n # Simple error estimate - the difference in derivative estimates between\n # this iteration and the last. This is typically conservative because if\n # convergence has begin, the true error is much closer to the difference\n # between the current estimate and the *next* error estimate. However,\n # we could use Richarson extrapolation to produce an error estimate that\n # is one order higher, and take the difference between that and\n # `work.df` (which would just be constant factor that depends on `fac`.)\n work.error = abs(work.df - work.df_last)\n\n def check_termination(work):\n \"\"\"Terminate due to convergence, non-finite values, or error increase\"\"\"\n stop = np.zeros_like(work.df).astype(bool)\n\n i = work.error < work.atol + work.rtol*abs(work.df)\n work.status[i] = eim._ECONVERGED\n stop[i] = True\n\n if work.nit > 0:\n i = ~((np.isfinite(work.x) & np.isfinite(work.df)) | stop)\n work.df[i], work.status[i] = np.nan, eim._EVALUEERR\n stop[i] = True\n\n # With infinite precision, there is a step size below which\n # all smaller step sizes will reduce the error. But in floating point\n # arithmetic, catastrophic cancellation will begin to cause the error\n # to increase again. This heuristic tries to avoid step sizes that are\n # too small. There may be more theoretically sound approaches for\n # detecting a step size that minimizes the total error, but this\n # heuristic seems simple and effective.\n i = (work.error > work.error_last*10) & ~stop\n work.status[i] = _EERRORINCREASE\n stop[i] = True\n\n return stop\n\n def post_termination_check(work):\n return\n\n def customize_result(res, shape):\n return shape\n\n return eim._loop(work, callback, shape, maxiter, func, args, dtype,\n pre_func_eval, post_func_eval, check_termination,\n post_termination_check, customize_result, res_work_pairs,\n preserve_shape)\n```\n###test function signature:\n```python\n\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_maxiter_callback", + "reference": "\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n x = 0.612814\n dist = stats.norm()\n maxiter = 3\n\n def f(x):\n res = dist.cdf(x)\n return res\n\n default_order = 8\n res = differentiate(f, x, maxiter=maxiter, rtol=1e-15)\n assert not np.any(res.success)\n assert np.all(res.nfev == default_order + 1 + (maxiter - 1)*2)\n assert np.all(res.nit == maxiter)\n\n def callback(res):\n callback.iter += 1\n callback.res = res\n assert hasattr(res, 'x')\n assert res.df not in callback.dfs\n callback.dfs.add(res.df)\n assert res.status == eim._EINPROGRESS\n if callback.iter == maxiter:\n raise StopIteration\n callback.iter = -1 # callback called once before first iteration\n callback.res = None\n callback.dfs = set()\n\n res2 = differentiate(f, x, callback=callback, rtol=1e-15)\n # terminating with callback is identical to terminating due to maxiter\n # (except for `status`)\n for key in res.keys():\n if key == 'status':\n assert res[key] == eim._ECONVERR\n assert callback.res[key] == eim._EINPROGRESS\n assert res2[key] == eim._ECALLBACK\n else:\n assert res2[key] == callback.res[key] == res[key]", + "completions": [ + "```python\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n def callback(work):\n if work.nit == 0:\n work.maxiter = 1\n else:\n work.maxiter = 0\n res = _differentiate(lambda x: x**2, 0, maxiter=2, callback=callback)\n self.assertEqual(res.nit, 1)\n self.assertEqual(res.status, 0)\n```\n" + ], + "line": 187, + "token": 1981, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 22, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef get_standard_metadata(\n pyproject_dict: Mapping[str, Any],\n settings: ScikitBuildSettings,\n) -> StandardMetadata:\n new_pyproject_dict = dict(pyproject_dict)\n # Handle any dynamic metadata\n for field, provider, config in load_dynamic_metadata(settings.metadata):\n if provider is None:\n msg = f\"{field} is missing provider\"\n raise KeyError(msg)\n if field not in pyproject_dict.get(\"project\", {}).get(\"dynamic\", []):\n msg = f\"{field} is not in project.dynamic\"\n raise KeyError(msg)\n new_pyproject_dict[\"project\"][field] = provider.dynamic_metadata(field, config)\n new_pyproject_dict[\"project\"][\"dynamic\"].remove(field)\n\n metadata = StandardMetadata.from_pyproject(new_pyproject_dict)\n\n # For scikit-build-core < 0.5, we keep the normalized name for back-compat\n if settings.minimum_version is not None and settings.minimum_version < Version(\n \"0.5\"\n ):\n metadata.name = metadata.canonical_name\n\n # The description field is required to be one line. Instead of merging it\n # or cutting off subsequent lines (setuptools), we throw a nice error.\n # But we didn't validate before 0.9.\n if (\n settings.minimum_version is None or settings.minimum_version >= Version(\"0.9\")\n ) and \"\\n\" in (metadata.description or \"\"):\n msg = \"Multiple lines in project.description are not supported; this is supposed to be a one line summary\"\n raise ValueError(msg)\n\n return metadata\n```\n###test function signature:\n```python\n\ndef test_plugin_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_plugin_metadata", + "reference": "\ndef test_plugin_metadata():\n reason_msg = (\n \"install hatch-fancy-pypi-readme and setuptools-scm to test the \"\n \"dynamic metadata plugins\"\n )\n pytest.importorskip(\"hatch_fancy_pypi_readme\", reason=reason_msg)\n pytest.importorskip(\"setuptools_scm\", reason=reason_msg)\n if shutil.which(\"git\") is None:\n pytest.skip(\"git is not installed\")\n\n shutil.copy(\"plugin_project.toml\", \"pyproject.toml\")\n\n subprocess.run([\"git\", \"init\", \"--initial-branch=main\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.name\", \"bot\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.email\", \"bot@scikit-build.org\"], check=True)\n subprocess.run([\"git\", \"add\", \"pyproject.toml\"], check=True)\n subprocess.run([\"git\", \"commit\", \"-m\", \"initial commit\"], check=True)\n subprocess.run([\"git\", \"tag\", \"v0.1.0\", \"-m\", \"initial commint\"], check=True)\n\n with Path(\"pyproject.toml\").open(\"rb\") as ft:\n pyproject = tomllib.load(ft)\n settings_reader = SettingsReader(pyproject, {}, state=\"metadata_wheel\")\n settings = settings_reader.settings\n\n settings_reader.validate_may_exit()\n\n metadata = get_standard_metadata(pyproject, settings)\n\n assert str(metadata.version) == \"0.1.0\"\n assert metadata.readme == pyproject_metadata.Readme(\n \"Fragment #1Fragment #2\", None, \"text/x-rst\"\n )\n\n assert set(GetRequires().dynamic_metadata()) == {\n \"hatch-fancy-pypi-readme>=22.3\",\n \"setuptools-scm\",\n }", + "completions": [ + "```python\n\ndef test_plugin_metadata():\n from skbuild import __version__ as skbuild_version\n from skbuild.settings import ScikitBuildSettings\n from skbuild.metadata import StandardMetadata\n from skbuild.utils import load_dynamic_metadata\n from skbuild.version import Version\n\n pyproject_dict = {\n \"build-system\": {\n \"requires\": [\"setuptools>=40.8.0\", \"wheel\"],\n \"build-backend\": \"setuptools.build_meta\",\n },\n \"project\": {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n \"description\": \"A build system for Python packages\",\n \"dynamic\": [\"name\", \"version\"],\n },\n }\n\n settings = ScikitBuildSettings(\n minimum_version=Version(\"0.1.0\"),\n metadata=[\n (\"name\", None, None),\n (\"version\", None, None),\n ],\n )\n\n metadata = get_standard_metadata(pyproject_dict, settings)\n\n assert metadata.name == \"scikit-build\"\n assert metadata.version == \"0.1.0\"\n assert metadata.description == \"A build system for Python packages\"\n\n assert metadata.canonical_name == \"scikit-build\"\n assert metadata.canonical_version == \"0.1.0\"\n\n assert metadata.dynamic_metadata == {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n }\n\n assert metadata.dynamic_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_providers == [None, None]\n\n assert metadata.dynamic_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_metadata_providers == [None, None]\n\n assert metadata.dynamic_metadata_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields_by_provider == {None: [\"name\", \"version\"]}\n\n assert metadata.dynamic_metadata_providers_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_configs_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field == {\n (None, \"name\"): \"scikit-build\",\n (None, \"version\"): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider == {\n (\"name\", None): \"scikit-build\",\n (\"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider == {\n (\"name\", None): None,\n (\"version\", None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config == {\n (None, \"name\", None): \"scikit-build\",\n (None, \"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config == {\n (\"name\", None, None): \"scikit-build\",\n (\"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider_and_config == {\n (\"name\", None, None): None,\n (\"version\", None, None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config_and_provider == {\n (None, \"name\", None, None): \"scikit-build\",\n (None, \"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config_and_provider == {\n (\"name\", None, None" + ], + "line": 40, + "token": 347, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 23, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _get_datetime_format(value):\n if not isinstance(value, pd.Series):\n value = pd.Series(value)\n\n value = value[~value.isna()]\n value = value.astype(str).to_numpy()\n\n return _guess_datetime_format_for_array(value)\n```\n###test function signature:\n```python\ndef test__get_datetime_format():\n # Setup\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test__get_datetime_format", + "reference": "def test__get_datetime_format():\n # Setup\n string_value = '2021-02-02'\n list_value = [np.nan, '2021-02-02']\n series_value = pd.Series(['2021-02-02T12:10:59'])\n\n # Run\n string_out = _get_datetime_format(string_value)\n list_out = _get_datetime_format(list_value)\n series_out = _get_datetime_format(series_value)\n\n # Assert\n expected_output = '%Y-%m-%d'\n assert string_out == expected_output\n assert list_out == expected_output\n assert series_out == '%Y-%m-%dT%H:%M:%S'", + "completions": [ + "```python\ndef test__get_datetime_format():\n # Setup\n value = pd.Series([\"2020-01-01\", \"2020-01-02\", \"2020-01-03\"])\n expected = \"%Y-%m-%d\"\n\n # Exercise\n actual = _get_datetime_format(value)\n\n # Verify\n assert actual == expected\n```\n" + ], + "line": 17, + "token": 124, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 24, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def load(self) -> None:\n super().load()\n\n _path = str(\n importlib_resources.files(\"distilabel\")\n / \"steps\"\n / \"tasks\"\n / \"templates\"\n / \"ultrafeedback\"\n / f\"{self.aspect}.jinja2\"\n )\n\n self._template = Template(open(_path).read())\n```\n###test function signature:\n```python\n\n def test_process_with_complex_aspect(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_process_with_complex_aspect", + "reference": "\n def test_process_with_complex_aspect(self) -> None:\n pipeline = Pipeline(name=\"unit-test-pipeline\")\n llm = UltraFeedbackLLM()\n\n task = UltraFeedback(\n name=\"ultrafeedback\",\n aspect=\"truthfulness\",\n llm=llm,\n pipeline=pipeline,\n )\n task.load()\n\n assert next(\n task.process([{\"instruction\": \"test\", \"generations\": [\"A\", \"B\"]}])\n ) == [\n {\n \"instruction\": \"test\",\n \"generations\": [\"A\", \"B\"],\n \"types\": [1, 2],\n \"rationales\": [\"text\", \"text\"],\n \"ratings\": [1, 2],\n \"rationales-for-ratings\": [\"text\", \"text\"],\n \"model_name\": \"ultrafeedback-model\",\n }\n ]", + "completions": [ + "```python\n def test_process_with_complex_aspect(self) -> None:\n self.aspect = \"complex_aspect\"\n self.load()\n self.assertEqual(\n self._template.render(\n {\n \"aspect\": self.aspect,\n \"aspect_name\": \"Complex Aspect\",\n \"aspect_description\": \"This is a complex aspect\",\n \"aspect_type\": \"complex\",\n \"aspect_value\": \"1\",\n \"aspect_weight\": \"1\",\n \"aspect_color\": \"red\",\n \"aspect_icon\": \"icon-circle\",\n \"aspect_icon_color\": \"red\",\n \"aspect_icon_bg_color\": \"white\",\n \"aspect_icon_size\": \"20\",\n \"aspect_icon_border_color\": \"red\",\n \"aspect_icon_border_width\": \"2\",\n \"aspect_icon_border_radius\": \"50\",\n \"aspect_icon_padding\": \"5\",\n \"aspect_icon_margin\": \"5\",\n \"aspect_icon_bg_size\": \"20\",\n \"aspect_icon_bg_border_color\": \"red\",\n \"aspect_icon_bg_border_width\": \"2\",\n \"aspect_icon_bg_border_radius\": \"50\",\n \"aspect_icon_bg_padding\": \"5\",\n \"aspect_icon_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_" + ], + "line": 21, + "token": 145, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 25, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _setup_cuda_device(device: torch.device) -> torch.device:\n local_rank = _get_local_rank() or 0\n if device.index is None:\n device = torch.device(type=\"cuda\", index=local_rank)\n\n # Ensure index is available before setting device\n if device.index >= torch.cuda.device_count():\n raise RuntimeError(\n \"The local rank is larger than the number of available GPUs.\"\n )\n\n torch.cuda.set_device(device)\n return device\n```\n###test function signature:\n```python\n\n def test_get_gpu_device(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_gpu_device", + "reference": "\n def test_get_gpu_device(self) -> None:\n device_idx = torch.cuda.device_count() - 1\n assert device_idx >= 0\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(device_idx)}, clear=True):\n device = get_device()\n assert device.type == \"cuda\"\n assert device.index == device_idx\n assert device.index == torch.cuda.current_device()\n\n # Test that we raise an error if the device index is specified on distributed runs\n if device_idx > 0:\n with pytest.raises(\n RuntimeError,\n match=f\"Device specified is cuda:0 but was assigned cuda:{device_idx}\",\n ):\n device = get_device(\"cuda:0\")\n\n invalid_device_idx = device_idx + 10\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(invalid_device_idx)}):\n with pytest.raises(\n RuntimeError,\n match=\"The local rank is larger than the number of available GPUs\",\n ):\n device = get_device(\"cuda\")\n\n # Test that we fall back to 0 if LOCAL_RANK is not specified\n device = torch.device(_get_device_type_from_env())\n device = _setup_cuda_device(device)\n assert device.type == \"cuda\"\n assert device.index == 0\n assert device.index == torch.cuda.current_device()", + "completions": [ + "```python\n def test_get_gpu_device(self) -> None:\n device = torch.device(type=\"cuda\", index=0)\n self.assertEqual(device, _setup_cuda_device(device))\n```\n" + ], + "line": 21, + "token": 157, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):\n local_dir = settings.home_dir / settings.cache_dir / \"models\" / model_id\n client.load_artifact(local_dir, posixpath.join(\"models\", model_id))\n meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)\n # internal worker case, no token based http client\n if not local_dir.exists():\n raise RuntimeError(f\"Cannot find existing model {project_key}.{model_id} in {local_dir}\")\n with (Path(local_dir) / META_FILENAME).open(encoding=\"utf-8\") as f:\n file_meta = yaml.load(f, Loader=yaml.Loader)\n classification_labels = cls.cast_labels(meta_response)\n meta = ModelMeta(\n name=meta_response.name,\n description=meta_response.description,\n model_type=SupportedModelTypes[meta_response.modelType],\n feature_names=meta_response.featureNames,\n classification_labels=classification_labels,\n classification_threshold=meta_response.threshold,\n loader_module=file_meta[\"loader_module\"],\n loader_class=file_meta[\"loader_class\"],\n )\n\n model_py_ver = (\n tuple(meta_response.languageVersion.split(\".\")) if \"PYTHON\" == meta_response.language.upper() else None\n )\n\n clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver)\n\n constructor_params = meta.__dict__\n constructor_params[\"id\"] = str(model_id)\n\n del constructor_params[\"loader_module\"]\n del constructor_params[\"loader_class\"]\n\n model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params)\n return model\n```\n###test function signature:\n```python\n\ndef test_model_download(request):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_model_download", + "reference": "\ndef test_model_download(request):\n model: BaseModel = request.getfixturevalue(\"german_credit_model\")\n project_key = str(uuid.uuid4())\n\n with utils.MockedProjectCacheDir():\n with utils.MockedClient(mock_all=False) as (client, mr):\n # The model needs to request files\n requested_urls = []\n requested_urls.extend(utils.register_uri_for_model_meta_info(mr, model, project_key))\n requested_urls.extend(\n utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True)\n )\n\n downloaded_model = BaseModel.download(client=client, project_key=project_key, model_id=str(model.id))\n\n for requested_url in requested_urls:\n assert utils.is_url_requested(mr.request_history, requested_url)\n\n assert downloaded_model.id == model.id\n assert downloaded_model.meta == model.meta", + "completions": [ + "```python\ndef test_model_download(request):\n from pathlib import Path\n from unittest.mock import MagicMock\n from giskard.client import GiskardClient\n from giskard.models import ModelMetaInfo, ModelMeta, SupportedModelTypes\n from giskard.utils import cast_labels\n from giskard.models.base import Model\n from giskard.models.classification import ClassificationModel\n from giskard.models.regression import RegressionModel\n from giskard.models.clustering import ClusteringModel\n from giskard.models.dim_reduction import DimReductionModel\n from giskard.models.anomaly_detection import AnomalyDetectionModel\n from giskard.models.time_series import TimeSeriesModel\n from giskard.models.text_classification import TextClassificationModel\n from giskard.models.text_recommendation import TextRecommendationModel\n from giskard.models.image_classification import ImageClassificationModel\n from giskard.models.image_recommendation import ImageRecommendationModel\n from giskard.models.image_segmentation import ImageSegmentationModel\n from giskard.models.image_generation import ImageGenerationModel\n from giskard.models.image_style_transfer import ImageStyleTransferModel\n from giskard.models.image_super_resolution import ImageSuperResolutionModel\n from giskard.models.image_denoising import ImageDenoisingModel\n from giskard.models.image_colorization import ImageColorizationModel\n from giskard.models.image_deraining import ImageDerainingModel\n from giskard.models.image_deblurring import ImageDeblurringModel\n from giskard.models.image_edge_detection import ImageEdgeDetectionModel\n from giskard.models.image_face_detection import ImageFaceDetectionModel\n from giskard.models.image_face_recognition import ImageFaceRecognitionModel\n from giskard.models.image_object_detection import ImageObjectDetectionModel\n from giskard.models.image_object_recognition import ImageObjectRecognitionModel\n from giskard.models.image_segmentation_mask_rcnn import ImageSegmentationMaskRCNNModel\n from giskard.models.image_segmentation_yolo import ImageSegmentationYOLOModel\n from giskard.models.image_segmentation_faster_rcnn import ImageSegmentationFasterRCNNModel\n from giskard.models.image_segmentation_unet import ImageSegmentationUNetModel\n from giskard.models.image_segmentation_deeplab import ImageSegmentationDeepLabModel\n from giskard.models.image_segmentation_fcn import ImageSegmentationFCNModel\n from giskard.models.image_segmentation_pspnet import ImageSegmentationPSPNetModel\n from giskard.models.image_segmentation_linknet import ImageSegmentationLinkNetModel\n from giskard.models.image_segmentation_unet_plus import ImageSegmentationUNetPlusModel\n from giskard.models.image_segmentation_deeplab_v3 import ImageSegmentationDeepLabV3Model\n from giskard.models.image_segmentation_deeplab_v3_plus import ImageSegmentationDeepLabV3PlusModel\n from giskard.models.image_segmentation_fcn_resnet101 import ImageSegmentationFCNResNet101Model\n from giskard.models.image_segmentation_fcn_resnet50 import ImageSegmentationFCNResNet50Model\n from giskard.models.image_segmentation_fcn_mobilenet_v2 import ImageSegmentationFCNMobileNetV2Model\n from giskard.models.image_segmentation_fcn_mobilenet_v3 import ImageSegmentationFCNMobileNetV3Model\n from giskard.models.image_segmentation_fcn_efficientnet_b" + ], + "line": 40, + "token": 316, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 27, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\ndef test_rounded_date_columns():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rounded_date_columns", + "reference": "def test_rounded_date_columns():\n df = pl.DataFrame(\n [\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.000000\", \"exp\": True},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.123456\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:01.000000\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00\", \"exp\": True},\n ]\n )\n\n df = df.with_columns(df[\"a\"].str.to_date().alias(\"a_dt\"))\n actual = columns_equal(df[\"a_dt\"], df[\"b\"])\n expected = df[\"exp\"]\n assert_series_equal(actual, expected, check_names=False)", + "completions": [ + "```python\ndef test_rounded_date_columns():\n # given\n col_1 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n col_2 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n\n # when\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n\n # then\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 28, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef annotations_to_instances(annos, image_size, mask_format=\"polygon\"):\n boxes = [BoxMode.convert(obj[\"bbox\"], obj[\"bbox_mode\"], BoxMode.XYXY_ABS) for obj in annos]\n target = Instances(image_size)\n boxes = target.gt_boxes = Boxes(boxes)\n boxes.clip(image_size)\n\n classes = [obj[\"category_id\"] for obj in annos]\n classes = torch.tensor(classes, dtype=torch.int64)\n target.gt_classes = classes\n\n if len(annos) and \"segmentation\" in annos[0]:\n segms = [obj[\"segmentation\"] for obj in annos]\n if mask_format == \"polygon\":\n masks = PolygonMasks(segms)\n else:\n assert mask_format == \"bitmask\", mask_format\n masks = []\n for segm in segms:\n if isinstance(segm, list):\n # polygon\n masks.append(polygons_to_bitmask(segm, *image_size))\n elif isinstance(segm, dict):\n # COCO RLE\n masks.append(mask_util.decode(segm))\n elif isinstance(segm, np.ndarray):\n assert segm.ndim == 2, \"Expect segmentation of 2 dimensions, got {}.\".format(\n segm.ndim\n )\n # mask array\n masks.append(segm)\n else:\n raise ValueError(\n \"Cannot convert segmentation of type '{}' to BitMasks!\"\n \"Supported types are: polygons as list[list[float] or ndarray],\"\n \" COCO-style RLE as a dict, or a full-image segmentation mask \"\n \"as a 2D ndarray.\".format(type(segm))\n )\n # torch.from_numpy does not support array with negative stride.\n masks = BitMasks(\n torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])\n )\n target.gt_masks = masks\n\n if len(annos) and \"keypoints\" in annos[0]:\n kpts = [obj.get(\"keypoints\", []) for obj in annos]\n target.gt_keypoints = Keypoints(kpts)\n\n return target\n```\n###test function signature:\n```python\n\n def test_transform_simple_annotation(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_simple_annotation", + "reference": "\n def test_transform_simple_annotation(self):\n transforms = T.TransformList([T.HFlipTransform(400)])\n anno = {\n \"bbox\": np.asarray([10, 10, 200, 300]),\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 3,\n \"segmentation\": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],\n }\n\n output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))\n self.assertTrue(np.allclose(output[\"bbox\"], [200, 10, 390, 300]))\n self.assertEqual(len(output[\"segmentation\"]), len(anno[\"segmentation\"]))\n self.assertTrue(np.allclose(output[\"segmentation\"][0], [390, 10, 300, 100, 300, 10]))\n\n detection_utils.annotations_to_instances([output, output], (400, 400))", + "completions": [ + "```python\n def test_transform_simple_annotation(self):\n annos = [\n {\n \"bbox\": [10, 10, 100, 100],\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 0,\n \"segmentation\": [\n [10, 10, 50, 50, 90, 90, 100, 100],\n [20, 20, 50, 50, 80, 80, 100, 100],\n ],\n \"keypoints\": [10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n }\n ]\n image_size = (100, 100)\n target = annotations_to_instances(annos, image_size)\n assert target.gt_boxes.tensor.tolist() == [[10, 10, 100, 100]]\n assert target.gt_classes.tolist() == [0]\n assert target.gt_masks.polygons().tolist() == [\n [[10, 10], [50, 50], [90, 90], [100, 100]],\n [[20, 20], [50, 50], [80, 80], [100, 100]],\n ]\n assert target.gt_keypoints.tensor.tolist() == [[10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n```\n" + ], + "line": 54, + "token": 437, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 29, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_simple(seed):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_simple", + "reference": "\ndef test_sabre_simple(seed):\n placer = Trivial()\n circ = Circuit(5)\n circ.add(gates.CZ(0, 1))\n initial_layout = placer(circ)\n router = Sabre(connectivity=star_connectivity(), seed=seed)\n routed_circuit, final_map = router(circuit=circ, initial_layout=initial_layout)\n assert router.added_swaps == 1\n assert final_map == {\"q0\": 2, \"q1\": 1, \"q2\": 0, \"q3\": 3, \"q4\": 4}\n assert routed_circuit.queue[0].qubits == (0, 2)\n assert isinstance(routed_circuit.queue[0], gates.SWAP)\n assert isinstance(routed_circuit.queue[1], gates.CZ)\n assert_connectivity(circuit=routed_circuit, connectivity=star_connectivity())\n assert_circuit_equivalence(\n original_circuit=circ,\n transpiled_circuit=routed_circuit,\n final_map=final_map,\n initial_map=initial_layout,\n )", + "completions": [ + "```python\ndef test_sabre_simple(seed):\n import random\n random.seed(seed)\n import networkx as nx\n from pyquil.quil import Program\n from pyquil.api import QVMConnection\n from pyquil.gates import CNOT, H, MEASURE\n from pyquil.quilbase import Declare\n from pyquil.quilbase import DefGate\n from pyquil.quilbase import DefCircuit\n from pyquil.quilbase import Gate\n from pyquil.quilbase import Circuit\n from pyquil.quilbase import QubitPlaceholder\n from pyquil.quilbase import MemoryReference\n from pyquil.quilbase import Parameter\n from pyquil.quilbase import Pragma\n from pyquil.quilbase import PragmaDeclare\n from pyquil.quilbase import PragmaDefGate\n from pyquil.quilbase import PragmaDefCircuit\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefQubit\n from pyquil.quilbase import PragmaDefMemory\n from pyquil.quilbase import PragmaDefRegister\n from pyquil.quilbase import PragmaDefLabel\n from pyquil.quilbase import PragmaDefFunction\n from pyquil.quilbase import PragmaDefProcedure\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyqu" + ], + "line": 27, + "token": 242, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 30, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef unique_patterns_2d(\n agrid: NDArray[np.int64], ksize: int, periodic_input: bool\n) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:\n assert ksize >= 1\n if periodic_input:\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n else:\n # TODO: implement non-wrapped image handling\n # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n\n patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(\n agrid,\n (\n agrid.shape[0] - ksize + 1,\n agrid.shape[1] - ksize + 1,\n ksize,\n ksize,\n *agrid.shape[2:],\n ),\n agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],\n writeable=False,\n )\n patch_codes = hash_downto(patches, 2)\n uc, ui = np.unique(patch_codes, return_index=True)\n locs = np.unravel_index(ui, patch_codes.shape)\n up: NDArray[np.int64] = patches[locs[0], locs[1]]\n ids: NDArray[np.int64] = np.vectorize(\n {code: ind for ind, code in enumerate(uc)}.get\n )(patch_codes)\n return ids, up, patch_codes\n```\n###test function signature:\n```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_patterns_2d", + "reference": "\ndef test_unique_patterns_2d(img_redmaze) -> None:\n img = img_redmaze\n tile_size = 1\n pattern_width = 2\n _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(\n img, tile_size\n )\n\n (\n _patterns_in_grid,\n pattern_contents_list,\n patch_codes,\n ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)\n assert patch_codes[1][2] == 4867810695119132864\n assert pattern_contents_list[7][1][1] == 8253868773529191888", + "completions": [ + "```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n ids, up, patch_codes = unique_patterns_2d(img_redmaze, 3, True)\n assert ids.shape == (10,)\n assert up.shape == (10, 3, 3)\n assert patch_codes.shape == (10,)\n```\n" + ], + "line": 48, + "token": 425, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 31, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef serialize(data: Any, encoding: str | None = None) -> str | bytes:\n serialized: bytes | str\n\n if encoding is not None:\n serialized = yaml.dump(data, encoding=encoding, Dumper=AiiDADumper)\n else:\n serialized = yaml.dump(data, Dumper=AiiDADumper)\n\n return serialized\n```\n###test function signature:\n```python\ndef test_mixed_attribute_normal_dict():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mixed_attribute_normal_dict", + "reference": "def test_mixed_attribute_normal_dict():\n from aiida.common.extendeddicts import AttributeDict\n\n # Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively\n dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}\n attribute_dict = AttributeDict(dictionary)\n\n # Now add a normal dictionary in the attribute dictionary\n attribute_dict['nested']['normal'] = {'a': 2}\n\n serialized = serialize.serialize(attribute_dict)\n deserialized = serialize.deserialize_unsafe(serialized)\n\n assert attribute_dict, deserialized", + "completions": [ + "```python\ndef test_mixed_attribute_normal_dict():\n data = {\n \"a\": 1,\n \"b\": 2,\n \"c\": 3,\n \"d\": 4,\n \"e\": 5,\n \"f\": 6,\n \"g\": 7,\n \"h\": 8,\n \"i\": 9,\n \"j\": 10,\n \"k\": 11,\n \"l\": 12,\n \"m\": 13,\n \"n\": 14,\n \"o\": 15,\n \"p\": 16,\n \"q\": 17,\n \"r\": 18,\n \"s\": 19,\n \"t\": 20,\n \"u\": 21,\n \"v\": 22,\n \"w\": 23,\n \"x\": 24,\n \"y\": 25,\n \"z\": 26,\n }\n assert serialize(data) == \"\"\"a: 1\nb: 2\nc: 3\nd: 4\ne: 5\nf: 6\ng: 7\nh: 8\ni: 9\nj: 10\nk: 11\nl: 12\nm: 13\nn: 14\no: 15\np: 16\nq: 17\nr: 18\ns: 19\nt: 20\nu: 21\nv: 22\nw: 23\nx: 24\ny: 25\nz: 26\n\"\"\"\n```\n" + ], + "line": 17, + "token": 133, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected_against_alias", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n http_client.select_model(model_id=\"yolov8n-640\")\n\n # when\n result = http_client.unload_model(model_id=\"yolov8n-640\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"coco/3\",\n }\n assert (\n http_client.selected_model is None\n ), \"Even when alias is in use - selected model should be emptied\"", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n requests_mock.post(\n f\"{API_URL}/model/remove\",\n json={\n \"model_id\": \"test_model_id\",\n },\n headers=DEFAULT_HEADERS,\n )\n client = RoboflowClient(api_key=\"test_api_key\", api_secret=\"test_api_secret\")\n # when\n response = client.unload_model(model_id=\"test_model_alias\")\n # then\n assert response.model_id == \"test_model_id\"\n assert response.model_name == \"test_model_name\"\n assert response.model_version == \"test_model_version\"\n assert response.model_type == \"test_model_type\"\n assert response.model_status == \"test_model_status\"\n assert response.model_created_at == \"test_model_created_at\"\n assert response.model_updated_at == \"test_model_updated_at\"\n assert response.model_tags == \"test_model_tags\"\n assert response.model_description == \"test_model_description\"\n assert response.model_url == \"test_model_url\"\n assert response.model_thumbnail_url == \"test_model_thumbnail_url\"\n assert response.model_training_data_url == \"test_model_training_data_url\"\n assert response.model_training_data_type == \"test_model_training_data_type\"\n assert response.model_training_data_size == \"test_model_training_data_size\"\n assert response.model_training_data_description == \"test_model_training_data_description\"\n assert response.model_training_data_tags == \"test_model_training_data_tags\"\n assert response.model_training_data_created_at == \"test_model_training_data_created_at\"\n assert response.model_training_data_updated_at == \"test_model_training_data_updated_at\"\n assert response.model_training_data_status == \"test_model_training_data_status\"\n assert response.model_training_data_status_message == \"test_model_training_data_status_message\"\n assert response.model_training_data_status_details == \"test_model_training_data_status_details\"\n assert response.model_training_data_status_created_at == \"test_model_training_data_status_created_at\"\n assert response.model_training_data_status_updated_at == \"test_model_training_data_status_updated_at\"\n assert response.model_training_data_status_completed_at == \"test_model_training_data_status_completed_at\"\n assert response.model_training_data_status_duration == \"test_model_training_data_status_duration\"\n assert response.model_training_data_status_budget_type == \"test_model_training_data_status_budget_type\"\n assert response.model_training_data_status_budget_amount == \"test_model_training_data_status_budget_amount\"\n assert response.model_training_data_status_budget_unit == \"test_model_training_data_status_budget_unit\"\n assert response.model_training_data_status_budget_remaining == \"test_model_training_data_status_budget_remaining\"\n assert response.model_training_data_status_budget_message == \"test_model_training_data_status_budget_message\"\n assert response.model_training_data_status_budget_details == \"test_model_training_data_status_budget_details\"\n assert response.model_training_data_status_budget_created_at == \"test_model_training_data_status_budget_created_at\"\n assert response.model_training_data_status_budget_updated_at == \"test_" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 33, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef edge_index_to_vector(\n edge_index: Tensor,\n size: Tuple[int, int],\n bipartite: bool,\n force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n row, col = edge_index\n\n if bipartite: # No need to account for self-loops.\n idx = (row * size[1]).add_(col)\n population = size[0] * size[1]\n return idx, population\n\n elif force_undirected:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We only operate on the upper triangular matrix:\n mask = row < col\n row, col = row[mask], col[mask]\n offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n idx = row.mul_(num_nodes).add_(col).sub_(offset)\n population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n return idx, population\n\n else:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We remove self-loops as we do not want to take them into account\n # when sampling negative values.\n mask = row != col\n row, col = row[mask], col[mask]\n col[row < col] -= 1\n idx = row.mul_(num_nodes - 1).add_(col)\n population = num_nodes * num_nodes - num_nodes\n return idx, population\n```\n###test function signature:\n```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_edge_index_to_vector_and_vice_versa", + "reference": "\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n N = 10\n row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n edge_index = torch.stack([row, col], dim=0)\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n assert population == N * N\n assert idx.tolist() == list(range(population))\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n assert is_undirected(edge_index2)\n assert edge_index.tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n assert population == N * N - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n force_undirected=True)\n assert population == (N * (N + 1)) / 2 - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n force_undirected=True)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()", + "completions": [ + "```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n\n # Test force_undirected:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test force_undirected with bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n```\n" + ], + "line": 43, + "token": 351, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 34, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_integration_test(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_integration_test", + "reference": "\n def test_spectrogram_integration_test(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,\n 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,\n 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,\n 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,\n 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,\n 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,\n 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,\n 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,\n 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,\n 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,\n 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,\n 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,\n 0.0293578 , 0.03452379, 0.02194803, 0.01676056,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=256,\n num_mel_filters=400,\n min_frequency=20,\n max_frequency=8000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"kaldi\",\n triangularize_in_mel_space=True,\n )\n\n mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"povey\", periodic=False),\n frame_length=400,\n hop_length=160,\n fft_length=512,\n power=2.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n preemphasis=0.97,\n mel_filters=mel_filters,\n log_mel=\"log\",\n mel_floor=1.1920928955078125e-07,\n remove_dc_offset=True,\n )\n self.assertEqual(spec.shape, (400, 584))\n\n # fmt: off\n expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -6.52463769, -7.73677889, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.18650018, -3.37195286,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -4.70190154, -2.4217066 , -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -5.62755239, -3.53385194,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -9.43303023, -8.77480925, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.2951092 , -5.51585994,\n -15.94238515, -15.94238515, -15.94238515, -4.40151721,\n -3.95228878, -15.94238515, -15.94238515, -15.94238515,\n -6.10365415, -4.59494697, -15.94238515, -15.94238515,\n -15.94238515, -8.10727767, -6.2585298 , -15.94238515,\n -15.94238515, -15.94238515, -5.60161702, -4.47217004,\n -15.94238515, -15.94238515, -15.94238515, -5.91641988]\n )\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))", + "completions": [ + "```python\n def test_spectrogram_integration_test(self):\n waveform = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])\n window = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n expected_spectrogram = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]])\n\n self.assertTrue(np.allclose(spectrogram, expected_spectrogram))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 35, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef run_query_expansion_node(modules: List[Callable],\n module_params: List[Dict],\n previous_result: pd.DataFrame,\n node_line_dir: str,\n strategies: Dict,\n ) -> pd.DataFrame:\n if not os.path.exists(node_line_dir):\n os.makedirs(node_line_dir)\n node_dir = os.path.join(node_line_dir, \"query_expansion\")\n if not os.path.exists(node_dir):\n os.makedirs(node_dir)\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n\n # run query expansion\n results, execution_times = zip(*map(lambda task: measure_speed(\n task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))\n average_times = list(map(lambda x: x / len(results[0]), execution_times))\n\n # save results to folder\n pseudo_module_params = deepcopy(module_params)\n for i, module_param in enumerate(pseudo_module_params):\n if 'prompt' in module_params:\n module_param['prompt'] = str(i)\n filepaths = list(map(lambda x: os.path.join(node_dir, f'{x}.parquet'), range(len(modules))))\n list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet\n filenames = list(map(lambda x: os.path.basename(x), filepaths))\n\n # make summary file\n summary_df = pd.DataFrame({\n 'filename': filenames,\n 'module_name': list(map(lambda module: module.__name__, modules)),\n 'module_params': module_params,\n 'execution_time': average_times,\n })\n\n # Run evaluation when there are more than one module.\n if len(modules) > 1:\n # pop general keys from strategies (e.g. metrics, speed_threshold)\n general_key = ['metrics', 'speed_threshold']\n general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))\n extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))\n\n # first, filter by threshold if it is enabled.\n if general_strategy.get('speed_threshold') is not None:\n results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],\n filenames)\n\n # check metrics in strategy\n if general_strategy.get('metrics') is None:\n raise ValueError(\"You must at least one metrics for query expansion evaluation.\")\n\n if extra_strategy.get('top_k') is None:\n extra_strategy['top_k'] = 10 # default value\n\n # get retrieval modules from strategy\n retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)\n\n # get retrieval_gt\n retrieval_gt = pd.read_parquet(os.path.join(project_dir, \"data\", \"qa.parquet\"))['retrieval_gt'].tolist()\n\n # run evaluation\n evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(\n retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,\n general_strategy['metrics'], project_dir, previous_result), results))\n\n evaluation_df = pd.DataFrame({\n 'filename': filenames,\n **{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))\n for metric_name in general_strategy['metrics']}\n })\n summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')\n\n best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)\n # change metric name columns to query_expansion_metric_name\n best_result = best_result.rename(columns={\n metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})\n best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])\n else:\n best_result, best_filename = results[0], filenames[0]\n best_result = pd.concat([previous_result, best_result], axis=1)\n\n # add 'is_best' column at summary file\n summary_df['is_best'] = summary_df['filename'] == best_filename\n\n # save files\n summary_df.to_csv(os.path.join(node_dir, \"summary.csv\"), index=False)\n best_result.to_parquet(os.path.join(node_dir, f\"best_{os.path.splitext(best_filename)[0]}.parquet\"), index=False)\n\n return best_result\n```\n###test function signature:\n```python\n\ndef test_run_query_expansion_node_default(node_line_dir):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_query_expansion_node_default", + "reference": "\ndef test_run_query_expansion_node_default(node_line_dir):\n generator_models['mock'] = MockLLM\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n qa_path = os.path.join(project_dir, \"data\", \"qa.parquet\")\n previous_result = pd.read_parquet(qa_path)\n\n modules = [query_decompose, hyde]\n module_params = [{'llm': \"mock\", 'temperature': 0.2, 'batch': 7}, {'llm': \"mock\"}]\n strategies = {\n 'metrics': metrics\n }\n best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n base_query_expansion_test(best_result, node_line_dir)", + "completions": [ + "```python\ndef test_run_query_expansion_node_default(node_line_dir):\n modules = [query_expansion_module]\n module_params = [{'prompt': 'default'}]\n previous_result = pd.DataFrame({'query': ['a', 'b', 'c']})\n strategies = {'metrics': ['recall', 'precision', 'f1'], 'speed_threshold': 1000}\n run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n```\n" + ], + "line": 84, + "token": 850, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 36, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(self) -> Tuple[pd.DataFrame, FreqDetectionResult]:\n try:\n if self._prophet_detector_params[\"suppress_stan\"]:\n pd.set_option(\"mode.chained_assignment\", None)\n\n # Skip measurements based on feedbacks given from SODA Cloud\n preprocessed_df = self.preprocess(time_series_df=self.raw_time_series_df)\n\n # Automatically detect frequency of the time series\n freq_detector = FrequencyDetector(\n logs=self.logs,\n params=self.params,\n time_series_df=preprocessed_df,\n manual_freq=self.training_dataset_params.frequency,\n )\n freq_detection_result = freq_detector.detect_frequency()\n\n # Return if frequency detection failed\n if freq_detection_result.error_code_int >= ERROR_CODE_LEVEL_CUTTOFF:\n return self.exit_with_warning(freq_detection_result)\n\n # Apply training dataset configurations\n training_df = self.apply_training_dataset_configs(\n time_series_df=preprocessed_df, freq_detection_result=freq_detection_result\n )\n\n # Remove big gaps from the time series to not confuse Prophet\n training_df = self.remove_big_gaps_from_time_series(\n time_series_df=training_df, freq_detection_result=freq_detection_result\n )\n\n # Only use the last n points for training based on the window length\n window_length = self.get_window_length(training_df=training_df)\n training_df = training_df.iloc[-window_length:]\n\n training_df_shape = training_df[\"y\"].dropna().shape[0]\n if training_df_shape <= self._min_n_points:\n freq_detection_result = get_not_enough_measurements_freq_result(n_data_points=training_df_shape)\n return self.exit_with_warning(freq_detection_result)\n\n model_hyperparameters = self.get_prophet_hyperparameters(time_series_df=training_df)\n\n predictions_df = self.setup_fit_predict(\n time_series_df=training_df, model_hyperparameters=model_hyperparameters\n )\n anomalies_df = self.detect_anomalies(time_series_df=training_df, predictions_df=predictions_df)\n anomalies_df = self.generate_severity_zones(anomalies_df=anomalies_df)\n anomalies_df = self.compute_alert_level(anomalies_df=anomalies_df)\n return anomalies_df, freq_detection_result\n except Exception as e:\n raise e\n```\n###test function signature:\n```python\n\ndef test_with_exit() -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_exit", + "reference": "\ndef test_with_exit() -> None:\n time_series_df = generate_random_dataframe(size=2, n_rows_to_convert_none=0, frequency=\"D\")\n detector = ProphetDetector(\n logs=LOGS,\n params=PARAMS,\n time_series_df=time_series_df,\n model_cfg=ModelConfigs(),\n training_dataset_params=TrainingDatasetParameters(),\n severity_level_params=SeverityLevelParameters(),\n )\n df_anomalies, frequency_result = detector.run()\n assert df_anomalies.empty\n assert frequency_result.error_code_int == 100\n assert frequency_result.freq_detection_strategy == \"not_enough_measurements_custom\"", + "completions": [ + "```python\ndef test_with_exit() -> None:\n from soda.sodacl.checks.check_frequency import FrequencyDetector\n from soda.sodacl.checks.check_frequency import get_not_enough_measurements_freq_result\n from soda.sodacl.checks.check_frequency import ERROR_CODE_LEVEL_CUTTOFF\n from soda.sodacl.checks.check_frequency import FreqDetectionResult\n from soda.sodacl.checks.check_frequency import CheckFrequency\n from soda.sodacl.checks.check_frequency import CheckFrequencyParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyLogs\n from soda.sodacl.checks.check_frequency import CheckFrequencyTrainingDatasetParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyPro" + ], + "line": 52, + "token": 310, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 37, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef average_transforms(\n R: torch.Tensor,\n t: torch.Tensor,\n w: torch.Tensor,\n mask: torch.Tensor,\n dim: int,\n t_edge: Optional[torch.Tensor] = None,\n dither: Optional[bool] = True,\n dither_eps: float = 1e-4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert dim >= 0, \"dimension must index from the left\"\n w = torch.where(\n mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)\n )\n\n # We use different averaging models based on the number of weights\n num_transform_weights = w.size(-1)\n if num_transform_weights == 1:\n # Share a single scalar weight between t and R.\n probs = w.softmax(dim)\n t_probs = probs\n R_probs = probs[..., None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 2:\n # Use separate scalar weights for each of t and R.\n probs = w.softmax(dim)\n t_probs, R_probs = probs.unbind(-1)\n t_probs = t_probs[..., None]\n R_probs = R_probs[..., None, None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 3:\n # For R use a signed scalar weight.\n R_probs = w[..., 2].softmax(dim)[..., None, None]\n\n # For t use a two-parameter precision matrix P = P_isometric + P_radial.\n # We need to hand compute softmax over the shared dim x 2 elements.\n w_t = w[..., :2]\n w_t_total = w_t.logsumexp([dim, -1], True)\n p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)\n\n # Use Gaussian fusion for translation.\n t_edge = t_edge * mask.to(t_edge.dtype)[..., None]\n t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)\n else:\n raise NotImplementedError\n\n # Average rotation via SVD\n R_avg_unc = (R * R_probs).sum(dim)\n R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)\n U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)\n R_avg = U @ Vh\n\n # Enforce that matrix is rotation matrix\n d = torch.linalg.det(R_avg)\n d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)\n Vh = Vh * d_expand\n R_avg = U @ Vh\n return R_avg, t_avg\n```\n###test function signature:\n```python\n\ndef test_average_transforms(rotations, translations):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_average_transforms", + "reference": "\ndef test_average_transforms(rotations, translations):\n R = torch.stack([rotations[0], torch.eye(3)])\n t = torch.stack([translations[0], torch.zeros(3)])\n w = torch.ones(2, 2)\n mask = torch.ones(2)\n # average of a transform with the identity is \"half\" the transform\n R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)\n R_total_fromavg, _ = compose_transforms(\n R_avg, torch.zeros(3), R_avg, torch.zeros(3)\n )\n _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)\n assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)\n assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)", + "completions": [ + "```python\ndef test_average_transforms(rotations, translations):\n # test case 1\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 2\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 3\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 4\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 5\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 6\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 7\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 8\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 9\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w" + ], + "line": 65, + "token": 537, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 38, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n return GPT2Config(\n vocab_size=btlm_config.vocab_size,\n n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n n_embd=btlm_config.hidden_size,\n n_layer=btlm_config.num_hidden_layers,\n n_head=btlm_config.num_attention_heads,\n n_inner=btlm_config.n_inner,\n activation_function=btlm_config.activation_function,\n resid_pdrop=btlm_config.resid_pdrop,\n embd_pdrop=btlm_config.embd_pdrop,\n attn_pdrop=btlm_config.attn_pdrop,\n layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n initializer_range=btlm_config.initializer_range,\n bos_token_id=btlm_config.bos_token_id,\n eos_token_id=btlm_config.eos_token_id,\n # These are new arguments not in the original GPT2Config\n use_alibi=btlm_config.position_embedding_type == \"alibi\",\n use_flash_attn=btlm_config.position_embedding_type == \"alibi\", # Alibi code path requires flash_attn\n mup_width_scale=btlm_config.mup_width_scale,\n mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n mup_output_multiplier=btlm_config.mup_output_alpha,\n mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n mlp_multiple_of=1,\n )\n```\n###test function signature:\n```python\ndef test_btlm_optimized(model_name):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_btlm_optimized", + "reference": "def test_btlm_optimized(model_name):\n dtype = torch.float16\n device = \"cuda\"\n config = btlm_config_to_gpt2_config(\n AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n )\n config.fused_bias_fc = True\n config.fused_dropout_add_ln = True\n config.residual_in_fp32 = True\n\n pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n model = GPTLMHeadModel(config, device=device, dtype=dtype)\n model.load_state_dict(pretrained_state_dict)\n model.eval()\n\n torch.manual_seed(0)\n batch_size = 2\n max_seqlen = 256\n seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n input_ids = torch.randint(\n 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n )\n with torch.no_grad():\n out = model.transformer(input_ids)\n logits = model(input_ids).logits\n del model\n\n # Without device_map, the model is loaded on the CPU, which is very slow\n # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n model_ref = AutoModelForCausalLM.from_pretrained(\n model_name, device_map=\"auto\", trust_remote_code=True\n )\n model_ref.eval()\n with torch.no_grad():\n out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n logits_ref = model_ref(input_ids).logits.to(device=device)\n del model_ref\n\n model_hf = AutoModelForCausalLM.from_pretrained(\n model_name,\n torch_dtype=dtype,\n device_map={\"\": device},\n trust_remote_code=True,\n )\n model_hf.eval()\n with torch.no_grad():\n out_hf = model_hf.transformer(input_ids).last_hidden_state\n logits_hf = model_hf(input_ids).logits\n del model_hf\n\n print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n assert (logits - logits_ref).abs().max().item() < 3 * (\n logits_hf - logits_ref\n ).abs().max().item()", + "completions": [ + "```python\ndef test_btlm_optimized(model_name):\n from transformers import GPT2Config, PretrainedConfig\n from bert_extended.optimized import btlm_config_to_gpt2_config\n\n config = PretrainedConfig.from_pretrained(model_name)\n gpt2_config = btlm_config_to_gpt2_config(config)\n gpt2_config.to_json_file(f\"{model_name}_optimized.json\")\n```\n" + ], + "line": 35, + "token": 169, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 39, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef grad_solid_harmonic(r: jnp.ndarray, l_max: int) -> jnp.ndarray:\n def cartesian_sum(r, ell):\n # computes all cartesian sums of order ell\n # i.e. all terms (a*x + b*y + c*z) for a+b+c = ell\n if ell == 0:\n return np.array([0.])\n return np.sum(np.stack(\n np.meshgrid(*[r]*ell, indexing='ij'), axis=0).reshape(ell, -1), axis=0)\n\n dharmonics = []\n for ell in range(l_max + 1):\n dx = cartesian_sum(np.array([1., 0., 0.]), ell)\n dy = cartesian_sum(np.array([0., 1., 0.]), ell)\n dz = cartesian_sum(np.array([0., 0., 1.]), ell)\n d = np.stack([dx, dy, dz], axis=1)\n cp = jax.vmap(functools.partial(cartesian_product, ell=ell), out_axes=-1)\n dcp = jnp.where(r.T[None, :, :] == 0, 0,\n d[:, :, None]*cp(r)[:, None, :] / r.T[None, :, :])\n sh = jnp.einsum('cs,cxn->snx', full_cart2sph(ell, True), dcp)\n pad = l_max - ell\n sh = jnp.pad(sh, [[pad, pad], [0, 0], [0, 0]])\n dharmonics.append(sh)\n\n return jnp.stack(dharmonics, axis=1) # [L, M, grid, xyz]\n```\n###test function signature:\n```python\n\n def test_grad_solid_harmonic(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_grad_solid_harmonic", + "reference": "\n def test_grad_solid_harmonic(self):\n np.random.seed(0)\n r = np.random.randn(100, 3)\n l_max = 5\n\n jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max))\n expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])),\n [1, 2, 0, 3])\n\n with self.subTest('by hand'):\n observed = gto.grad_solid_harmonic(r, l_max)\n np.testing.assert_allclose(observed, expected, atol=1.e-4)\n with self.subTest('by jax'):\n observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max)\n np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4)", + "completions": [ + "```python\n def test_grad_solid_harmonic(self):\n r = jnp.array([[0.5, 0.5, 0.5]])\n l_max = 2\n dharmonics = grad_solid_harmonic(r, l_max)\n self.assertEqual(dharmonics.shape, (l_max + 1, l_max + 1, 1, 3))\n self.assertTrue(jnp.allclose(dharmonics[0, 0, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 0, 0, :], jnp.array([1., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 1, 0, :], jnp.array([0., 1., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 0, 0, :], jnp.array([0., 0., 1.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 1, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 2, 0, :], jnp.array([0., 0., 0.])))\n```\n" + ], + "line": 32, + "token": 389, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 40, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_basis(\n base_shape, angular_tesselation, remove_symmetries=True, eps=1e-4\n):\n\n if base_shape == 'tetrahedron':\n verts = np.array([\n (np.sqrt(8 / 9), 0, -1 / 3),\n (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),\n (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3),\n (0, 0, 1),\n ])\n faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)])\n elif base_shape == 'icosahedron':\n a = (np.sqrt(5) + 1) / 2\n verts = np.array([\n (-1, 0, a),\n (1, 0, a),\n (-1, 0, -a),\n (1, 0, -a),\n (0, a, 1),\n (0, a, -1),\n (0, -a, 1),\n (0, -a, -1),\n (a, 1, 0),\n (-a, 1, 0),\n (a, -1, 0),\n (-a, -1, 0),\n ]) / np.sqrt(a + 2)\n faces = np.array([\n (0, 4, 1),\n (0, 9, 4),\n (9, 5, 4),\n (4, 5, 8),\n (4, 8, 1),\n (8, 10, 1),\n (8, 3, 10),\n (5, 3, 8),\n (5, 2, 3),\n (2, 7, 3),\n (7, 10, 3),\n (7, 6, 10),\n (7, 11, 6),\n (11, 0, 6),\n (0, 1, 6),\n (6, 1, 10),\n (9, 0, 11),\n (9, 11, 2),\n (9, 2, 5),\n (7, 2, 11),\n ])\n elif base_shape == 'octahedron':\n verts = np.array(\n [(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]\n )\n corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n else:\n raise ValueError(f'base_shape {base_shape} not supported')\n verts = tesselate_geodesic(verts, faces, angular_tesselation)\n\n if remove_symmetries:\n # Remove elements of `verts` that are reflections of each other.\n match = compute_sq_dist(verts.T, -verts.T) < eps\n verts = verts[~np.any(np.triu(match), axis=0), :]\n\n basis = verts[:, ::-1]\n return basis\n```\n###test function signature:\n```python\n def test_generate_basis_golden(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_basis_golden", + "reference": " def test_generate_basis_golden(self):\n\n basis = geopoly.generate_basis('tetrahedron', 1)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('tetrahedron', 2)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.57735027, 0.00000000, -0.81649658],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.57735027, -0.70710678, 0.40824829],\n [-0.57735027, 0.70710678, 0.40824829],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('icosahedron', 2)\n basis_golden = np.array([\n [0.85065081, 0.00000000, 0.52573111],\n [0.80901699, 0.50000000, 0.30901699],\n [0.52573111, 0.85065081, 0.00000000],\n [1.00000000, 0.00000000, 0.00000000],\n [0.80901699, 0.50000000, -0.30901699],\n [0.85065081, 0.00000000, -0.52573111],\n [0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, -0.85065081],\n [0.50000000, 0.30901699, -0.80901699],\n [0.00000000, 1.00000000, 0.00000000],\n [-0.52573111, 0.85065081, 0.00000000],\n [-0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, 0.85065081],\n [-0.30901699, 0.80901699, 0.50000000],\n [0.30901699, 0.80901699, 0.50000000],\n [0.50000000, 0.30901699, 0.80901699],\n [0.50000000, -0.30901699, 0.80901699],\n [0.00000000, 0.00000000, 1.00000000],\n [-0.50000000, 0.30901699, 0.80901699],\n [-0.80901699, 0.50000000, 0.30901699],\n [-0.80901699, 0.50000000, -0.30901699],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('octahedron', 4)\n basis_golden = np.array([\n [0.00000000, 0.00000000, -1.00000000],\n [0.00000000, -0.31622777, -0.94868330],\n [0.00000000, -0.70710678, -0.70710678],\n [0.00000000, -0.94868330, -0.31622777],\n [0.00000000, -1.00000000, 0.00000000],\n [-0.31622777, 0.00000000, -0.94868330],\n [-0.40824829, -0.40824829, -0.81649658],\n [-0.40824829, -0.81649658, -0.40824829],\n [-0.31622777, -0.94868330, 0.00000000],\n [-0.70710678, 0.00000000, -0.70710678],\n [-0.81649658, -0.40824829, -0.40824829],\n [-0.70710678, -0.70710678, 0.00000000],\n [-0.94868330, 0.00000000, -0.31622777],\n [-0.94868330, -0.31622777, 0.00000000],\n [-1.00000000, 0.00000000, 0.00000000],\n [0.00000000, -0.31622777, 0.94868330],\n [0.00000000, -0.70710678, 0.70710678],\n [0.00000000, -0.94868330, 0.31622777],\n [0.40824829, -0.40824829, 0.81649658],\n [0.40824829, -0.81649658, 0.40824829],\n [0.31622777, -0.94868330, 0.00000000],\n [0.81649658, -0.40824829, 0.40824829],\n [0.70710678, -0.70710678, 0.00000000],\n [0.94868330, -0.31622777, 0.00000000],\n [0.31622777, 0.00000000, -0.94868330],\n [0.40824829, 0.40824829, -0.81649658],\n [0.40824829, 0.81649658, -0.40824829],\n [0.70710678, 0.00000000, -0.70710678],\n [0.81649658, 0.40824829, -0.40824829],\n [0.94868330, 0.00000000, -0.31622777],\n [0.40824829, -0.40824829, -0.81649658],\n [0.40824829, -0.81649658, -0.40824829],\n [0.81649658, -0.40824829, -0.40824829],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))", + "completions": [ + "```python\n def test_generate_basis_golden(self):\n basis = generate_basis(\n base_shape='tetrahedron', angular_tesselation=10, remove_symmetries=True, eps=1e-4\n )\n self.assertTrue(np.allclose(basis, np.array([[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0]])))\n```\n" + ], + "line": 75, + "token": 692, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 41, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:\n if not isinstance(t, torch.Tensor):\n t = torch.Tensor([t]).float()\n\n t_max = self.t_max.to(t.device)\n t_min = self.t_min.to(t.device)\n t_tilde = t_min + (t_max - t_min) * t\n\n return t_tilde\n```\n###test function signature:\n```python\n\ndef test_noise_schedule_diffusion_coeff(kind):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_noise_schedule_diffusion_coeff", + "reference": "\ndef test_noise_schedule_diffusion_coeff(kind):\n noise_schedule = GaussianNoiseSchedule(kind=kind, log_snr_range=(-6, 6))\n ts = torch.linspace(1e-2, 1 - 1e-2, 10)\n t_map = noise_schedule.t_map(ts) # map time to the prescribed log_SNR range\n\n if kind == \"log_snr\":\n g = noise_schedule.g(ts)\n\n # compute true beta_t\n l_range = noise_schedule.log_snr_range\n snr = torch.exp((1 - t_map) * l_range[1] + t_map * l_range[0])\n g_true = (-(l_range[0] - l_range[1]) / (snr + 1)).sqrt()\n\n assert torch.allclose(g, g_true, atol=1e-4)\n\n if kind == \"ot_linear\":\n g = noise_schedule.g(ts)\n g_true = (2.0 * t_map / (1.0 - t_map)).sqrt()\n assert torch.allclose(g, g_true, atol=1e-4)", + "completions": [ + "```python\ndef test_noise_schedule_diffusion_coeff(kind):\n from torch.testing import assert_allclose\n from torch.nn import functional as F\n from torch import tensor\n from torch.distributions import Uniform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import Transform\n from torch.distributions.transforms import SigmoidTransform\n from torch.distributions.transforms import TanhTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import ComposeTransform\n from torch.distributions.transforms import PowerTransform\n from torch.distributions.transforms import ExpTransform\n from torch.distributions.transforms import LogTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import" + ], + "line": 17, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 42, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_data_quality", + "reference": "\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index, temperature_data, temperature_mean=False, data_quality=True\n )\n assert df.shape == (27, 4)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_not_null\",\n \"temperature_null\",\n ]\n assert round(df.temperature_not_null.mean(), 2) == 729.23\n assert round(df.temperature_null.mean(), 2) == 0.0", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"1H\", tz=\"America/Chicago\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0" + ], + "line": 168, + "token": 984, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 43, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n\n # when\n result = http_client.unload_model(model_id=\"some/1\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"some/1\",\n }", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n model_id = \"my_model_id\"\n de_aliased_model_id = \"my_de_aliased_model_id\"\n response_payload = {\n \"model_id\": de_aliased_model_id,\n \"model_name\": \"my_model_name\",\n \"model_version\": \"my_model_version\",\n \"model_type\": \"my_model_type\",\n \"model_status\": \"my_model_status\",\n \"model_description\": \"my_model_description\",\n \"model_tags\": [\"my_model_tag1\", \"my_model_tag2\"],\n \"model_created_at\": \"my_model_created_at\",\n \"model_updated_at\": \"my_model_updated_at\",\n }\n requests_mock.post(\n f\"{self.__api_url}/model/remove\",\n json=response_payload,\n headers=DEFAULT_HEADERS,\n )\n # when\n response = self.client.unload_model(model_id=model_id)\n # then\n assert response.model_id == de_aliased_model_id\n assert response.model_name == \"my_model_name\"\n assert response.model_version == \"my_model_version\"\n assert response.model_type == \"my_model_type\"\n assert response.model_status == \"my_model_status\"\n assert response.model_description == \"my_model_description\"\n assert response.model_tags == [\"my_model_tag1\", \"my_model_tag2\"]\n assert response.model_created_at == \"my_model_created_at\"\n assert response.model_updated_at == \"my_model_updated_at\"\n assert self.client.selected_model is None\n```\n" + ], + "line": 31, + "token": 162, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 44, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=20.0,\n max_iteration=100,\n distance_strict=False,\n seed=None,\n):\n rng = np.random.default_rng(seed=seed)\n units_locations = np.zeros((num_units, 3), dtype=\"float32\")\n\n minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um\n minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um\n\n units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)\n units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)\n units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)\n\n if minimum_distance is not None:\n solution_found = False\n renew_inds = None\n for i in range(max_iteration):\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n inds0, inds1 = np.nonzero(distances < minimum_distance)\n mask = inds0 != inds1\n inds0 = inds0[mask]\n inds1 = inds1[mask]\n\n if inds0.size > 0:\n if renew_inds is None:\n renew_inds = np.unique(inds0)\n else:\n # random only bad ones in the previous set\n renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]\n\n units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)\n units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)\n units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)\n else:\n solution_found = True\n break\n\n if not solution_found:\n if distance_strict:\n raise ValueError(\n f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} \"\n \"You can use distance_strict=False or reduce minimum distance\"\n )\n else:\n warnings.warn(f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}\")\n\n return units_locations\n```\n###test function signature:\n```python\n\ndef test_generate_unit_locations():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_unit_locations", + "reference": "\ndef test_generate_unit_locations():\n seed = 0\n\n probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)\n channel_locations = probe.contact_positions\n\n num_units = 100\n minimum_distance = 20.0\n unit_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=minimum_distance,\n max_iteration=500,\n distance_strict=False,\n seed=seed,\n )\n distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)\n dist_flat = np.triu(distances, k=1).flatten()\n dist_flat = dist_flat[dist_flat > 0]\n assert np.all(dist_flat > minimum_distance)\n\n # import matplotlib.pyplot as plt\n # fig, ax = plt.subplots()\n # ax.hist(dist_flat, bins = np.arange(0, 400, 10))\n # fig, ax = plt.subplots()\n # from probeinterface.plotting import plot_probe\n # plot_probe(probe, ax=ax)\n # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)\n # plt.show()", + "completions": [ + "```python\n\ndef test_generate_unit_locations():\n num_units = 10\n channel_locations = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])\n margin_um = 20.0\n minimum_z = 5.0\n maximum_z = 40.0\n minimum_distance = 20.0\n max_iteration = 100\n distance_strict = False\n seed = None\n\n units_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=margin_um,\n minimum_z=minimum_z,\n maximum_z=maximum_z,\n minimum_distance=minimum_distance,\n max_iteration=max_iteration,\n distance_strict=distance_strict,\n seed=seed,\n )\n\n assert units_locations.shape == (num_units, 3)\n assert np.all(units_locations[:, 0] >= margin_um)\n assert np.all(units_locations[:, 0] <= 100 - margin_um)\n assert np.all(units_locations[:, 1] >= margin_um)\n assert np.all(units_locations[:, 1] <= 100 - margin_um)\n assert np.all(units_locations[:, 2] >= minimum_z)\n assert np.all(units_locations[:, 2] <= maximum_z)\n\n # check distance\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n assert np.all(distances >= minimum_distance)\n\n # check no duplicate\n assert np.all(np.unique(units_locations, axis=0) == units_locations)\n```\n" + ], + "line": 58, + "token": 453, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 45, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def to_dataframe(\n self, index: Optional[Union[str, None]] = \"date\", sort_by: Optional[str] = None\n ) -> pd.DataFrame:\n\n def is_list_of_basemodel(items: Union[List[T], T]) -> bool:\n return isinstance(items, list) and all(\n isinstance(item, BaseModel) for item in items\n )\n\n if self.results is None or not self.results:\n raise OpenBBError(\"Results not found.\")\n\n if isinstance(self.results, pd.DataFrame):\n return self.results\n\n try:\n res = self.results\n df = None\n sort_columns = True\n\n # List[Dict]\n if isinstance(res, list) and len(res) == 1 and isinstance(res[0], dict):\n r = res[0]\n dict_of_df = {}\n\n for k, v in r.items():\n # Dict[str, List[BaseModel]]\n if is_list_of_basemodel(v):\n dict_of_df[k] = basemodel_to_df(v, index)\n sort_columns = False\n # Dict[str, Any]\n else:\n dict_of_df[k] = pd.DataFrame(v)\n\n df = pd.concat(dict_of_df, axis=1)\n\n # List[BaseModel]\n elif is_list_of_basemodel(res):\n dt: Union[List[Data], Data] = res # type: ignore\n df = basemodel_to_df(dt, index)\n sort_columns = False\n # List[List | str | int | float] | Dict[str, Dict | List | BaseModel]\n else:\n try:\n df = pd.DataFrame(res)\n # Set index, if any\n if index is not None and index in df.columns:\n df.set_index(index, inplace=True)\n\n except ValueError:\n if isinstance(res, dict):\n df = pd.DataFrame([res])\n\n if df is None:\n raise OpenBBError(\"Unsupported data format.\")\n\n # Drop columns that are all NaN, but don't rearrange columns\n if sort_columns:\n df.sort_index(axis=1, inplace=True)\n df = df.dropna(axis=1, how=\"all\")\n\n # Sort by specified column\n if sort_by:\n df.sort_values(by=sort_by, inplace=True)\n\n except OpenBBError as e:\n raise e\n except ValueError as ve:\n raise OpenBBError(\n f\"ValueError: {ve}. Ensure the data format matches the expected format.\"\n ) from ve\n except TypeError as te:\n raise OpenBBError(\n f\"TypeError: {te}. Check the data types in your results.\"\n ) from te\n except Exception as ex:\n raise OpenBBError(f\"An unexpected error occurred: {ex}\") from ex\n\n return df\n```\n###test function signature:\n```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_to_dataframe_w_args", + "reference": "def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n co = OBBject(results=results)\n\n # Act and Assert\n result = co.to_dataframe(index=index, sort_by=sort_by)\n assert isinstance(result, pd.DataFrame)\n assert result.index.name == index\n\n # check if dataframe is properly sorted\n assert result[sort_by].is_monotonic_increasing", + "completions": [ + "```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n from openbb_terminal.decorators import log_start_end\n from openbb_terminal.stocks.fundamental_analysis import financedatabase_model\n\n @log_start_end(log=logger)\n def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n df = financedatabase_model.to_dataframe(results, index, sort_by)\n # Act\n # Assert\n assert df.shape[0] > 0\n assert df.shape[1] > 0\n assert df.index.name == index\n assert df.columns.tolist() == sort_by\n\n # Act\n test_to_dataframe_w_args(results, index, sort_by)\n # Assert\n```\n" + ], + "line": 76, + "token": 566, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 46, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef unique(ar, return_index=False, return_inverse=False,\n return_counts=False, axis=None, *, equal_nan=True):\n ar = np.asanyarray(ar)\n if axis is None:\n ret = _unique1d(ar, return_index, return_inverse, return_counts, \n equal_nan=equal_nan)\n return _unpack_tuple(ret)\n\n # axis was specified and not None\n try:\n ar = np.moveaxis(ar, axis, 0)\n except np.AxisError:\n # this removes the \"axis1\" or \"axis2\" prefix from the error message\n raise np.AxisError(axis, ar.ndim) from None\n\n # Must reshape to a contiguous 2D array for this to work...\n orig_shape, orig_dtype = ar.shape, ar.dtype\n ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))\n ar = np.ascontiguousarray(ar)\n dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]\n\n # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured\n # data type with `m` fields where each field has the data type of `ar`.\n # In the following, we create the array `consolidated`, which has\n # shape `(n,)` with data type `dtype`.\n try:\n if ar.shape[1] > 0:\n consolidated = ar.view(dtype)\n else:\n # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is\n # a data type with itemsize 0, and the call `ar.view(dtype)` will\n # fail. Instead, we'll use `np.empty` to explicitly create the\n # array with shape `(len(ar),)`. Because `dtype` in this case has\n # itemsize 0, the total size of the result is still 0 bytes.\n consolidated = np.empty(len(ar), dtype=dtype)\n except TypeError as e:\n # There's no good way to do this for object arrays, etc...\n msg = 'The axis argument to unique is not supported for dtype {dt}'\n raise TypeError(msg.format(dt=ar.dtype)) from e\n\n def reshape_uniq(uniq):\n n = len(uniq)\n uniq = uniq.view(orig_dtype)\n uniq = uniq.reshape(n, *orig_shape[1:])\n uniq = np.moveaxis(uniq, 0, axis)\n return uniq\n\n output = _unique1d(consolidated, return_index,\n return_inverse, return_counts, equal_nan=equal_nan)\n output = (reshape_uniq(output[0]),) + output[1:]\n return _unpack_tuple(output)\n```\n###test function signature:\n```python\n\n def test_unique_1d(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_1d", + "reference": "\n def test_unique_1d(self):\n\n def check_all(a, b, i1, i2, c, dt):\n base_msg = 'check {0} failed for type {1}'\n\n msg = base_msg.format('values', dt)\n v = unique(a)\n assert_array_equal(v, b, msg)\n\n msg = base_msg.format('return_index', dt)\n v, j = unique(a, True, False, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i1, msg)\n\n msg = base_msg.format('return_inverse', dt)\n v, j = unique(a, False, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i2, msg)\n\n msg = base_msg.format('return_counts', dt)\n v, j = unique(a, False, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, c, msg)\n\n msg = base_msg.format('return_index and return_inverse', dt)\n v, j1, j2 = unique(a, True, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n\n msg = base_msg.format('return_index and return_counts', dt)\n v, j1, j2 = unique(a, True, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format('return_inverse and return_counts', dt)\n v, j1, j2 = unique(a, False, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i2, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format(('return_index, return_inverse '\n 'and return_counts'), dt)\n v, j1, j2, j3 = unique(a, True, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n assert_array_equal(j3, c, msg)\n\n a = [5, 7, 1, 2, 1, 5, 7]*10\n b = [1, 2, 5, 7]\n i1 = [2, 3, 0, 1]\n i2 = [2, 3, 0, 1, 0, 2, 3]*10\n c = np.multiply([2, 1, 2, 2], 10)\n\n # test for numeric arrays\n types = []\n types.extend(np.typecodes['AllInteger'])\n types.extend(np.typecodes['AllFloat'])\n types.append('datetime64[D]')\n types.append('timedelta64[D]')\n for dt in types:\n aa = np.array(a, dt)\n bb = np.array(b, dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for object arrays\n dt = 'O'\n aa = np.empty(len(a), dt)\n aa[:] = a\n bb = np.empty(len(b), dt)\n bb[:] = b\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for structured arrays\n dt = [('', 'i'), ('', 'i')]\n aa = np.array(list(zip(a, a)), dt)\n bb = np.array(list(zip(b, b)), dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for ticket #2799\n aa = [1. + 0.j, 1 - 1.j, 1]\n assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])\n\n # test for ticket #4785\n a = [(1, 2), (1, 2), (2, 3)]\n unq = [1, 2, 3]\n inv = [0, 1, 0, 1, 1, 2]\n a1 = unique(a)\n assert_array_equal(a1, unq)\n a2, a2_inv = unique(a, return_inverse=True)\n assert_array_equal(a2, unq)\n assert_array_equal(a2_inv, inv)\n\n # test for chararrays with return_inverse (gh-5099)\n a = np.chararray(5)\n a[...] = ''\n a2, a2_inv = np.unique(a, return_inverse=True)\n assert_array_equal(a2_inv, np.zeros(5))\n\n # test for ticket #9137\n a = []\n a1_idx = np.unique(a, return_index=True)[1]\n a2_inv = np.unique(a, return_inverse=True)[1]\n a3_idx, a3_inv = np.unique(a, return_index=True,\n return_inverse=True)[1:]\n assert_equal(a1_idx.dtype, np.intp)\n assert_equal(a2_inv.dtype, np.intp)\n assert_equal(a3_idx.dtype, np.intp)\n assert_equal(a3_inv.dtype, np.intp)\n\n # test for ticket 2111 - float\n a = [2.0, np.nan, 1.0, np.nan]\n ua = [1.0, 2.0, np.nan]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - complex\n a = [2.0-1j, np.nan, 1.0+1j, complex(0.0, np.nan), complex(1.0, np.nan)]\n ua = [1.0+1j, 2.0-1j, complex(0.0, np.nan)]\n ua_idx = [2, 0, 3]\n ua_inv = [1, 2, 0, 2, 2]\n ua_cnt = [1, 1, 3]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - datetime64\n nat = np.datetime64('nat')\n a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]\n ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - timedelta\n nat = np.timedelta64('nat')\n a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]\n ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for gh-19300\n all_nans = [np.nan] * 4\n ua = [np.nan]\n ua_idx = [0]\n ua_inv = [0, 0, 0, 0]\n ua_cnt = [4]\n assert_equal(np.unique(all_nans), ua)\n assert_equal(np.unique(all_nans, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(all_nans, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(all_nans, return_counts=True), (ua, ua_cnt))", + "completions": [ + "```python\n def test_unique_1d(self):\n ar = np.array([1, 2, 3, 4, 5, 5, 5, 5, 5, 5])\n self.assertEqual(unique(ar), np.array([1, 2, 3, 4, 5]))\n self.assertEqual(unique(ar, return_index=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n```\n" + ], + "line": 56, + "token": 571, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 47, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef scaled_dot_product_attention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n att_mask: Optional[Union[AttentionMask, \"SparseCS\", torch.Tensor]],\n dropout: Optional[torch.nn.Module] = None,\n block_size: int = 128,\n) -> torch.Tensor:\n autocast_disabled = (\n _has_cpp_library\n and isinstance(att_mask, SparseCS)\n or (att_mask is not None and att_mask.is_sparse)\n )\n seq_len = q.shape[-2]\n\n # switch if:\n # causal is required but mask is not sparse\n # fp16 or under amp context\n # sequence length is divisible by block size\n # same seq len for K and Q\n switch_to_blocksparse = (\n _is_blocksparse_available\n and (att_mask is not None and not att_mask.is_sparse)\n and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)\n and (q.dtype == torch.float16 or torch.is_autocast_enabled())\n and not seq_len % block_size\n and q.shape[-2] == k.shape[-2]\n )\n\n if switch_to_blocksparse:\n logger.info(\"Switching causal attention to Triton blocksparse...\")\n return blocksparse_attention(q, k, v, dropout, block_size)\n\n with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore\n if autocast_disabled:\n q, k, v = q.float(), k.float(), v.float()\n\n att = scaled_query_key_softmax(q, k, att_mask=att_mask)\n\n # Optional dropout, could be part of the masking in the future\n att = _apply_dropout(att, dropout)\n\n # Get to the predicted values, for all heads\n # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)\n y = bmm(att, v)\n return y\n```\n###test function signature:\n```python\n\ndef test_amp_attention_dense(device):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_amp_attention_dense", + "reference": "\ndef test_amp_attention_dense(device):\n b, s, d = 8, 64, 32\n prob = 0.9\n\n a = torch.rand(b, s, d, device=device)\n m = torch.rand(s, s, device=device) > prob\n\n with torch.cuda.amp.autocast():\n r = scaled_dot_product_attention(a, a, a, m)\n\n expected_device = torch.float16 if device == \"cuda\" else torch.float32\n assert r.dtype == expected_device", + "completions": [ + "```python\n\ndef test_amp_attention_dense(device):\n q = torch.randn(1, 128, 128, device=device)\n k = torch.randn(1, 128, 128, device=device)\n v = torch.randn(1, 128, 128, device=device)\n att_mask = torch.randn(1, 128, 128, device=device)\n dropout = torch.nn.Dropout(0.1)\n block_size = 128\n y = scaled_dot_product_attention(q, k, v, att_mask, dropout, block_size)\n assert y.shape == (1, 128, 128)\n```\n" + ], + "line": 50, + "token": 390, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 48, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def quantize(\n cls,\n tensor: Tensor,\n nbits: int = 4,\n channel_wise: bool = True,\n group_size: int = 64,\n optimize: bool = False,\n round_zero: bool = False,\n axis: int = 0,\n bitpack: bool = True,\n compute_dtype: torch.dtype | None = None,\n view_as_float: bool = False,\n device: str = \"cuda\",\n ) -> tuple:\n assert nbits in Quantizer.SUPPORTED_BITS, (\n \"nbits=\" + str(nbits) + \" not supported.\"\n )\n assert axis in [0, 1], \"axis should be either 0 or 1\"\n if group_size is not None:\n assert is_divisible(tensor.numel(), group_size), (\n \"group_size should be divisble by the total tensor dimensions. shape: \"\n + str(tensor.shape)\n + \", group_size: \"\n + str(group_size)\n )\n\n W = tensor.float()\n shape = W.shape\n\n # Reshape for grouping\n if (group_size is not None) and channel_wise:\n W = (\n W.reshape([-1, group_size])\n if (axis == 1)\n else W.reshape([group_size, -1])\n )\n\n # Get min/max values\n if not channel_wise:\n _min, _max = W.min(), W.max()\n optimize = False\n else:\n _min = W.min(axis=axis, keepdim=True)[0]\n _max = W.max(axis=axis, keepdim=True)[0]\n\n max_v = 2**nbits - 1\n min_v = 0\n min_max = [min_v, max_v]\n\n # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.\n scale = (max_v / (_max - _min)).clamp(\n max=2e4\n ) # clamp to avoid half-precision problems\n zero = -_min * scale\n\n # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14\n if round_zero:\n zero = torch.round(zero)\n\n # Fine-tune weights\n if optimize:\n W_q, scale, zero = Quantizer.optimize_weights(\n tensor=W,\n scale=scale,\n zero=zero,\n min_max=min_max,\n axis=axis,\n device=device,\n )\n else:\n W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])\n\n # Store meta-data (we invert the scale for dequantization)\n meta = {\n \"nbits\": nbits,\n \"group_size\": group_size,\n \"shape\": shape,\n \"scale\": 1.0 / scale,\n \"zero\": zero,\n \"axis\": axis,\n \"packing\": Quantizer.bit_to_packing[nbits],\n }\n meta[\"unpack_view_dtype\"] = Quantizer.unpack_view_dtype[meta[\"packing\"]]\n\n # Pack bits\n meta[\"view_as_float\"] = view_as_float\n if bitpack:\n W_q = Quantizer.pack[meta[\"packing\"]](W_q)\n if view_as_float:\n W_q = W_q.view(\n torch.float32 if compute_dtype is None else compute_dtype\n ) # store quantized weights as compute_dtype\n else:\n W_q = W_q.to(tensor.dtype)\n meta[\"packing\"] = None\n\n # cleanup\n del W, _min, _max\n torch.cuda.empty_cache()\n\n return W_q, meta\n```\n###test function signature:\n```python\n\n def test_quantizer(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_quantizer", + "reference": "\n def test_quantizer(self): \n for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:\n for nbits in [8,4,3,2,1]:\n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, view_as_float=False)\n if nbits == 3:\n assert W_q.dtype == torch.int32\n else:\n assert W_q.dtype == torch.uint8\n w_dq = Quantizer.dequantize(W_q, meta)\n norm1 = torch.norm(self.w - w_dq, p=0.7)\n \n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, compute_dtype=compute_dtype, view_as_float=True)\n assert W_q.dtype == compute_dtype\n w_dq = Quantizer.dequantize(W_q, meta)\n norm2 = torch.norm(self.w - w_dq, p=0.7)\n \n self.assertTrue(torch.equal(norm1, norm2))", + "completions": [ + "```python\n\n def test_quantizer(self):\n # test case 1\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 2\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=True, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], 0.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 3\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=True, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 4\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor," + ], + "line": 100, + "token": 645, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 49, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_num_ref_area", + "reference": "\ndef test_nanoplot_out_bars_with_num_ref_area():\n\n out_bars_with_num_ref_area = _generate_nanoplot(**CASES[13])\n\n assert _is_nanoplot_output(out_bars_with_num_ref_area)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_area = [1, 3]\n y_ref_line = 2\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_" + ], + "line": 749, + "token": 7038, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 50, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef mean(x, axis=None, keepdims=False):\n if isinstance(x, (list, tuple)):\n x = stack(x)\n x = convert_to_tensor(x)\n if axis == () or axis == []:\n # Torch handles the empty axis case differently from numpy.\n return x\n axis = to_tuple_or_list(axis) # see [NB] below\n\n ori_dtype = standardize_dtype(x.dtype)\n # torch.mean only supports floating point inputs\n compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n if \"int\" in ori_dtype or ori_dtype == \"bool\":\n result_dtype = compute_dtype\n else:\n result_dtype = ori_dtype\n\n # [NB] the python torch op torch.mean() is generated into\n # `torch._C._VariableFunctions.pyi`, and the method\n # signature is overloaded.\n # Dynamo won't actually find the correct signature of\n # `torch.mean()` if arguments are passed via kwargs\n # So we have to pass the arguments via positional args\n # EXCEPT for those that are forced as kwargs via the `*`\n # delimiter in the overloaded method signatures.\n # Additionally, we have to create a singleton-tuple\n # when `axis` is an int to match the existing fn signature\n result = torch.mean(\n x,\n axis,\n keepdims,\n dtype=to_torch_dtype(compute_dtype),\n )\n return cast(result, result_dtype)\n```\n###test function signature:\n```python\n\n def test_mean(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mean", + "reference": "\n def test_mean(self):\n x = np.array([[1, 2, 3], [3, 2, 1]])\n self.assertAllClose(knp.mean(x), np.mean(x))\n self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n self.assertAllClose(\n knp.mean(x, axis=1, keepdims=True),\n np.mean(x, axis=1, keepdims=True),\n )\n\n self.assertAllClose(knp.Mean()(x), np.mean(x))\n self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n self.assertAllClose(\n knp.Mean(axis=1, keepdims=True)(x),\n np.mean(x, axis=1, keepdims=True),\n )\n\n # test overflow\n x = np.array([65504, 65504, 65504], dtype=\"float16\")\n self.assertAllClose(knp.mean(x), np.mean(x))", + "completions": [ + "```python\n\n def test_mean(self):\n x = torch.randn(2, 3, 4)\n self.assertEqual(mean(x), torch.mean(x))\n self.assertEqual(mean(x, axis=0), torch.mean(x, dim=0))\n self.assertEqual(mean(x, axis=1), torch.mean(x, dim=1))\n self.assertEqual(mean(x, axis=2), torch.mean(x, dim=2))\n self.assertEqual(mean(x, axis=(0, 1)), torch.mean(x, dim=(0, 1)))\n self.assertEqual(mean(x, axis=(0, 2)), torch.mean(x, dim=(0, 2)))\n self.assertEqual(mean(x, axis=(1, 2)), torch.mean(x, dim=(1, 2)))\n self.assertEqual(mean(x, axis=(0, 1, 2)), torch.mean(x, dim=(0, 1, 2)))\n self.assertEqual(mean(x, axis=0, keepdims=True), torch.mean(x, dim=0, keepdim=True))\n self.assertEqual(mean(x, axis=1, keepdims=True), torch.mean(x, dim=1, keepdim=True))\n self.assertEqual(mean(x, axis=2, keepdims=True), torch.mean(x, dim=2, keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1), keepdims=True), torch.mean(x, dim=(0, 1), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 2), keepdims=True), torch.mean(x, dim=(0, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(1, 2), keepdims=True), torch.mean(x, dim=(1, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1, 2), keepdims=True), torch.mean(x, dim=(0, 1, 2), keepdim=True))\n```\n" + ], + "line": 42, + "token": 322, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 51, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_intrinsic_value(\n cash_flow: float,\n growth_rate: float,\n perpetual_growth_rate: float,\n weighted_average_cost_of_capital: float,\n cash_and_cash_equivalents: float,\n total_debt: float,\n shares_outstanding: float,\n periods: int = 5,\n) -> pd.DataFrame:\n components = {}\n\n cash_flow_projection = [cash_flow]\n\n # Cash Flow to Use\n for period in range(1, periods + 1):\n if period == 1:\n cash_flow_projection.append(cash_flow_projection[0] * (1 + growth_rate))\n else:\n cash_flow_projection.append(\n cash_flow_projection[period - 1] * (1 + growth_rate)\n )\n\n # Calculate the Terminal Value\n terminal_value = (\n cash_flow_projection[-1]\n * (1 + perpetual_growth_rate)\n / (weighted_average_cost_of_capital - perpetual_growth_rate)\n )\n\n # Add Terminal Value to the end of the cash flow projection\n cash_flow_projection[-1] = cash_flow_projection[-1] + terminal_value\n\n # Calculate the Present Value based on the Discounted Cash Flow Formula\n cash_flow_present_value = []\n for index, cash_flow_value in enumerate(cash_flow_projection):\n cash_flow_present_value.append(\n cash_flow_value / (1 + weighted_average_cost_of_capital) ** (index + 1)\n )\n\n # Calculate the Enterprise Value\n enterprise_value = sum(cash_flow_present_value)\n\n # Calculate the Equity Value\n equity_value = enterprise_value + cash_and_cash_equivalents - total_debt\n\n # Calculate the Intrinsic Value\n intrinsic_value = equity_value / shares_outstanding\n\n # Combine the components into a dictionary\n components = {\n \"Terminal Value\": terminal_value,\n \"Cash Flow Projection\": cash_flow_projection[-1],\n \"Enterprise Value\": enterprise_value,\n \"Equity Value\": equity_value,\n \"Intrinsic Value\": intrinsic_value,\n }\n\n return pd.DataFrame.from_dict(\n components, orient=\"index\", columns=[f\"Periods = {periods}\"]\n )\n```\n###test function signature:\n```python\n\ndef test_get_intrinsic_value(recorder):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_intrinsic_value", + "reference": "\ndef test_get_intrinsic_value(recorder):\n recorder.capture(\n intrinsic_model.get_intrinsic_value(\n cash_flow=500,\n growth_rate=0.05,\n perpetual_growth_rate=0.02,\n weighted_average_cost_of_capital=0.1,\n cash_and_cash_equivalents=100,\n total_debt=100,\n shares_outstanding=100,\n periods=5,\n )\n )", + "completions": [ + "```python\n\ndef test_get_intrinsic_value(recorder):\n cash_flow = 1000000\n growth_rate = 0.05\n perpetual_growth_rate = 0.05\n weighted_average_cost_of_capital = 0.1\n cash_and_cash_equivalents = 1000000\n total_debt = 1000000\n shares_outstanding = 1000000\n periods = 5\n\n result = get_intrinsic_value(\n cash_flow,\n growth_rate,\n perpetual_growth_rate,\n weighted_average_cost_of_capital,\n cash_and_cash_equivalents,\n total_debt,\n shares_outstanding,\n periods,\n )\n\n recorder.capture(result)\n```\n" + ], + "line": 61, + "token": 379, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 52, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef to_numpy(\n evset: EventSet,\n timestamp_to_datetime: bool = True,\n timestamps: bool = True,\n) -> Dict[str, ndarray]:\n timestamp_key = \"timestamp\"\n index_names = evset.schema.index_names()\n feature_names = evset.schema.feature_names()\n\n column_names = index_names + feature_names\n if timestamps:\n column_names += [timestamp_key]\n\n dst = {column_name: [] for column_name in column_names}\n for index, data in evset.data.items():\n assert isinstance(index, tuple)\n\n if timestamps:\n # Timestamps\n if evset.schema.is_unix_timestamp and timestamp_to_datetime:\n dst[timestamp_key].append(\n data.timestamps.astype(\"datetime64[s]\")\n )\n else:\n dst[timestamp_key].append(data.timestamps)\n\n # Features\n for feature_name, feature in zip(feature_names, data.features):\n dst[feature_name].append(feature)\n\n # Indexes\n num_timestamps = len(data.timestamps)\n for index_name, index_item in zip(index_names, index):\n dst[index_name].append(np.repeat(index_item, num_timestamps))\n\n dst = {k: np.concatenate(v) for k, v in dst.items()}\n return dst\n```\n###test function signature:\n```python\n\n def test_correct(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct", + "reference": "\n def test_correct(self):\n evset = event_set(\n timestamps=[\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n features={\n \"feature_1\": [0.5, 0.6],\n \"my_index\": [\"red\", \"blue\"],\n },\n indexes=[\"my_index\"],\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"timestamp\": np.array(\n [\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n dtype=\"datetime64[s]\",\n ),\n \"feature_1\": np.array([0.5, 0.6]),\n \"my_index\": np.array([b\"red\", b\"blue\"]),\n }\n\n for k in expected:\n np.testing.assert_array_equal(\n np.sort(result[k]), np.sort(expected[k])\n )", + "completions": [ + "```python\n def test_correct(self):\n evset = EventSet(\n data={\n \"timestamp\": [1610000000, 1610000000, 1610000000, 1610000000],\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n },\n schema=EventSetSchema(\n index_names=[\"index\"],\n feature_names=[\"feature1\", \"feature2\"],\n is_unix_timestamp=True,\n ),\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n \"timestamp\": [\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n ],\n }\n\n self.assertDictEqual(result, expected)\n```\n" + ], + "line": 41, + "token": 276, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 53, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def from_file(cls, path: str) -> \"LayerConfig\":\n logger.info(f\"Loading route config from {path}\")\n _, ext = os.path.splitext(path)\n with open(path, \"r\") as f:\n if ext == \".json\":\n layer = json.load(f)\n elif ext in [\".yaml\", \".yml\"]:\n layer = yaml.safe_load(f)\n else:\n raise ValueError(\n \"Unsupported file type. Only .json and .yaml are supported\"\n )\n\n if not is_valid(json.dumps(layer)):\n raise Exception(\"Invalid config JSON or YAML\")\n\n encoder_type = layer[\"encoder_type\"]\n encoder_name = layer[\"encoder_name\"]\n routes = []\n for route_data in layer[\"routes\"]:\n # Handle the 'llm' field specially if it exists\n if \"llm\" in route_data and route_data[\"llm\"] is not None:\n llm_data = route_data.pop(\n \"llm\"\n ) # Remove 'llm' from route_data and handle it separately\n # Use the module path directly from llm_data without modification\n llm_module_path = llm_data[\"module\"]\n # Dynamically import the module and then the class from that module\n llm_module = importlib.import_module(llm_module_path)\n llm_class = getattr(llm_module, llm_data[\"class\"])\n # Instantiate the LLM class with the provided model name\n llm = llm_class(name=llm_data[\"model\"])\n # Reassign the instantiated llm object back to route_data\n route_data[\"llm\"] = llm\n\n # Dynamically create the Route object using the remaining route_data\n route = Route(**route_data)\n routes.append(route)\n\n return cls(\n encoder_type=encoder_type, encoder_name=encoder_name, routes=routes\n )\n```\n###test function signature:\n```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_from_file_yaml", + "reference": "\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n config_path = tmp_path / \"config.yaml\"\n config_path.write_text(\n layer_yaml()\n ) # Assuming layer_yaml() returns a valid YAML string\n\n # Load the LayerConfig from the temporary file\n layer_config = LayerConfig.from_file(str(config_path))\n\n # Assertions to verify the loaded configuration\n assert layer_config.encoder_type == \"cohere\"\n assert layer_config.encoder_name == \"embed-english-v3.0\"\n assert len(layer_config.routes) == 2\n assert layer_config.routes[0].name == \"politics\"", + "completions": [ + "```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n yaml_file = tmp_path / \"test_layer.yaml\"\n yaml_file.write_text(\n \"\"\"\n encoder_type: openai\n encoder_name: gpt2\n routes:\n - name: test_route\n model_name: test_model\n llm:\n module: tests.test_layer\n class: TestLLM\n model: test_model\n encoder_kwargs:\n max_position_embeddings: 512\n num_attention_heads: 8\n intermediate_size: 2048\n hidden_act: gelu\n hidden_dropout_prob: 0.1\n attention_probs_dropout_prob: 0.1\n type_vocab_size: 16\n initializer_range: 0.02\n layer_norm_eps: 1e-12\n use_cache: True\n is_decoder: False\n classifier_dropout: 0.1\n use_multiplicative_input_mask: False\n num_labels: 2\n output_attentions: False\n output_hidden_states: False\n return_dict: True\n task_specific_params:\n summarization:\n max_length: 20\n min_length: 5\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n" + ], + "line": 49, + "token": 405, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 54, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def bind_all(self, circuit: QuantumCircuit) -> np.ndarray:\n arr = np.empty(self.shape, dtype=object)\n for idx in np.ndindex(self.shape):\n arr[idx] = self.bind(circuit, idx)\n return arr\n```\n###test function signature:\n```python\n def test_data(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_data", + "reference": " def test_data(self):\n with self.subTest(\"binding a single value\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 0)\n self.assertEqual(ba.shape, ())\n self.assertEqual(ba.size, 1)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuit = ba.bind(self.circuit, ())\n self.assertEqual(bound_circuit, self.circuit.assign_parameters(vals))\n\n with self.subTest(\"binding an array\"):\n vals = np.linspace(0, 1, 300).reshape((2, 3, 50))\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 2)\n self.assertEqual(ba.shape, (2, 3))\n self.assertEqual(ba.size, 6)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuits = ba.bind_all(self.circuit)\n self.assertEqual(bound_circuits.shape, (2, 3))\n self.assertEqual(bound_circuits[0, 0], self.circuit.assign_parameters(vals[0, 0]))\n self.assertEqual(bound_circuits[0, 1], self.circuit.assign_parameters(vals[0, 1]))\n self.assertEqual(bound_circuits[0, 2], self.circuit.assign_parameters(vals[0, 2]))\n self.assertEqual(bound_circuits[1, 0], self.circuit.assign_parameters(vals[1, 0]))\n self.assertEqual(bound_circuits[1, 1], self.circuit.assign_parameters(vals[1, 1]))\n self.assertEqual(bound_circuits[1, 2], self.circuit.assign_parameters(vals[1, 2]))\n\n with self.subTest(\"binding a single param\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params[0]: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 1)\n self.assertEqual(ba.ndim, 1)\n self.assertEqual(ba.shape, (50,))\n self.assertEqual(ba.size, 50)\n self.assertEqual(list(ba.data.keys()), [(self.params[0].name,)])\n np.testing.assert_allclose(list(ba.data.values()), [vals[..., np.newaxis]])", + "completions": [ + "```python\n def test_data(self):\n from qiskit import QuantumCircuit\n from qiskit.quantum_info import Pauli\n from qiskit.quantum_info import PauliTable\n from qiskit.quantum_info import PauliSumOp\n from qiskit.quantum_info import PauliOp\n from qiskit.quantum_info import PauliExpectation\n from qiskit.quantum_info import PauliMoment\n from qiskit.quantum_info import PauliTensor\n from qiskit.quantum_info import PauliList\n from qiskit.quantum_info import PauliTransform\n from qiskit.quantum_info import PauliTransformOp\n from qiskit.quantum_info import PauliTransformList\n from qiskit.quantum_info import PauliTransformSum\n from qiskit.quantum_info import PauliTransformProduct\n from qiskit.quantum_info import PauliTransformExpectation\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum" + ], + "line": 15, + "token": 122, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef set_seed(\n seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None\n) -> int:\n world_size, rank = get_world_size_and_rank()\n max_val = np.iinfo(np.uint32).max - world_size + 1\n min_val = np.iinfo(np.uint32).min\n if seed is None:\n rand_seed = torch.randint(min_val, max_val, (1,))\n seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks\n if seed < min_val or seed > max_val:\n raise ValueError(\n f\"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]\"\n )\n local_seed = seed + rank\n _log.debug(\n f\"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}\"\n )\n\n torch.manual_seed(local_seed)\n np.random.seed(local_seed)\n random.seed(local_seed)\n\n if debug_mode is not None:\n _log.debug(f\"Setting deterministic debug mode to {debug_mode}\")\n torch.set_deterministic_debug_mode(debug_mode)\n deterministic_debug_mode = torch.get_deterministic_debug_mode()\n if deterministic_debug_mode == 0:\n _log.debug(\"Disabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = False\n torch.backends.cudnn.benchmark = True\n else:\n _log.debug(\"Enabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.benchmark = False\n # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n return seed\n```\n###test function signature:\n```python\n\n def test_deterministic_true(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_deterministic_true", + "reference": "\n def test_deterministic_true(self) -> None:\n for det_debug_mode, det_debug_mode_str in [(1, \"warn\"), (2, \"error\")]:\n warn_only = det_debug_mode == 1\n for debug_mode in (det_debug_mode, det_debug_mode_str):\n set_seed(42, debug_mode=debug_mode)\n assert torch.backends.cudnn.deterministic\n assert not torch.backends.cudnn.benchmark\n assert det_debug_mode == torch.get_deterministic_debug_mode()\n assert torch.are_deterministic_algorithms_enabled()\n assert (\n warn_only == torch.is_deterministic_algorithms_warn_only_enabled()\n )\n assert os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] == \":4096:8\"", + "completions": [ + "```python\n def test_deterministic_true(self) -> None:\n seed = set_seed(debug_mode=True)\n self.assertEqual(seed, 123456789)\n```\n" + ], + "line": 45, + "token": 345, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 56, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n```\n###test function signature:\n```python\n\ndef test_concat_periodarray_2d():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_concat_periodarray_2d", + "reference": "\ndef test_concat_periodarray_2d():\n pi = pd.period_range(\"2016-01-01\", periods=36, freq=\"D\")\n arr = pi._data.reshape(6, 6)\n\n result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)\n tm.assert_period_array_equal(result, arr)\n\n result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)\n tm.assert_period_array_equal(result, arr)\n\n msg = (\n \"all the input array dimensions.* for the concatenation axis must match exactly\"\n )\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)\n\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:2], arr[2:]], axis=1)", + "completions": [ + "```python\n\ndef test_concat_periodarray_2d():\n import numpy as np\n import pandas as pd\n from pandas.core.dtypes.common import is_period_dtype\n\n def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n\n def _concat_datetime(to_concat, axis: AxisInt = 0):\n from pandas.core.arrays import PeriodArray\n\n if axis == 0:\n " + ], + "line": 70, + "token": 610, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 57, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef ingest_ocp_payload(request):\n request_id = uuid4().hex\n response_data = {\"request-id\": request_id, \"payload-name\": []}\n if request.method == \"POST\":\n if not request.FILES:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for _, file in request.FILES.items():\n payload_name = file.name\n response_data[\"payload-name\"].append(payload_name)\n s3_signature = get_s3_signature(settings.S3_ENDPOINT, payload_name, method=\"put_object\")\n res = upload_file_to_s3(s3_signature, data=file.file)\n if res.status_code == HTTPStatus.OK:\n response_data[\"upload\"] = \"success\"\n else:\n response_data[\"upload\"] = \"failed\"\n response_data[\"failed-reason\"] = res.reason\n return Response(response_data, status=res.status_code)\n send_payload(request_id, payload_name)\n else:\n params = request.query_params\n payload_names = params.get(\"payload_name\")\n if not payload_names:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for payload_name in payload_names.split(\",\"):\n send_payload(request_id, payload_name)\n response_data[\"payload-name\"].append(payload_name)\n\n response_data[\"ingest-started\"] = True\n\n return Response(response_data, status=HTTPStatus.ACCEPTED)\n```\n###test function signature:\n```python\n\n def test_ingest_ocp_payload(self):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ingest_ocp_payload", + "reference": "\n def test_ingest_ocp_payload(self):\n # Arrange\n file1 = SimpleUploadedFile(\"file1.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n file2 = SimpleUploadedFile(\"file2.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n test_table = [\n # Happy path tests\n (\n \"POST\",\n {\"file1\": file1, \"file2\": file2},\n \"\",\n HTTPStatus.ACCEPTED,\n {\"upload\": \"success\", \"ingest-started\": True},\n \"happy_path_post\",\n ),\n (\"GET\", {}, \"?payload_name=file1,file2\", HTTPStatus.ACCEPTED, {\"ingest-started\": True}, \"happy_path_get\"),\n # Edge cases\n (\"POST\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_files_post\"),\n (\"GET\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_payload_name_get\"),\n # Error cases\n (\"POST\", {\"file1\": file1}, \"\", HTTPStatus.INTERNAL_SERVER_ERROR, {\"upload\": \"failed\"}, \"error_case_post\"),\n (\"GET\", {}, \"?payload_name=non_existent_file\", HTTPStatus.ACCEPTED, {}, \"error_case_get\"),\n ]\n for test in test_table:\n method, files, query_params, expected_status, expected_response, test_id = test\n with self.subTest(test=test_id):\n\n mock_response = MagicMock()\n mock_response.status_code = (\n HTTPStatus.OK if expected_response.get(\"upload\") != \"failed\" else HTTPStatus.INTERNAL_SERVER_ERROR\n )\n mock_response.reason = \"Some reason\"\n\n with patch(\"masu.api.ingest_ocp_payload.get_s3_signature\") as mock_get_s3_signature, patch(\n \"masu.api.ingest_ocp_payload.upload_file_to_s3\", return_value=mock_response\n ) as mock_upload_file_to_s3, patch(\"masu.api.ingest_ocp_payload.send_payload\") as mock_send_payload:\n\n request = (\n APIRequestFactory().get(f\"/{query_params}\")\n if method == \"GET\"\n else APIRequestFactory().post(\"/\", data=files)\n )\n\n # Act\n response = ingest_ocp_payload(request)\n\n # Assert\n self.assertEqual(response.status_code, expected_status)\n if method == \"POST\" and files:\n mock_get_s3_signature.assert_called()\n mock_upload_file_to_s3.assert_called()\n if method == \"GET\" and \"payload_name\" in query_params:\n mock_send_payload.assert_called()\n for key, value in expected_response.items():\n self.assertEqual(response.data[key], value)", + "completions": [ + "```python\n def test_ingest_ocp_payload(self):\n # Arrange\n request = self.factory.post(\n \"/api/v1/ingest/ocp/payload\",\n data={\"payload_name\": \"payload_name\"},\n content_type=\"multipart/form-data\",\n )\n request.FILES = {\"payload_name\": \"payload_name\"}\n # Act\n response = ingest_ocp_payload(request)\n # Assert\n self.assertEqual(response.status_code, HTTPStatus.ACCEPTED)\n self.assertEqual(response.data[\"request-id\"], \"request_id\")\n self.assertEqual(response.data[\"payload-name\"], [\"payload_name\"])\n self.assertEqual(response.data[\"ingest-started\"], True)\n```\n" + ], + "line": 41, + "token": 308, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 58, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def compute_nb_items_with_metadata_field(\n project_id: str,\n metadata_field: str,\n collection_name: str,\n quantile_value: float = 0.1,\n) -> Tuple[int, float, int]:\n\n mongo_db = await get_mongo_db()\n\n try:\n pipeline = [\n # Match on the project id and the existence of the metadata field\n {\n \"$match\": {\n \"project_id\": project_id,\n f\"metadata.{metadata_field}\": {\"$exists\": True, \"$ne\": None},\n }\n },\n {\n \"$group\": {\n \"_id\": \"$user_id\", # Group by 'user_id'\n \"count\": {\"$sum\": 1}, # Count the documents in each group\n }\n },\n {\n \"$sort\": {\"count\": -1} # Sort by 'count' in descending order\n },\n ]\n\n # Execute the aggregation pipeline\n results = [] # List to hold the results\n async for group in mongo_db[collection_name].aggregate(pipeline):\n results.append(group[\"count\"])\n\n logger.debug(\n f\"Results for {metadata_field} in collection {collection_name} for project {project_id}: {results}\"\n )\n\n # If no results, return 0\n if len(results) == 0:\n return 0, 0.0, 0\n\n # Get the average and the quantiles\n average = sum(results) / len(results)\n bottom_quantile = results[int(len(results) * quantile_value)]\n top_quantile = results[int(len(results) * (1 - quantile_value))]\n\n return bottom_quantile, average, top_quantile\n\n except Exception as e:\n logger.warning(\n f\"Failed to fetch the number of {metadata_field} in collection {collection_name} for project {project_id}: {e}\",\n )\n return 0, 0.0, 0\n```\n###test function signature:\n```python\n\nasync def test_main_pipeline(db, populated_project):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_main_pipeline", + "reference": "\nasync def test_main_pipeline(db, populated_project):\n async for mongo_db in db:\n test_project_id = populated_project.id\n\n # Run the function\n topics = await fetch_topics(test_project_id)\n\n logger.debug(f\"Topics: {topics}\")\n\n count = await nb_items_with_a_metadata_field(\n test_project_id, \"tasks\", \"user_id\"\n )\n\n assert count > 0\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_successrate_metadata_quantiles(test_project_id, \"user_id\")\n\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert bottom_quantile >= 0\n # assert bottom_quantile <= top_quantile\n\n test_project_id = populated_project.id\n test_metadata_field = \"user_id\"\n test_collection_name = \"tasks\"\n test_quantile_value = 0.1\n\n # Run the function\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_nb_items_with_metadata_field(\n test_project_id,\n test_metadata_field,\n test_collection_name,\n test_quantile_value,\n )\n\n logger.debug(\"runned compute_nb_items_with_metadata_field\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, int)\n assert isinstance(average, float)\n assert isinstance(top_quantile, int)\n # assert bottom_quantile <= average <= top_quantile\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_session_length_per_metadata(\n test_project_id,\n test_metadata_field,\n test_quantile_value,\n )\n logger.debug(\"runned compute_session_length_per_metadata\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, float)\n assert isinstance(average, float)\n assert isinstance(top_quantile, float)\n # assert bottom_quantile <= average <= top_quantile", + "completions": [ + "```python\nasync def test_main_pipeline(db, populated_project):\n # Arrange\n project_id = populated_project.id\n collection_name = \"collection_name\"\n metadata_field = \"metadata_field\"\n quantile_value = 0.1\n\n # Act\n result = await compute_nb_items_with_metadata_field(\n project_id, metadata_field, collection_name, quantile_value\n )\n\n # Assert\n assert result == (0, 0.0, 0)\n```\n" + ], + "line": 56, + "token": 430, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 59, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n```\n###test function signature:\n```python\n\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_engine_auto_error_message", + "reference": "\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n from pandas.compat._optional import VERSIONS\n\n # Do we have engines installed, but a bad version of them?\n pa_min_ver = VERSIONS.get(\"pyarrow\")\n fp_min_ver = VERSIONS.get(\"fastparquet\")\n have_pa_bad_version = (\n False\n if not _HAVE_PYARROW\n else Version(pyarrow.__version__) < Version(pa_min_ver)\n )\n have_fp_bad_version = (\n False\n if not _HAVE_FASTPARQUET\n else Version(fastparquet.__version__) < Version(fp_min_ver)\n )\n # Do we have usable engines installed?\n have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version\n have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version\n\n if not have_usable_pa and not have_usable_fp:\n # No usable engines found.\n if have_pa_bad_version:\n match = f\"Pandas requires version .{pa_min_ver}. or newer of .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n\n if have_fp_bad_version:\n match = f\"Pandas requires version .{fp_min_ver}. or newer of .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")", + "completions": [ + "```python\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n import pyarrow\n import fastparquet\n from pyarrow.parquet import ParquetEngine\n from fastparquet.parquet import ParquetEngine as FastParquetEngine\n\n def get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n\n # pyarrow is installed\n assert get_engine(\"auto\").__class__ == PyArrowImpl\n assert get_engine(\"pyarrow\").__class__ == PyArrowImpl\n assert get_engine(\"fastparquet\").__class__ == FastParquetImpl\n\n # pyarrow is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'pyarrow'\" in str(excinfo.value)\n\n # fastparquet is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'fastparquet'\" in str(excinfo.value)\n\n # pyarrow is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n\n # fastparquet is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n" + ], + "line": 38, + "token": 277, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 60, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def attribute(\n self,\n inp: InterpretableInput,\n target: Union[str, torch.Tensor, None] = None,\n num_trials: int = 1,\n gen_args: Optional[Dict] = None,\n # internal callback hook can be used for logging\n _inspect_forward: Optional[Callable] = None,\n **kwargs,\n ) -> LLMAttributionResult:\n\n assert isinstance(\n inp, self.SUPPORTED_INPUTS\n ), f\"LLMAttribution does not support input type {type(inp)}\"\n\n if target is None:\n # generate when None\n assert hasattr(self.model, \"generate\") and callable(self.model.generate), (\n \"The model does not have recognizable generate function.\"\n \"Target must be given for attribution\"\n )\n\n if not gen_args:\n gen_args = DEFAULT_GEN_ARGS\n\n model_inp = self._format_model_input(inp.to_model_input())\n output_tokens = self.model.generate(model_inp, **gen_args)\n target_tokens = output_tokens[0][model_inp.size(1) :]\n else:\n assert gen_args is None, \"gen_args must be None when target is given\"\n\n if type(target) is str:\n # exclude sos\n target_tokens = self.tokenizer.encode(target)[1:]\n target_tokens = torch.tensor(target_tokens)\n elif type(target) is torch.Tensor:\n target_tokens = target\n\n attr = torch.zeros(\n [\n 1 + len(target_tokens) if self.include_per_token_attr else 1,\n inp.n_itp_features,\n ],\n dtype=torch.float,\n device=self.device,\n )\n\n for _ in range(num_trials):\n attr_input = inp.to_tensor().to(self.device)\n\n cur_attr = self.attr_method.attribute(\n attr_input,\n additional_forward_args=(inp, target_tokens, _inspect_forward),\n **kwargs,\n )\n\n # temp necessary due to FA & Shapley's different return shape of multi-task\n # FA will flatten output shape internally (n_output_token, n_itp_features)\n # Shapley will keep output shape (batch, n_output_token, n_input_features)\n cur_attr = cur_attr.reshape(attr.shape)\n\n attr += cur_attr\n\n attr = attr / num_trials\n\n attr = inp.format_attr(attr)\n\n return LLMAttributionResult(\n attr[0],\n attr[1:]\n if self.include_per_token_attr\n else None, # shape(n_output_token, n_input_features)\n inp.values,\n self.tokenizer.convert_ids_to_tokens(target_tokens),\n )\n```\n###test function signature:\n```python\n\n def test_llm_attr(self, AttrClass) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_llm_attr", + "reference": "\n def test_llm_attr(self, AttrClass) -> None:\n llm = DummyLLM()\n llm.to(self.device)\n tokenizer = DummyTokenizer()\n llm_attr = LLMAttribution(AttrClass(llm), tokenizer)\n\n inp = TextTemplateInput(\"{} b {} {} e {}\", [\"a\", \"c\", \"d\", \"f\"])\n res = llm_attr.attribute(inp, \"m n o p q\")\n\n self.assertEqual(res.seq_attr.shape, (4,))\n self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))\n self.assertEqual(res.input_tokens, [\"a\", \"c\", \"d\", \"f\"])\n self.assertEqual(res.output_tokens, [\"m\", \"n\", \"o\", \"p\", \"q\"])\n self.assertEqual(res.seq_attr.device.type, self.device)\n self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)", + "completions": [ + "```python\n def test_llm_attr(self, AttrClass) -> None:\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for L" + ], + "line": 72, + "token": 464, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 61, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def plan(self, task: str):\n agent_pool_description = self.agent_pool.get_agent_pool_description()\n feedback = self.get_last_feedback()\n last_plan = self.get_last_plan()\n\n # run the planning\n plan = self.planning.execute(task, agent_pool_description, last_plan, feedback)\n\n logger.info(f\"Plan: {plan}\")\n\n self.shared_memory.add(\n event_type=EventType.planning, agent=self.name, content=str(plan)\n )\n\n return plan\n```\n###test function signature:\n```python\n\ndef test_planner_succeeds(get_llm): # noqa: F811\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_planner_succeeds", + "reference": "\ndef test_planner_succeeds(get_llm): # noqa: F811\n llm = get_llm(__file__, test_planner_succeeds.__name__)\n\n physicist_description = (\n \"The physicist agent answers questions or research about physics-related topics\"\n )\n physicist = Physicist(\n name=\"Physicist\",\n description=physicist_description,\n llm=llm,\n )\n\n agent_pool = AgentPool()\n agent_pool.add_agents([physicist])\n\n shared_memory = SharedMemory(\n objective=\"Share the information across different agents.\",\n agent_pool=agent_pool,\n )\n\n planner = Planner(\n name=\"planner\",\n agent_pool=agent_pool,\n shared_memory=shared_memory,\n llm=llm,\n )\n\n task = \"\"\"We need to render a highly complex 3D image on the solar system. We can use any publicly avaliable\n resources to achieve this task.\"\"\" # noqa: E501\n\n p = planner.plan(task=task)\n assert len(p.steps) > 0", + "completions": [ + "```python\ndef test_planner_succeeds(get_llm): # noqa: F811\n planner = get_llm\n planner.agent_pool.set_agent_pool_description(\n {\n \"agent_pool\": {\n \"agents\": [\n {\n \"name\": \"agent1\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n {\n \"name\": \"agent2\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n ]\n }\n }\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task1\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task2\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task3\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task4\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task5\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task6\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task7\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\"," + ], + "line": 21, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_edge(self, from_step: str, to_step: str) -> None:\n if from_step not in self.G:\n raise ValueError(f\"Step with name '{from_step}' does not exist\")\n\n if to_step not in self.G:\n raise ValueError(f\"Step with name '{to_step}' does not exist\")\n\n if to_step in self.G[from_step]:\n raise ValueError(\n f\"There is already a edge from '{to_step}' to '{from_step}'\"\n )\n\n if to_step in nx.ancestors(self.G, from_step):\n raise ValueError(\n f\"Cannot add edge from '{from_step}' to '{to_step}' as it would create a cycle.\"\n )\n\n self.G.add_edge(from_step, to_step)\n```\n###test function signature:\n```python\n\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validate_missing_step_input", + "reference": "\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n class DummyStep3(Step):\n @property\n def inputs(self) -> List[str]:\n return [\"instruction\"]\n\n @property\n def outputs(self) -> List[str]:\n return [\"response\"]\n\n def process(self) -> \"StepOutput\": # type: ignore\n yield [{\"response\": \"response1\"}]\n\n dag = DAG()\n dag.add_step(dummy_generator_step)\n dag.add_step(DummyStep3(name=\"dummy_step_3\", pipeline=pipeline))\n dag.add_edge(\"dummy_generator_step\", \"dummy_step_3\")\n\n with pytest.raises(\n ValueError,\n match=\"Step 'dummy_step_3' should have a parameter with type hint `StepInput`\",\n ):\n dag.validate()", + "completions": [ + "```python\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_2\")\n```\n" + ], + "line": 26, + "token": 238, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 63, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_hourly_temp_mean", + "reference": "\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n meter_data = il_electricity_cdd_hdd_hourly[\"meter_data\"][\"2016-03-01\":\"2016-07-01\"]\n temperature_data = il_electricity_cdd_hdd_hourly[\"temperature_data\"][\n \"2016-03-01\":\"2016-07-01\"\n ]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert list(sorted(df.columns)) == [\n \"n_hours_dropped\",\n \"n_hours_kept\",\n \"temperature_mean\",\n ]\n assert df.shape == (2952, 3)\n\n assert round(df.temperature_mean.mean()) == 62.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n df = il_electricity_cdd_hdd_hourly.copy()\n df = df.iloc[100:150]\n df = df.assign(\n **{\n \"cdd_10\": np.maximum(df.temperature_mean - 10, 0),\n \"cdd_20\": np.maximum(df.temperature_mean - 20, 0),\n \"hdd_10\": np.maximum(10 - df.temperature_mean, 0),\n \"hdd_20\": np.maximum(20 - df.temperature_mean, 0),\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n df = df.drop([\"temperature_mean\"], axis=1)\n df = df.reindex(il_electricity_cdd_hdd_hourly.index)\n df = overwrite_partial_rows_with_nan(df)\n df = df.iloc[:-1].reindex(df.index)\n assert_frame_equal(df, il_electricity_cdd_hdd_hourly)\n```\n" + ], + "line": 167, + "token": 992, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 64, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef pandasSQL_builder(\n con,\n schema: str | None = None,\n need_transaction: bool = False,\n) -> PandasSQL:\n import sqlite3\n\n if isinstance(con, sqlite3.Connection) or con is None:\n return SQLiteDatabase(con)\n\n sqlalchemy = import_optional_dependency(\"sqlalchemy\", errors=\"ignore\")\n\n if isinstance(con, str) and sqlalchemy is None:\n raise ImportError(\"Using URI string without sqlalchemy installed.\")\n\n if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):\n return SQLDatabase(con, schema, need_transaction)\n\n warnings.warn(\n \"pandas only supports SQLAlchemy connectable (engine/connection) or \"\n \"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 \"\n \"objects are not tested. Please consider using SQLAlchemy.\",\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return SQLiteDatabase(con)\n```\n###test function signature:\n```python\n\n def test_execute_closed_connection(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_execute_closed_connection", + "reference": "\n def test_execute_closed_connection(self):\n create_sql = \"\"\"\n CREATE TABLE test\n (\n a TEXT,\n b TEXT,\n c REAL,\n PRIMARY KEY (a, b)\n );\n \"\"\"\n with contextlib.closing(sqlite3.connect(\":memory:\")) as conn:\n cur = conn.cursor()\n cur.execute(create_sql)\n\n with sql.pandasSQL_builder(conn) as pandas_sql:\n pandas_sql.execute('INSERT INTO test VALUES(\"foo\", \"bar\", 1.234)')\n\n msg = \"Cannot operate on a closed database.\"\n with pytest.raises(sqlite3.ProgrammingError, match=msg):\n tquery(\"select * from test\", con=conn)", + "completions": [ + "```python\n def test_execute_closed_connection(self):\n from pandas.io.sql import PandasSQL\n from pandas.io.sql.sqldatabase import SQLiteDatabase\n from pandas.io.sql.pandas_sql import PandasSQLQueryCompiler\n from pandas.io.sql.pandas_sql import PandasSQLEngine\n from pandas.io.sql.pandas_sql import PandasSQLTable\n from pandas.io.sql.pandas_sql import PandasSQLConverter\n from pandas.io.sql.pandas_sql import PandasSQLAlchemy\n from pandas.io.sql.pandas_sql import PandasSQLQuery\n from pandas.io.sql.pandas_sql import PandasSQLExecutionContext\n from pandas.io.sql.pandas_sql import PandasSQLCompiler\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas." + ], + "line": 31, + "token": 237, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 65, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_hourly_degree_days", + "reference": "\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index,\n temperature_data,\n heating_balance_points=[60, 61],\n cooling_balance_points=[65, 66],\n temperature_mean=False,\n degree_day_method=\"hourly\",\n )\n assert df.shape == (27, 6)\n assert list(sorted(df.columns)) == [\n \"cdd_65\",\n \"cdd_66\",\n \"hdd_60\",\n \"hdd_61\",\n \"n_hours_dropped\",\n \"n_hours_kept\",\n ]\n snapshot.assert_match(\n [\n round(df.hdd_60.mean(), 2),\n round(df.hdd_61.mean(), 2),\n round(df.cdd_65.mean(), 2),\n round(df.cdd_66.mean(), 2),\n round(df.n_hours_kept.mean(), 2),\n round(df.n_hours_dropped.mean(), 2),\n ],\n \"values\",\n )", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"H\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 1" + ], + "line": 168, + "token": 985, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 66, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef initialize_model_parallel(\n tensor_model_parallel_size: int = 1,\n pipeline_model_parallel_size: int = 1,\n virtual_pipeline_model_parallel_size: Optional[int] = None,\n pipeline_model_parallel_split_rank: Optional[int] = None,\n use_fp8: bool = False,\n) -> None:\n # Get world size and rank. Ensure some consistencies.\n assert torch.distributed.is_initialized()\n world_size: int = torch.distributed.get_world_size()\n\n if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:\n raise RuntimeError(\n f\"world_size ({world_size}) is not divisible by tensor_model_parallel_size \"\n f\"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n )\n\n data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n\n num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n num_data_parallel_groups: int = world_size // data_parallel_size\n\n if virtual_pipeline_model_parallel_size is not None:\n if not pipeline_model_parallel_size > 2:\n raise RuntimeError(\"pipeline-model-parallel size should be greater than 2 with \" \"interleaved schedule\")\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size\n\n if pipeline_model_parallel_split_rank is not None:\n global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK\n _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank\n\n rank = torch.distributed.get_rank()\n\n # Build the data-parallel groups.\n global _DATA_PARALLEL_GROUP\n global _DATA_PARALLEL_GROUP_GLOO\n global _DATA_PARALLEL_GLOBAL_RANKS\n assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'\n all_data_parallel_group_ranks = []\n for i in range(pipeline_model_parallel_size):\n start_rank = i * num_pipeline_model_parallel_groups\n end_rank = (i + 1) * num_pipeline_model_parallel_groups\n for j in range(tensor_model_parallel_size):\n ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)\n all_data_parallel_group_ranks.append(list(ranks))\n group = torch.distributed.new_group(ranks)\n group_gloo = torch.distributed.new_group(ranks, backend=\"gloo\")\n if rank in ranks:\n _DATA_PARALLEL_GROUP = group\n _DATA_PARALLEL_GROUP_GLOO = group_gloo\n _DATA_PARALLEL_GLOBAL_RANKS = ranks\n\n # Build the model-parallel groups.\n global _MODEL_PARALLEL_GROUP\n assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'\n for i in range(data_parallel_size):\n ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _MODEL_PARALLEL_GROUP = group\n\n # Build the tensor model-parallel groups.\n global _TENSOR_MODEL_PARALLEL_GROUP\n assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized'\n for i in range(num_tensor_model_parallel_groups):\n ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _TENSOR_MODEL_PARALLEL_GROUP = group\n\n # Build the pipeline model-parallel groups and embedding groups\n # (first and last rank in each pipeline model-parallel group).\n global _PIPELINE_MODEL_PARALLEL_GROUP\n global _PIPELINE_GLOBAL_RANKS\n assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized'\n global _EMBEDDING_GROUP\n global _EMBEDDING_GLOBAL_RANKS\n assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'\n global _POSITION_EMBEDDING_GROUP\n global _POSITION_EMBEDDING_GLOBAL_RANKS\n assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'\n for i in range(num_pipeline_model_parallel_groups):\n ranks = range(i, world_size, num_pipeline_model_parallel_groups)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _PIPELINE_MODEL_PARALLEL_GROUP = group\n _PIPELINE_GLOBAL_RANKS = ranks\n # Setup embedding group (to exchange gradients between\n # first and last stages).\n if len(ranks) > 1:\n embedding_ranks = [ranks[0], ranks[-1]]\n position_embedding_ranks = [ranks[0]]\n if pipeline_model_parallel_split_rank is not None:\n if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:\n embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]]\n if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:\n position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]\n else:\n embedding_ranks = ranks\n position_embedding_ranks = ranks\n\n group = torch.distributed.new_group(embedding_ranks)\n if rank in embedding_ranks:\n _EMBEDDING_GROUP = group\n if rank in ranks:\n _EMBEDDING_GLOBAL_RANKS = embedding_ranks\n\n group = torch.distributed.new_group(position_embedding_ranks)\n if rank in position_embedding_ranks:\n _POSITION_EMBEDDING_GROUP = group\n if rank in ranks:\n _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks\n\n # Build the FP8 groups.\n global _AMAX_REDUCTION_GROUP\n assert _AMAX_REDUCTION_GROUP is None, \\\n 'FP8 amax reduction group is already initialized'\n if use_fp8:\n amax_group_size: int = tensor_model_parallel_size * data_parallel_size\n num_amax_groups: int = world_size // amax_group_size\n for i in range(num_amax_groups):\n start_rank = i * amax_group_size\n end_rank = (i + 1) * amax_group_size\n ranks = range(start_rank, end_rank)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _AMAX_REDUCTION_GROUP = group\n\n # Initialize global memory buffer\n # This isn't really \"parallel state\" but there isn't another good place to\n # put this. If we end up with a more generic initialization of megatron-core\n # we could stick it there\n _set_global_memory_buffer()\n```\n###test function signature:\n```python\n\ndef test_initialize__and_destroy_model_parallel():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_initialize__and_destroy_model_parallel", + "reference": "\ndef test_initialize__and_destroy_model_parallel():\n with pytest.raises(AssertionError):\n assert(ps.initialize_model_parallel())\n Utils.initialize_distributed()\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))\n Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)\n\n assert(ps.model_parallel_is_initialized())\n assert(ps.get_model_parallel_group() is not None)\n assert(ps.get_tensor_model_parallel_group() is not None)\n assert(ps.get_pipeline_model_parallel_group() is not None)\n assert(ps.get_data_parallel_group() is not None) \n Utils.destroy_model_parallel()\n assert(ps._MODEL_PARALLEL_GROUP is None)", + "completions": [ + "```python\ndef test_initialize_and_destroy_model_parallel():\n initialize_model_parallel(\n tensor_model_parallel_size=1,\n pipeline_model_parallel_size=1,\n virtual_pipeline_model_parallel_size=None,\n pipeline_model_parallel_split_rank=None,\n use_fp8=False,\n )\n destroy_model_parallel()\n```\n" + ], + "line": 133, + "token": 886, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 67, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_shapes(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_shapes", + "reference": "\n def test_spectrogram_shapes(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 729))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 183))", + "completions": [ + "```python\n def test_spectrogram_shapes(self):\n waveform = np.random.rand(1000)\n window = np.hanning(100)\n frame_length = 100\n hop_length = 10\n fft_length = 100\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (10, 100))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 68, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_base64_image_given_and_no_resize_needed", + "reference": "\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n _, image = example_local_image\n _, buffer = cv2.imencode(\".jpg\", image)\n base64_image = base64.b64encode(buffer).decode(\"ascii\")\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=base64_image,\n max_width=None,\n max_height=None,\n )\n\n # then\n assert (\n scaling_factor is None\n ), \"No resize parameters given, so scaling factor should not be established\"\n assert (\n serialised_image == base64_image\n ), \"Serialised image should be identical with input\"", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n image_bytes = base64.b64decode(example_local_image[0])\n image = bytes_to_opencv_image(payload=image_bytes)\n # when\n result = load_image_from_string(\n reference=example_local_image[0],\n max_height=None,\n max_width=None,\n )\n # then\n assert result[0] == example_local_image[0]\n assert result[1] is None\n assert np.array_equal(image, result[1])\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 69, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def submit_payout_address(self, robot_index=1):\n path = reverse(\"order\")\n params = f\"?order_id={self.order_id}\"\n headers = self.get_robot_auth(robot_index)\n\n payout_address = create_address(\"robot\")\n signed_payout_address = sign_message(\n payout_address,\n passphrase_path=f\"tests/robots/{robot_index}/token\",\n private_key_path=f\"tests/robots/{robot_index}/enc_priv_key\",\n )\n body = {\n \"action\": \"update_address\",\n \"address\": signed_payout_address,\n \"mining_fee_rate\": 50,\n }\n self.response = self.client.post(path + params, body, **headers)\n```\n###test function signature:\n```python\n def test_successful_onchain(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_successful_onchain", + "reference": " def test_successful_onchain(self):\n trade = Trade(self.client)\n trade.publish_order()\n trade.take_order()\n trade.lock_taker_bond()\n trade.lock_escrow(trade.taker_index)\n trade.submit_payout_address(trade.maker_index)\n trade.confirm_fiat(trade.maker_index)\n trade.confirm_fiat(trade.taker_index)\n\n trade.process_payouts(mine_a_block=True)\n trade.get_order(trade.maker_index)\n\n data = trade.response.json()\n\n self.assertEqual(trade.response.status_code, 200)\n self.assertResponse(trade.response)\n\n self.assertEqual(data[\"status_message\"], Order.Status(Order.Status.SUC).label)\n self.assertTrue(data[\"is_fiat_sent\"])\n self.assertFalse(data[\"is_disputed\"])\n self.assertIsInstance(data[\"maker_summary\"][\"address\"], str)\n self.assertIsHash(data[\"maker_summary\"][\"txid\"])\n\n self.assert_order_logs(data[\"id\"])", + "completions": [ + "```python\n def test_successful_onchain(self):\n self.submit_payout_address(robot_index=1)\n self.assertEqual(self.response.status_code, 200)\n self.assertEqual(self.response.json()[\"status\"], \"success\")\n```\n" + ], + "line": 26, + "token": 183, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 70, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_blank_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_blank_line", + "reference": "\ndef test_header_to_metadata_and_cell_blank_line():\n text = \"\"\"---\ntitle: Sample header\n---\n\nHeader is followed by a blank line\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {}\n assert lines[pos].startswith(\"Header is\")", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_blank_line():\n lines = [\"#' ---\", \"\"]\n header_prefix = \"#'\"\n header_suffix = \"\"\n ext = None\n root_level_metadata_as_raw_cell = True\n metadata, jupyter, cell, i = header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext, root_level_metadata_as_raw_cell\n )\n assert metadata == {}\n assert jupyter == []\n assert cell == new_raw_cell(source=\"\\n\".join([\"---\"]), metadata={\"lines_to_next_cell\": 1})\n assert i == 1\n```\n" + ], + "line": 104, + "token": 642, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 71, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def _create_output(\n self,\n accepted: torch.Tensor, # [batch_size, k]\n recovered_token_ids: torch.Tensor, # [batch_size, k]\n draft_token_ids: torch.Tensor, # [batch_size, k]\n bonus_token_ids: torch.Tensor, # [batch_size]\n ) -> torch.Tensor:\n bonus_token_ids = bonus_token_ids.squeeze()\n batch_size, k = recovered_token_ids.shape\n\n # Determine the index of the first False value for each row.\n limits = (accepted == 0).max(1).indices\n limits[~(accepted == 0).any(1)] = k\n\n # Create masks using the indices.\n indices = torch.arange(k, device=accepted.device).unsqueeze(0)\n accepted_mask = indices < limits.unsqueeze(1)\n after_false_mask = indices == limits.unsqueeze(1)\n\n # Create an extended output tensor\n output_with_bonus_tokens = -torch.ones(\n (batch_size, k + self._num_bonus_tokens),\n dtype=self.token_id_dtype,\n device=accepted.device)\n output = output_with_bonus_tokens[:, :k]\n\n # Fill in the first k columns of the output tensor using masks and data\n # tensors.\n output[:, :k] = torch.where(accepted_mask, draft_token_ids,\n -torch.ones_like(draft_token_ids))\n\n # Fill the last column.\n # We check output directly as accepted may have True values inconsistent\n # with causal acceptance.\n output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,\n bonus_token_ids, -1)\n\n # We disable bonus tokens because it causes corrupt KV cache for\n # proposal methods that require KV cache. We can fix it by \"prefilling\"\n # the bonus token in the proposer. The following issue tracks the fix.\n # https://github.com/vllm-project/vllm/issues/4212\n output_with_bonus_tokens[:, -1] = -1\n\n # Fill the recovered token ids.\n output.mul_(~after_false_mask).add_(\n recovered_token_ids.mul(after_false_mask))\n\n self.num_accepted_tokens += accepted.sum()\n self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()\n self.num_draft_tokens += batch_size * k\n\n return output_with_bonus_tokens\n```\n###test function signature:\n```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct_output_format", + "reference": "def test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n set_random_seed(seed)\n torch.set_default_device(device)\n\n batch_size = 10\n k = 5\n vocab_size = 3000\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"no_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"some_tokens_accepted\":\n last_accepted_indices = torch.randint(low=-1,\n high=k,\n size=(batch_size, ))\n accepted = mock_causal_accepted_tensor(k, last_accepted_indices)\n else:\n raise AssertionError()\n\n recovered_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n draft_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n bonus_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, 1),\n dtype=torch.int64)\n\n rejection_sampler = RejectionSampler()\n rejection_sampler.init_gpu_tensors(rank=0)\n output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access\n accepted,\n recovered_token_ids,\n draft_token_ids,\n bonus_token_ids,\n )\n\n # Bonus tokens are currently disabled. Verify they're set to -1.\n # See https://github.com/vllm-project/vllm/issues/4212\n expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n # Expect all tokens to be equal to draft tokens.\n assert torch.equal(output_token_ids[:, :-1], draft_token_ids)\n\n # Expect all bonus tokens to be included.\n assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)\n elif which_tokens_accepted == \"no_tokens_accepted\":\n # Expect first token to be equal to recovered tokens.\n assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])\n\n # Expect everything else to be -1.\n assert torch.equal(output_token_ids[:, 1:],\n torch.ones_like(output_token_ids[:, 1:]) * -1)\n elif which_tokens_accepted == \"some_tokens_accepted\":\n recovered_plus_bonus = torch.cat(\n (recovered_token_ids, expected_bonus_token_ids), dim=-1)\n # Assert first rejected token is a recovered token or bonus token.\n assert torch.equal(\n recovered_plus_bonus[torch.arange(0, batch_size),\n last_accepted_indices + 1],\n output_token_ids[torch.arange(0, batch_size),\n last_accepted_indices + 1])\n\n # Assert every subsequent token is -1.\n subsequent_mask = torch.arange(0, k + 1).expand(\n batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)\n assert torch.all(output_token_ids[subsequent_mask] == -1)", + "completions": [ + "```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n torch.manual_seed(seed)\n batch_size = 10\n k = 10\n num_bonus_tokens = 1\n accepted = torch.randint(0, 2, (batch_size, k), dtype=torch.int32)\n recovered_token_ids = torch.randint(0, 10, (batch_size, k),\n dtype=torch.int32)\n draft_token_ids = torch.randint(0, 10, (batch_size, k), dtype=torch.int32)\n bonus_token_ids = torch.randint(0, 10, (batch_size,), dtype=torch.int32)\n output = _create_output(accepted, recovered_token_ids, draft_token_ids,\n bonus_token_ids)\n assert output.shape == (batch_size, k + num_bonus_tokens)\n assert output.dtype == torch.int32\n assert output.device == device\n assert (output[:, :k] == torch.where(accepted, draft_token_ids,\n -torch.ones_like(draft_token_ids))).all()\n assert (output[:, -1] == torch.where(output[:, -1] != -1, bonus_token_ids,\n -1)).all()\n assert (output[:, -1] == -1).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted" + ], + "line": 54, + "token": 433, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 72, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_array(fp, allow_pickle=False, pickle_kwargs=None, *,\n max_header_size=_MAX_HEADER_SIZE):\n if allow_pickle:\n # Effectively ignore max_header_size, since `allow_pickle` indicates\n # that the input is fully trusted.\n max_header_size = 2**64\n\n version = read_magic(fp)\n _check_version(version)\n shape, fortran_order, dtype = _read_array_header(\n fp, version, max_header_size=max_header_size)\n if len(shape) == 0:\n count = 1\n else:\n count = numpy.multiply.reduce(shape, dtype=numpy.int64)\n\n # Now read the actual data.\n if dtype.hasobject:\n # The array contained Python objects. We need to unpickle the data.\n if not allow_pickle:\n raise ValueError(\"Object arrays cannot be loaded when \"\n \"allow_pickle=False\")\n if pickle_kwargs is None:\n pickle_kwargs = {}\n try:\n array = pickle.load(fp, **pickle_kwargs)\n except UnicodeError as err:\n # Friendlier error message\n raise UnicodeError(\"Unpickling a python object failed: %r\\n\"\n \"You may need to pass the encoding= option \"\n \"to numpy.load\" % (err,)) from err\n else:\n if isfileobj(fp):\n # We can use the fast fromfile() function.\n array = numpy.fromfile(fp, dtype=dtype, count=count)\n else:\n # This is not a real file. We have to read it the\n # memory-intensive way.\n # crc32 module fails on reads greater than 2 ** 32 bytes,\n # breaking large reads from gzip streams. Chunk reads to\n # BUFFER_SIZE bytes to avoid issue and reduce memory overhead\n # of the read. In non-chunked case count < max_read_count, so\n # only one read is performed.\n\n # Use np.ndarray instead of np.empty since the latter does\n # not correctly instantiate zero-width string dtypes; see\n # https://github.com/numpy/numpy/pull/6430\n array = numpy.ndarray(count, dtype=dtype)\n\n if dtype.itemsize > 0:\n # If dtype.itemsize == 0 then there's nothing more to read\n max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)\n\n for i in range(0, count, max_read_count):\n read_count = min(max_read_count, count - i)\n read_size = int(read_count * dtype.itemsize)\n data = _read_bytes(fp, read_size, \"array data\")\n array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n count=read_count)\n\n if fortran_order:\n array.shape = shape[::-1]\n array = array.transpose()\n else:\n array.shape = shape\n\n return array\n```\n###test function signature:\n```python\n\ndef test_version_2_0():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_version_2_0", + "reference": "\ndef test_version_2_0():\n f = BytesIO()\n # requires more than 2 byte for header\n dt = [((\"%d\" % i) * 100, float) for i in range(500)]\n d = np.ones(1000, dtype=dt)\n\n format.write_array(f, d, version=(2, 0))\n with warnings.catch_warnings(record=True) as w:\n warnings.filterwarnings('always', '', UserWarning)\n format.write_array(f, d)\n assert_(w[0].category is UserWarning)\n\n # check alignment of data portion\n f.seek(0)\n header = f.readline()\n assert_(len(header) % format.ARRAY_ALIGN == 0)\n\n f.seek(0)\n n = format.read_array(f, max_header_size=200000)\n assert_array_equal(d, n)\n\n # 1.0 requested but data cannot be saved this way\n assert_raises(ValueError, format.write_array, f, d, (1, 0))", + "completions": [ + "```python\ndef test_version_2_0():\n fp = open(\"test_file.npy\", \"rb\")\n read_array(fp)\n```\n" + ], + "line": 70, + "token": 516, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 73, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):\n if not torch.is_tensor(audio):\n if isinstance(audio, str):\n audio = load_audio(audio)\n audio = torch.from_numpy(audio)\n\n window = torch.hann_window(N_FFT).to(audio.device)\n stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)\n magnitudes = stft[..., :-1].abs() ** 2\n\n filters = mel_filters(audio.device, n_mels)\n mel_spec = filters @ magnitudes\n\n log_spec = torch.clamp(mel_spec, min=1e-10).log10()\n log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)\n log_spec = (log_spec + 4.0) / 4.0\n return log_spec\n```\n###test function signature:\n```python\n\ndef test_audio():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_audio", + "reference": "\ndef test_audio():\n audio_path = os.path.join(os.path.dirname(__file__), \"jfk.flac\")\n audio = load_audio(audio_path)\n assert audio.ndim == 1\n assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12\n assert 0 < audio.std() < 1\n\n mel_from_audio = log_mel_spectrogram(audio)\n mel_from_file = log_mel_spectrogram(audio_path)\n\n assert np.allclose(mel_from_audio, mel_from_file)\n assert mel_from_audio.max() - mel_from_audio.min() <= 2.0", + "completions": [ + "```python\ndef test_audio():\n audio = torch.randn(1, 16000)\n log_mel_spectrogram(audio)\n```\n" + ], + "line": 24, + "token": 213, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 74, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef transform_grid_cell(\n grid: Grid,\n cell: npt.NDArray[np.float_] = np.eye(3),\n output_shape: Optional[Tuple[int, int, int]] = None,\n mode: str = \"constant\",\n order: int = 1,\n **kwargs,\n) -> Grid:\n # Take the current shape of the grid if no output shape was provided\n if output_shape is None:\n output_shape = grid.shape\n\n # Make sure the cell has type float\n cell = np.asarray(cell, dtype=float)\n\n # Get the current cell in coordinates of the destination axes\n inv_cell = cell_invert(cell).T\n projected_cell = grid.cell.dot(inv_cell)\n\n # From that, infere how long will the bounding box of the cell be\n lengths = abs(projected_cell).sum(axis=0)\n\n # Create the transformation matrix. Since we want to control the shape\n # of the output, we can not use grid.dcell directly, we need to modify it.\n scales = output_shape / lengths\n forward_t = (grid.dcell.dot(inv_cell) * scales).T\n\n # Scipy's affine transform asks for the inverse transformation matrix, to\n # map from output pixels to input pixels. By taking the inverse of our\n # transformation matrix, we get exactly that.\n tr = cell_invert(forward_t).T\n\n # Calculate the offset of the image so that all points of the grid \"fall\" inside\n # the output array.\n # For this we just calculate the centers of the input and output images\n center_input = 0.5 * (_a.asarrayd(grid.shape) - 1)\n center_output = 0.5 * (_a.asarrayd(output_shape) - 1)\n\n # And then make sure that the input center that is interpolated from the output\n # falls in the actual input's center\n offset = center_input - tr.dot(center_output)\n\n # We pass all the parameters to scipy's affine_transform\n transformed_image = affine_transform(\n grid.grid,\n tr,\n order=1,\n offset=offset,\n output_shape=output_shape,\n mode=mode,\n **kwargs,\n )\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1))\n new_grid.grid = transformed_image\n new_grid.geometry = grid.geometry\n new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset))\n\n # Find the offset between the origin before and after the transformation\n return new_grid\n```\n###test function signature:\n```python\n\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_grid_cell", + "reference": "\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10))\n\n assert new_grid.shape == (10, 10, 10)\n assert new_grid.lattice.is_cartesian()\n\n if not skewed:\n assert np.allclose(new_grid.lattice.cell, grid.lattice.cell)\n\n assert np.allclose(new_grid.grid, grid.grid) == (not skewed)\n assert not np.allclose(new_grid.grid, 0)\n\n assert np.allclose(new_grid.origin, grid.origin) == (not skewed)\n\n # Convert to a skewed cell\n directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]])\n new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5))\n\n assert new_grid.shape == (5, 5, 5)\n for i in range(3):\n n = new_grid.lattice.cell[i] / directions[i]\n assert np.allclose(n, n[0])", + "completions": [ + "```python\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n cell = grid.cell.dot(cell_invert(grid.cell))\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = transform_grid_cell(grid, cell, grid.shape, mode=\"constant\", order=1)\n\n # Check that the new grid has the same shape and cell as the original grid\n assert new_grid.shape == grid.shape\n assert np.allclose(new_grid.cell, grid.cell)\n\n # Check that the new grid has the same geometry as the original grid\n assert new_grid.geometry == grid.geometry\n\n # Check that the new grid has the same origin as the original grid\n assert np.allclose(new_grid.origin, grid.origin)\n\n # Check that the new grid has the same dcell as the original grid\n assert np.allclose(new_grid.dcell, grid.dcell)\n\n # Check that the new grid has the same grid as the original grid\n assert np.allclose(new_grid.grid, grid.grid)\n\n # Check that the new grid has the same skewed as the original grid\n assert np.allclose(new_grid.skewed, grid.skewed)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid" + ], + "line": 63, + "token": 533, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 75, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:\n\n index_norm = norm(index, axis=1)\n xq_norm = norm(xq.T)\n sim = np.dot(index, xq.T) / (index_norm * xq_norm)\n return sim\n```\n###test function signature:\n```python\ndef test_top_scores__is_sorted(test_index):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_top_scores__is_sorted", + "reference": "def test_top_scores__is_sorted(test_index):\n\n xq = test_index[0] # should have max similarity\n\n sim = similarity_matrix(xq, test_index)\n _, idx = top_scores(sim, 3)\n\n # Scores and indexes should be sorted ascending\n assert np.array_equal(idx, np.array([2, 1, 0]))", + "completions": [ + "```python\ndef test_top_scores__is_sorted(test_index):\n sim = similarity_matrix(xq, test_index)\n assert np.all(np.diff(sim, axis=1) >= 0)\n```\n" + ], + "line": 15, + "token": 125, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_mixed_ref_area_2", + "reference": "\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n\n out_bars_with_mixed_ref_area_2 = _generate_nanoplot(**CASES[16])\n\n assert _is_nanoplot_output(out_bars_with_mixed_ref_area_2)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n y_ref_area = [1, 10]\n y_ref_line = 5\n x_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n expand_x = \"2020-01-01 00:00:00 UTC\"\n expand_y = \"2020-01-01 00:00:00 UTC\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n all_single_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_" + ], + "line": 749, + "token": 7038, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 77, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def mean(self, lower=0.0, upper=np.inf):\n loc = []\n for idx in self.index:\n qpd = self.qpd.loc[idx, :].values[0]\n l, _ = quad(exp_func, args=(qpd), a=lower, b=upper)\n loc.append(l)\n loc_arr = np.array(loc)\n return pd.DataFrame(loc_arr, index=self.index, columns=self.columns)\n```\n###test function signature:\n```python\ndef test_qpd_b_simple_use():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_qpd_b_simple_use", + "reference": "def test_qpd_b_simple_use():\n from skpro.distributions.qpd import QPD_B\n\n qpd = QPD_B(\n alpha=0.2,\n qv_low=[1, 2],\n qv_median=[3, 4],\n qv_high=[5, 6],\n lower=0,\n upper=10,\n )\n\n qpd.mean()", + "completions": [ + "```python\ndef test_qpd_b_simple_use():\n import numpy as np\n import pandas as pd\n from scipy.integrate import quad\n from scipy.special import expit\n\n def exp_func(x, qpd):\n return expit(qpd * x)\n\n qpd = pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z'])\n qpd_b = qpd.mean(lower=0.0, upper=np.inf)\n assert qpd_b.equals(pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z']))\n```\n" + ], + "line": 18, + "token": 146, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef create(\n supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False\n) -> adapters.HTTPAdapter:\n settings = supply_settings()\n allowed_methods = set(util.Retry.DEFAULT_ALLOWED_METHODS)\n if include_post:\n allowed_methods.add(\"POST\")\n\n retries = util.Retry(\n total=settings.total,\n backoff_factor=settings.backoff_factor,\n status_forcelist=settings.statuses,\n allowed_methods=allowed_methods,\n )\n return adapters.HTTPAdapter(max_retries=retries)\n```\n###test function signature:\n```python\n\ndef test_should_create_adapter(include_post):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_should_create_adapter", + "reference": "\ndef test_should_create_adapter(include_post):\n total_retries = 5\n backoff_factor = 0.4\n statuses = (random.randint(400, 600),)\n settings = config.ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses)\n adapter = http_adapter.create(supply_settings=lambda: settings, include_post=include_post)\n assert adapter.max_retries.total == total_retries\n assert adapter.max_retries.backoff_factor == backoff_factor\n assert adapter.max_retries.status_forcelist == statuses\n allowed_methods = adapter.max_retries.allowed_methods or []\n if include_post:\n assert \"POST\" in allowed_methods\n else:\n assert \"POST\" not in allowed_methods", + "completions": [ + "```python\n\ndef test_should_create_adapter(include_post):\n assert create(include_post=include_post) == adapters.HTTPAdapter(max_retries=util.Retry(\n total=config.ResilienceSettings().total,\n backoff_factor=config.ResilienceSettings().backoff_factor,\n status_forcelist=config.ResilienceSettings().statuses,\n allowed_methods=set(util.Retry.DEFAULT_ALLOWED_METHODS) if not include_post else set(util.Retry.DEFAULT_ALLOWED_METHODS).union({\"POST\"})\n ))\n```\n" + ], + "line": 24, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 79, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:\n summary_path = os.path.join(trial_path, 'summary.csv')\n if not os.path.exists(summary_path):\n raise ValueError(f\"summary.csv does not exist in {trial_path}.\")\n trial_summary_df = load_summary_file(summary_path, dict_columns=['best_module_params'])\n config_yaml_path = os.path.join(trial_path, 'config.yaml')\n with open(config_yaml_path, 'r') as f:\n config_dict = yaml.safe_load(f)\n yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)\n if output_path is not None:\n with open(output_path, 'w') as f:\n yaml.dump(yaml_dict, f)\n return yaml_dict\n```\n###test function signature:\n```python\n\ndef test_runner(evaluator):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_runner", + "reference": "\ndef test_runner(evaluator):\n os.environ['BM25'] = 'bm25'\n evaluator.start_trial(os.path.join(resource_dir, 'simple.yaml'))\n project_dir = evaluator.project_dir\n\n def runner_test(runner: Runner):\n answer = runner.run('What is the best movie in Korea? Have Korea movie ever won Oscar?',\n 'retrieved_contents')\n assert len(answer) == 1\n assert isinstance(answer, list)\n assert isinstance(answer[0], str)\n\n runner = Runner.from_trial_folder(os.path.join(project_dir, '0'))\n runner_test(runner)\n runner_test(runner)\n\n with tempfile.NamedTemporaryFile(suffix='yaml', mode='w+t') as yaml_path:\n extract_best_config(os.path.join(project_dir, '0'), yaml_path.name)\n runner = Runner.from_yaml(yaml_path.name, project_dir=project_dir)\n runner_test(runner)", + "completions": [ + "```python\ndef test_runner(evaluator):\n def test_extract_best_config():\n trial_path = 'path/to/trial'\n output_path = 'path/to/output'\n yaml_dict = extract_best_config(trial_path, output_path)\n assert yaml_dict is not None\n assert isinstance(yaml_dict, dict)\n assert 'best_module_params' in yaml_dict\n assert isinstance(yaml_dict['best_module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['" + ], + "line": 23, + "token": 205, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 80, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def init_weights(self):\n logger = MMLogger.get_current_instance()\n if self.init_cfg is None:\n logger.warn(f'No pre-trained weights for '\n f'{self.__class__.__name__}, '\n f'training start from scratch')\n if self.use_abs_pos_embed:\n trunc_normal_(self.absolute_pos_embed, std=0.02)\n for m in self.modules():\n if isinstance(m, nn.Linear):\n trunc_normal_init(m, std=.02, bias=0.)\n elif isinstance(m, nn.LayerNorm):\n constant_init(m, 1.0)\n else:\n assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n f'specify `Pretrained` in ' \\\n f'`init_cfg` in ' \\\n f'{self.__class__.__name__} '\n ckpt = CheckpointLoader.load_checkpoint(\n self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n if 'state_dict' in ckpt:\n _state_dict = ckpt['state_dict']\n elif 'model' in ckpt:\n _state_dict = ckpt['model']\n else:\n _state_dict = ckpt\n if self.convert_weights:\n # supported loading weight from original repo,\n _state_dict = swin_converter(_state_dict)\n\n state_dict = OrderedDict()\n for k, v in _state_dict.items():\n if k.startswith('backbone.'):\n state_dict[k[9:]] = v\n\n # strip prefix of state_dict\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n # reshape absolute position embedding\n if state_dict.get('absolute_pos_embed') is not None:\n absolute_pos_embed = state_dict['absolute_pos_embed']\n N1, L, C1 = absolute_pos_embed.size()\n N2, C2, H, W = self.absolute_pos_embed.size()\n if N1 != N2 or C1 != C2 or L != H * W:\n logger.warning('Error in loading absolute_pos_embed, pass')\n else:\n state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n # interpolate position bias table if needed\n relative_position_bias_table_keys = [\n k for k in state_dict.keys()\n if 'relative_position_bias_table' in k\n ]\n for table_key in relative_position_bias_table_keys:\n table_pretrained = state_dict[table_key]\n table_current = self.state_dict()[table_key]\n L1, nH1 = table_pretrained.size()\n L2, nH2 = table_current.size()\n if nH1 != nH2:\n logger.warning(f'Error in loading {table_key}, pass')\n elif L1 != L2:\n S1 = int(L1**0.5)\n S2 = int(L2**0.5)\n table_pretrained_resized = F.interpolate(\n table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n size=(S2, S2),\n mode='bicubic')\n state_dict[table_key] = table_pretrained_resized.view(\n nH2, L2).permute(1, 0).contiguous()\n\n # load state_dict\n self.load_state_dict(state_dict, False)\n```\n###test function signature:\n```python\ndef test_swin_transformer():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_swin_transformer", + "reference": "def test_swin_transformer():\n\n with pytest.raises(TypeError):\n # Pretrained arg must be str or None.\n SwinTransformer(pretrained=123)\n\n with pytest.raises(AssertionError):\n # Because swin uses non-overlapping patch embed, so the stride of patch\n # embed must be equal to patch size.\n SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)\n\n # test pretrained image size\n with pytest.raises(AssertionError):\n SwinTransformer(pretrain_img_size=(224, 224, 224))\n\n # Test absolute position embedding\n temp = torch.randn((1, 3, 224, 224))\n model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)\n model.init_weights()\n model(temp)\n\n # Test patch norm\n model = SwinTransformer(patch_norm=False)\n model(temp)\n\n # Test normal inference\n temp = torch.randn((1, 3, 32, 32))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 31, 31))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 112, 137))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 28, 35)\n assert outs[1].shape == (1, 192, 14, 18)\n assert outs[2].shape == (1, 384, 7, 9)\n assert outs[3].shape == (1, 768, 4, 5)\n\n model = SwinTransformer(frozen_stages=4)\n model.train()\n for p in model.parameters():\n assert not p.requires_grad", + "completions": [ + "```python\ndef test_swin_transformer():\n from mmcv.cnn import (build_pos_embed, build_pos_embed_sine,\n constant_init, trunc_normal_init)\n from mmcv.runner import CheckpointLoader\n from mmdet.models.builder import build_backbone\n from mmdet.models.utils import swin_converter\n from mmdet.utils import get_root_logger\n from mmdet.utils.logger import MMLogger\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import PatchEmbed\n from mmdet.models.backbones.swin import PatchMerging\n from mmdet.models.backbones.swin import BasicLayer\n from mmdet.models.backbones.swin import BasicBlock\n from mmdet.models.backbones.swin import DepthwiseSeparableConv\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransform" + ], + "line": 79, + "token": 586, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 81, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\n\ndef test_date_columns_equal():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_date_columns_equal", + "reference": "\ndef test_date_columns_equal():\n data = \"\"\"a|b|expected\n2017-01-01|2017-01-01|True\n2017-01-02|2017-01-02|True\n2017-10-01|2017-10-10|False\n2017-01-01||False\n|2017-01-01|False\n||True\"\"\"\n df = pl.read_csv(\n io.StringIO(data),\n separator=\"|\",\n null_values=[\"NULL\"],\n missing_utf8_is_empty_string=True,\n )\n # First compare just the strings\n actual_out = columns_equal(df[\"a\"], df[\"b\"], rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)\n\n # Then compare converted to datetime objects\n col_a = df[\"a\"].str.to_datetime(strict=False)\n col_b = df[\"b\"].str.to_datetime(strict=False)\n actual_out = columns_equal(col_a, col_b, rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)", + "completions": [ + "```python\ndef test_date_columns_equal():\n col_1 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n col_2 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 82, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def upsert_child_firing(\n db: PrefectDBInterface,\n session: AsyncSession,\n firing: Firing,\n):\n automation_id = firing.trigger.automation.id\n parent_trigger_id = firing.trigger.parent.id\n child_trigger_id = firing.trigger.id\n\n upsert = (\n postgresql.insert(db.CompositeTriggerChildFiring)\n .values(\n automation_id=automation_id,\n parent_trigger_id=parent_trigger_id,\n child_trigger_id=child_trigger_id,\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n )\n .on_conflict_do_update(\n index_elements=[\n db.CompositeTriggerChildFiring.automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id,\n ],\n set_=dict(\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n updated=pendulum.now(\"UTC\"),\n ),\n )\n )\n\n await session.execute(upsert)\n\n result = await session.execute(\n sa.select(db.CompositeTriggerChildFiring).filter(\n db.CompositeTriggerChildFiring.automation_id == automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id == parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id == child_trigger_id,\n )\n )\n\n return result.scalars().one()\n```\n###test function signature:\n```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_insert_compound_trigger_child_firing", + "reference": "\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n firing = Firing(\n trigger=saw_a_red_dragon,\n trigger_states={TriggerState.Triggered},\n triggered=frozen_time.add(seconds=1),\n triggering_labels={},\n triggering_event=baby_red_dragon_passed_by,\n )\n res = await upsert_child_firing(\n session=automations_session,\n firing=firing,\n )\n\n assert res and res.id\n assert res.automation_id == raise_the_alarm.id\n assert res.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert res.child_trigger_id == saw_a_red_dragon.id\n assert res.child_firing_id == firing.id\n assert res.child_fired_at == firing.triggered\n\n child_firings = await get_child_firings(\n session=automations_session, trigger=saw_a_red_and_green_dragon\n )\n assert len(child_firings) == 1\n assert child_firings[0].id == res.id\n assert child_firings[0].child_firing_id == firing.id", + "completions": [ + "```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n \"\"\"\n Test that a compound trigger child firing is inserted correctly\n \"\"\"\n # Arrange\n firing = Firing(\n id=uuid4(),\n triggered=frozen_time,\n trigger=saw_a_red_and_green_dragon,\n event=baby_red_dragon_passed_by,\n )\n\n # Act\n result = await upsert_child_firing(\n db=automations_session,\n session=automations_session,\n firing=firing,\n )\n\n # Assert\n assert result.automation_id == raise_the_alarm.id\n assert result.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert result.child_trigger_id == saw_a_red_dragon.id\n assert result.child_firing_id == firing.id\n assert result.child_fired_at == firing.triggered\n assert result.child_firing == firing.dict()\n assert result.updated == frozen_time\n\n```\n" + ], + "line": 58, + "token": 225, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 83, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef write_output(output_file, output_type, results):\n\n # If json or txt, write contents directly to file.\n if output_type == \"json\" or output_type == \"txt\":\n with open(output_file, \"w\") as f:\n json.dump(results, f, indent=4)\n return\n # If csv or xlsx, convert results to pandas dataframes.\n urls_df = get_urls_df(results)\n codes_df = get_codes_df(results)\n\n # If csv, write dataframes to separate csv files for urls, codes.\n if output_type == \"csv\":\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n urls_df.to_csv(urls_output_file, index=False)\n codes_df.to_csv(codes_output_file, index=False)\n\n # If xlsx, write dataframes to separate sheets for urls, codes.\n if output_type == \"xlsx\":\n writer = pd.ExcelWriter(output_file, engine=\"xlsxwriter\")\n urls_df.to_excel(writer, sheet_name=\"URLs\", index=False)\n codes_df.to_excel(writer, sheet_name=\"Codes\", index=False)\n writer.close()\n```\n###test function signature:\n```python\n def test_write_output_csv(self, mock_urls, mock_codes):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_output_csv", + "reference": " def test_write_output_csv(self, mock_urls, mock_codes):\n\n test_file = \"./test_output/test_file.csv\"\n test_file_urls = \"./test_output/test_file_urls.csv\"\n test_file_codes = \"./test_output/test_file_codes.csv\"\n test_results = {\"test\": \"test\"}\n mock_urls.return_value = pd.DataFrame([test_results])\n mock_codes.return_value = pd.DataFrame([test_results])\n\n with open(test_file_urls, \"w\") as f:\n pass\n\n with open(test_file_codes, \"w\") as f:\n pass\n\n write_output(test_file, \"csv\", test_results)\n\n test_data_urls = pd.read_csv(test_file_urls).to_dict(orient=\"records\")[0]\n test_data_codes = pd.read_csv(test_file_codes).to_dict(orient=\"records\")[0]\n\n os.remove(test_file_urls)\n os.remove(test_file_codes)\n\n self.assertEqual(test_data_urls, test_results)\n self.assertEqual(test_data_codes, test_results)", + "completions": [ + "```python\ndef test_write_output_csv(self, mock_urls, mock_codes):\n \"\"\"\n Test write_output function with csv output type.\n \"\"\"\n output_file = \"output.csv\"\n output_type = \"csv\"\n results = {\"urls\": mock_urls, \"codes\": mock_codes}\n write_output(output_file, output_type, results)\n\n # Check that the urls and codes dataframes are written to separate csv files.\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n self.assertTrue(os.path.isfile(urls_output_file))\n self.assertTrue(os.path.isfile(codes_output_file))\n\n # Check that the dataframes are correct.\n urls_df = pd.read_csv(urls_output_file)\n codes_df = pd.read_csv(codes_output_file)\n self.assertEqual(urls_df.shape, (2, 2))\n self.assertEqual(codes_df.shape, (2, 2))\n self.assertEqual(urls_df.iloc[0, 0], \"url1\")\n self.assertEqual(codes_df.iloc[0, 0], \"code1\")\n self.assertEqual(urls_df.iloc[1, 0], \"url2\")\n self.assertEqual(codes_df.iloc[1, 0], \"code2\")\n```\n" + ], + "line": 31, + "token": 279, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 84, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def terminate_idle_instance() -> None:\n async with get_session_ctx() as session:\n async with PROCESSING_POOL_LOCK:\n res = await session.execute(\n select(InstanceModel)\n .where(\n InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE,\n InstanceModel.deleted == False,\n InstanceModel.job == None, # noqa: E711\n InstanceModel.status == InstanceStatus.IDLE,\n )\n .options(joinedload(InstanceModel.project))\n )\n instances = res.scalars().all()\n\n for instance in instances:\n last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)\n if instance.last_job_processed_at is not None:\n last_time = instance.last_job_processed_at.replace(\n tzinfo=datetime.timezone.utc\n )\n\n idle_seconds = instance.termination_idle_time\n delta = datetime.timedelta(seconds=idle_seconds)\n\n current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc)\n\n if last_time + delta < current_time:\n jpd = JobProvisioningData.__response__.parse_raw(\n instance.job_provisioning_data\n )\n await terminate_job_provisioning_data_instance(\n project=instance.project, job_provisioning_data=jpd\n )\n instance.deleted = True\n instance.deleted_at = get_current_datetime()\n instance.finished_at = get_current_datetime()\n instance.status = InstanceStatus.TERMINATED\n instance.termination_reason = \"Idle timeout\"\n\n idle_time = current_time - last_time\n logger.info(\n \"Instance %s terminated by termination policy: idle time %ss\",\n instance.name,\n str(idle_time.seconds),\n extra={\n \"instance_name\": instance.name,\n \"instance_status\": InstanceStatus.TERMINATED.value,\n },\n )\n\n await session.commit()\n```\n###test function signature:\n```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_terminate_by_idle_timeout", + "reference": "\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n project = await create_project(session=session)\n pool = await create_pool(session, project)\n\n instance = await create_instance(session, project, pool, status=InstanceStatus.IDLE)\n instance.termination_idle_time = 300\n instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE\n instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)\n await session.commit()\n\n with patch(\n \"dstack._internal.server.background.tasks.process_instances.terminate_job_provisioning_data_instance\"\n ):\n await terminate_idle_instance()\n\n await session.refresh(instance)\n\n assert instance is not None\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"", + "completions": [ + "```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n # given\n project = ProjectModel(name=\"test_project\")\n await session.add(project)\n await session.commit()\n\n instance = InstanceModel(\n name=\"test_instance\",\n project=project,\n termination_policy=TerminationPolicy.DESTROY_AFTER_IDLE,\n termination_idle_time=10,\n status=InstanceStatus.IDLE,\n created_at=datetime.datetime.now(datetime.timezone.utc),\n )\n await session.add(instance)\n await session.commit()\n\n # when\n await terminate_idle_instance()\n\n # then\n instance = await session.get(InstanceModel, instance.id)\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"\n```\n" + ], + "line": 56, + "token": 289, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 85, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef nargsort(\n items,\n kind: str = \"quicksort\",\n ascending: bool = True,\n na_position: str = \"last\",\n key: Callable | None = None,\n mask: npt.NDArray[np.bool_] | None = None,\n) -> npt.NDArray[np.intp]:\n\n if key is not None:\n items = ensure_key_mapped(items, key)\n return nargsort(\n items,\n kind=kind,\n ascending=ascending,\n na_position=na_position,\n key=None,\n mask=mask,\n )\n\n if isinstance(items, ABCRangeIndex):\n return items.argsort(ascending=ascending) # TODO: test coverage with key?\n elif not isinstance(items, ABCMultiIndex):\n items = extract_array(items)\n if mask is None:\n mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?\n\n if is_extension_array_dtype(items):\n return items.argsort(ascending=ascending, kind=kind, na_position=na_position)\n else:\n items = np.asanyarray(items)\n\n idx = np.arange(len(items))\n non_nans = items[~mask]\n non_nan_idx = idx[~mask]\n\n nan_idx = np.nonzero(mask)[0]\n if not ascending:\n non_nans = non_nans[::-1]\n non_nan_idx = non_nan_idx[::-1]\n indexer = non_nan_idx[non_nans.argsort(kind=kind)]\n if not ascending:\n indexer = indexer[::-1]\n # Finally, place the NaNs at the end or the beginning according to\n # na_position\n if na_position == \"last\":\n indexer = np.concatenate([indexer, nan_idx])\n elif na_position == \"first\":\n indexer = np.concatenate([nan_idx, indexer])\n else:\n raise ValueError(f\"invalid na_position: {na_position}\")\n return ensure_platform_int(indexer)\n```\n###test function signature:\n```python\n\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nargsort", + "reference": "\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)\n\n # mergesort is the most difficult to get right because we want it to be\n # stable.\n\n # According to numpy/core/tests/test_multiarray, \"\"\"The number of\n # sorted items must be greater than ~50 to check the actual algorithm\n # because quick and merge sort fall over to insertion sort for small\n # arrays.\"\"\"\n\n result = nargsort(\n items, kind=\"mergesort\", ascending=ascending, na_position=na_position\n )\n tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)", + "completions": [ + "```python\n def test_nargsort(self, ascending, na_position, exp, box):\n items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216," + ], + "line": 58, + "token": 412, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 86, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_rotated_rb_circuits(\n n_qubits: int,\n num_cliffords: int,\n theta: Optional[float] = None,\n trials: int = 1,\n return_type: Optional[str] = None,\n seed: Optional[int] = None,\n) -> List[QPROGRAM]:\n\n circuits = cast(\n List[cirq.Circuit],\n generate_rb_circuits(n_qubits, num_cliffords, trials, seed=seed),\n )\n\n if theta is None:\n theta = 2 * np.pi * np.random.rand()\n\n for circ in circuits:\n qubits = list(circ.all_qubits())\n circ.insert(len(circ) // 2, cirq.Rz(rads=theta).on(qubits[0]))\n\n return_type = \"cirq\" if not return_type else return_type\n return [convert_from_mitiq(circuit, return_type) for circuit in circuits]\n```\n###test function signature:\n```python\n\ndef test_rotated_rb_circuits(n_qubits, theta):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rotated_rb_circuits", + "reference": "\ndef test_rotated_rb_circuits(n_qubits, theta):\n depth = 10\n for trials in [2, 3]:\n circuits = generate_rotated_rb_circuits(\n n_qubits=n_qubits, num_cliffords=depth, theta=theta, trials=trials\n )\n for circ in circuits:\n zero_prob = (\n cirq.DensityMatrixSimulator()\n .simulate(circ)\n .final_density_matrix[0, 0]\n .real\n )\n assert -1.0001 <= zero_prob <= 1.0001", + "completions": [ + "```python\ndef test_rotated_rb_circuits(n_qubits, theta):\n from cirq import Circuit, LineQubit\n from cirq.testing import assert_has_diagram\n from mitiq.generate import generate_rotated_rb_circuits\n\n circuits = generate_rotated_rb_circuits(n_qubits, 10, theta)\n for circuit in circuits:\n assert_has_diagram(\n circuit,\n f\"\"\"\n0: ───Rz({theta})───\n1: ───Rz({theta})───\n2: ───Rz({theta})───\n3: ───Rz({theta})───\n4: ───Rz({theta})───\n5: ───Rz({theta})───\n6: ───Rz({theta})───\n7: ───Rz({theta})───\n8: ───Rz({theta})───\n9: ───Rz({theta})───\n10: ───Rz({theta})───\n11: ───Rz({theta})───\n12: ───Rz({theta})───\n13: ───Rz({theta})───\n14: ───Rz({theta})───\n15: ───Rz({theta})───\n16: ───Rz({theta})───\n17: ───Rz({theta})───\n18: ───Rz({theta})───\n19: ───Rz({theta})───\n20: ───Rz({theta})───\n21: ───Rz({theta})───\n22: ───Rz({theta})───\n23: ───Rz({theta})───\n24: ───Rz({theta})───\n25: ───Rz({theta})───\n26: ───Rz({theta})───\n27: ───Rz({theta})───\n28: ───Rz({theta})───\n29: ───Rz({theta})───\n30: ───Rz({theta})───\n31: ───Rz({theta})───\n32: ───Rz({theta})───\n33: ───Rz({theta})───\n34: ───Rz({theta})───\n35: ───Rz({theta})───\n36: ───Rz({theta})───\n37: ───Rz({theta})───\n38: ───Rz({theta})───\n39: ───Rz({theta})───\n40: ───Rz({theta})───\n41: ───Rz({theta})───\n42: ───Rz({theta})───\n43: ───Rz({theta})───\n44: ───Rz({theta})───\n45: ───Rz({theta})───\n46: ───Rz({theta})───\n47: ───Rz({theta})───\n48: ───Rz({theta})───\n49: ───Rz({theta})───\n50: ───Rz({theta})───\n51: ───Rz({theta})───\n52: ───Rz({theta})───\n53: ───Rz({theta})───\n54: ───Rz({theta})───\n55: ───Rz({theta})───\n56: ───Rz({theta})───\n57: ───Rz({theta})───\n58: ───Rz({theta})───\n59: ───Rz({theta})───\n60: ───Rz({theta})───\n61: ───Rz({theta})───\n62: ───Rz({theta})───\n63: ───" + ], + "line": 29, + "token": 226, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 87, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef eye(N, M=None, k=None, dtype=None):\n dtype = to_torch_dtype(dtype or config.floatx())\n M = N if M is None else M\n k = 0 if k is None else k\n if k == 0:\n # TODO: torch.eye doesn't support bfloat16 with cpu\n if get_device() == \"cpu\" and dtype == torch.bfloat16:\n return cast(\n torch.eye(\n N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n ),\n dtype,\n )\n return torch.eye(N, M, dtype=dtype, device=get_device())\n diag_length = builtins.max(N, M)\n diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n return torch.diag(diag, diagonal=k)[:N, :M]\n```\n###test function signature:\n```python\n\n def test_eye(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_eye", + "reference": "\n def test_eye(self):\n self.assertAllClose(knp.eye(3), np.eye(3))\n self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n self.assertAllClose(knp.Eye()(3), np.eye(3))\n self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1))\n\n # Test k >= N\n self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3))\n\n # Test k > 0 and N >= M\n self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1))\n\n # Test k > 0 and N < M and N + k > M\n self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2))\n\n # Test k < 0 and M >= N\n self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1))\n\n # Test k < 0 and M < N and M - k > N\n self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))", + "completions": [ + "```python\n\n def test_eye(self):\n self.assertEqual(eye(3), torch.eye(3))\n self.assertEqual(eye(3, 4), torch.eye(3, 4))\n self.assertEqual(eye(3, 4, 1), torch.eye(3, 4, 1))\n self.assertEqual(eye(3, 4, -1), torch.eye(3, 4, -1))\n self.assertEqual(eye(3, 4, 2), torch.eye(3, 4, 2))\n self.assertEqual(eye(3, 4, -2), torch.eye(3, 4, -2))\n self.assertEqual(eye(3, 4, 3), torch.eye(3, 4, 3))\n self.assertEqual(eye(3, 4, -3), torch.eye(3, 4, -3))\n self.assertEqual(eye(3, 4, 4), torch.eye(3, 4, 4))\n self.assertEqual(eye(3, 4, -4), torch.eye(3, 4, -4))\n self.assertEqual(eye(3, 4, 5), torch.eye(3, 4, 5))\n self.assertEqual(eye(3, 4, -5), torch.eye(3, 4, -5))\n self.assertEqual(eye(3, 4, 6), torch.eye(3, 4, 6))\n self.assertEqual(eye(3, 4, -6), torch.eye(3, 4, -6))\n self.assertEqual(eye(3, 4, 7), torch.eye(3, 4, 7))\n self.assertEqual(eye(3, 4, -7), torch.eye(3, 4, -7))\n self.assertEqual(eye(3, 4, 8), torch.eye(3, 4, 8))\n self.assertEqual(eye(3, 4, -8), torch.eye(3, 4, -8))\n self.assertEqual(eye(3, 4, 9), torch.eye(3, 4, 9))\n self.assertEqual(eye(3, 4, -9), torch.eye(3, 4, -9))\n self.assertEqual(eye(3, 4, 10), torch.eye(3, 4, 10))\n self.assertEqual(eye(3, 4, -10), torch.eye(3, 4, -10))\n self.assertEqual(eye(3, 4, 11), torch.eye(3, 4, 11))\n self.assertEqual(eye(3, 4, -11), torch.eye(3, 4, -11))\n self.assertEqual(eye(3, 4, 12), torch.eye(3, 4, 12))\n self.assertEqual(eye(3, 4, -12), torch.eye(3, 4, -12))\n self.assertEqual(eye(3, 4, 13), torch.eye(3, 4, 13))\n self.assertEqual(eye(3, 4, -13), torch.eye(3, 4, -13))\n self.assertEqual(eye(3, 4, 14), torch.eye(3, 4, 14))\n self.assertEqual(eye(3, 4, -14), torch.eye(3, 4, -14))\n self.assertEqual(eye(3, 4, 15), torch.eye(" + ], + "line": 27, + "token": 219, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 88, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef ragged_qk_dotprod(\n query: RaggedActivations, key: RaggedActivations, lut: RaggedQkPidLookupTable\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n # check constraints\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n # allocates output\n # max_n_ctx_q_across_seqs = query.max_n_ctx_per_seq\n\n assert query.n_seqs == key.n_seqs\n # TODO: flag use zeros for garbage\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n # Integers\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n```\n###test function signature:\n```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ragged_qk_dotprod_single_seq", + "reference": "\ndef test_ragged_qk_dotprod_single_seq(dtype):\n d_head = 2\n\n key = RaggedActivations.from_list(\n [\n make_seq(n_ctx=3, value=42, d_model=d_head, dtype=dtype),\n ]\n )\n query = RaggedActivations.from_list(\n [\n make_seq(n_ctx=4, value=55, d_model=d_head, dtype=dtype),\n ]\n )\n torch_scores = scores_via_qk_dotprod(query, key)\n print(f\"{torch_scores=}\")\n\n lut = RaggedQkPidLookupTable.from_query_and_key_tokens_per_seq(\n n_ctx_q_per_seq=query.n_ctx_per_seq, n_ctx_k_per_seq=key.n_ctx_per_seq\n )\n\n scores = ragged_qk_dotprod(query, key, lut)\n assert_eq(torch_scores, scores)", + "completions": [ + "```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n # query\n query_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n query = RaggedActivations(\n raw_tensor=query_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # key\n key_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n key = RaggedActivations(\n raw_tensor=key_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # lut\n pid_to_in_q_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_in_k_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_q_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_k_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_seq_idx = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n lut = RaggedQkPidLookupTable(\n pid_to_in_q_token_offset=pid_to_in_q_token_offset,\n pid_to_in_k_token_offset=pid_to_in_k_token_offset,\n pid_to_out_q_block=pid_to_out_q_block,\n pid_to_out_k_block=pid_to_out_k_block,\n pid_to_out_seq_idx=pid_to_out_seq_idx,\n n_pids_total=1,\n )\n\n # scores_out\n scores_out = torch.ones(\n (1, 1, 1), dtype=dtype, device=torch.device(\"cuda\")\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_" + ], + "line": 57, + "token": 310, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 89, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):\n # Ravel both arrays, behavior for the first array could be different\n ar1 = np.asarray(ar1).ravel()\n ar2 = np.asarray(ar2).ravel()\n\n # Ensure that iteration through object arrays yields size-1 arrays\n if ar2.dtype == object:\n ar2 = ar2.reshape(-1, 1)\n\n if kind not in {None, 'sort', 'table'}:\n raise ValueError(\n f\"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.\")\n\n # Can use the table method if all arrays are integers or boolean:\n is_int_arrays = all(ar.dtype.kind in (\"u\", \"i\", \"b\") for ar in (ar1, ar2))\n use_table_method = is_int_arrays and kind in {None, 'table'}\n\n if use_table_method:\n if ar2.size == 0:\n if invert:\n return np.ones_like(ar1, dtype=bool)\n else:\n return np.zeros_like(ar1, dtype=bool)\n\n # Convert booleans to uint8 so we can use the fast integer algorithm\n if ar1.dtype == bool:\n ar1 = ar1.astype(np.uint8)\n if ar2.dtype == bool:\n ar2 = ar2.astype(np.uint8)\n\n ar2_min = np.min(ar2)\n ar2_max = np.max(ar2)\n\n ar2_range = int(ar2_max) - int(ar2_min)\n\n # Constraints on whether we can actually use the table method:\n # 1. Assert memory usage is not too large\n below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)\n # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype\n range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max\n # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype\n if ar1.size > 0:\n ar1_min = np.min(ar1)\n ar1_max = np.max(ar1)\n\n # After masking, the range of ar1 is guaranteed to be\n # within the range of ar2:\n ar1_upper = min(int(ar1_max), int(ar2_max))\n ar1_lower = max(int(ar1_min), int(ar2_min))\n\n range_safe_from_overflow &= all((\n ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,\n ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min\n ))\n\n # Optimal performance is for approximately\n # log10(size) > (log10(range) - 2.27) / 0.927.\n # However, here we set the requirement that by default\n # the intermediate array can only be 6x\n # the combined memory allocation of the original\n # arrays. See discussion on \n # https://github.com/numpy/numpy/pull/12065.\n\n if (\n range_safe_from_overflow and \n (below_memory_constraint or kind == 'table')\n ):\n\n if invert:\n outgoing_array = np.ones_like(ar1, dtype=bool)\n else:\n outgoing_array = np.zeros_like(ar1, dtype=bool)\n\n # Make elements 1 where the integer exists in ar2\n if invert:\n isin_helper_ar = np.ones(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 0\n else:\n isin_helper_ar = np.zeros(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 1\n\n # Mask out elements we know won't work\n basic_mask = (ar1 <= ar2_max) & (ar1 >= ar2_min)\n outgoing_array[basic_mask] = isin_helper_ar[ar1[basic_mask] -\n ar2_min]\n\n return outgoing_array\n elif kind == 'table': # not range_safe_from_overflow\n raise RuntimeError(\n \"You have specified kind='table', \"\n \"but the range of values in `ar2` or `ar1` exceed the \"\n \"maximum integer of the datatype. \"\n \"Please set `kind` to None or 'sort'.\"\n )\n elif kind == 'table':\n raise ValueError(\n \"The 'table' method is only \"\n \"supported for boolean or integer arrays. \"\n \"Please select 'sort' or None for kind.\"\n )\n\n\n # Check if one of the arrays may contain arbitrary objects\n contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject\n\n # This code is run when\n # a) the first condition is true, making the code significantly faster\n # b) the second condition is true (i.e. `ar1` or `ar2` may contain\n # arbitrary objects), since then sorting is not guaranteed to work\n if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:\n if invert:\n mask = np.ones(len(ar1), dtype=bool)\n for a in ar2:\n mask &= (ar1 != a)\n else:\n mask = np.zeros(len(ar1), dtype=bool)\n for a in ar2:\n mask |= (ar1 == a)\n return mask\n\n # Otherwise use sorting\n if not assume_unique:\n ar1, rev_idx = np.unique(ar1, return_inverse=True)\n ar2 = np.unique(ar2)\n\n ar = np.concatenate((ar1, ar2))\n # We need this to be a stable sort, so always use 'mergesort'\n # here. The values from the first array should always come before\n # the values from the second array.\n order = ar.argsort(kind='mergesort')\n sar = ar[order]\n if invert:\n bool_ar = (sar[1:] != sar[:-1])\n else:\n bool_ar = (sar[1:] == sar[:-1])\n flag = np.concatenate((bool_ar, [invert]))\n ret = np.empty(ar.shape, dtype=bool)\n ret[order] = flag\n\n if assume_unique:\n return ret[:len(ar1)]\n else:\n return ret[rev_idx]\n```\n###test function signature:\n```python\n\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_in1d_ravel", + "reference": "\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n a = np.arange(6).reshape(2, 3)\n b = np.arange(3, 9).reshape(3, 2)\n long_b = np.arange(3, 63).reshape(30, 2)\n ec = np.array([False, False, False, True, True, True])\n\n assert_array_equal(in1d(a, b, assume_unique=True, kind=kind),\n ec)\n assert_array_equal(in1d(a, b, assume_unique=False,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=True,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=False,\n kind=kind),\n ec)", + "completions": [ + "```python\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n ar1 = np.array([1, 2, 3, 4, 5])\n ar2 = np.array([1, 2, 3, 4, 5])\n ar1_ravel = ar1.ravel()\n ar2_ravel = ar2.ravel()\n self.assertTrue(np.array_equal(ar1_ravel, ar1))\n self.assertTrue(np.array_equal(ar2_ravel, ar2))\n self.assertTrue(np.array_equal(in1d(ar1, ar2, kind=kind),\n in1d(ar1_ravel, ar2_ravel, kind=kind)))\n```\n" + ], + "line": 133, + "token": 1112, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 90, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_metadata", + "reference": "\ndef test_header_to_metadata_and_cell_metadata():\n text = \"\"\"---\ntitle: Sample header\njupyter:\n mainlanguage: python\n---\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {\"mainlanguage\": \"python\"}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {\"lines_to_next_cell\": 0}\n assert pos == len(lines)", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_metadata():\n lines = [\n \"#' ---\",\n \"#' jupyter:\",\n \"#' jupytext:\",\n \"#' text_representation:\",\n \"#' extension: .py\",\n \"#' format_name: light\",\n \"#' format_version: '1.5'\",\n \"#' jupytext_version: 1.11.4\",\n \"#' kernelspec:\",\n \"#' display_name: Python 3\",\n \"#' language: python\",\n \"#' name: python3\",\n \"#' ---\",\n \"#'\",\n \"#' This is a comment\",\n \"#'\",\n \"#' And another one\",\n \"#'\",\n \"#' ```python\",\n \"#' import numpy as np\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' import matplotlib.pyplot as plt\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %matplotlib inline\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2" + ], + "line": 104, + "token": 642, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 91, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef import_optional_dependency(\n name: str,\n extra: str = \"\",\n errors: str = \"raise\",\n min_version: str | None = None,\n):\n\n assert errors in {\"warn\", \"raise\", \"ignore\"}\n\n package_name = INSTALL_MAPPING.get(name)\n install_name = package_name if package_name is not None else name\n\n msg = (\n f\"Missing optional dependency '{install_name}'. {extra} \"\n f\"Use pip or conda to install {install_name}.\"\n )\n try:\n module = importlib.import_module(name)\n except ImportError:\n if errors == \"raise\":\n raise ImportError(msg)\n return None\n\n # Handle submodules: if we have submodule, grab parent module from sys.modules\n parent = name.split(\".\")[0]\n if parent != name:\n install_name = parent\n module_to_get = sys.modules[install_name]\n else:\n module_to_get = module\n minimum_version = min_version if min_version is not None else VERSIONS.get(parent)\n if minimum_version:\n version = get_version(module_to_get)\n if version and Version(version) < Version(minimum_version):\n msg = (\n f\"Pandas requires version '{minimum_version}' or newer of '{parent}' \"\n f\"(version '{version}' currently installed).\"\n )\n if errors == \"warn\":\n warnings.warn(\n msg,\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return None\n elif errors == \"raise\":\n raise ImportError(msg)\n\n return module\n```\n###test function signature:\n```python\n\ndef test_bad_version(monkeypatch):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_bad_version", + "reference": "\ndef test_bad_version(monkeypatch):\n name = \"fakemodule\"\n module = types.ModuleType(name)\n module.__version__ = \"0.9.0\"\n sys.modules[name] = module\n monkeypatch.setitem(VERSIONS, name, \"1.0.0\")\n\n match = \"Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'\"\n with pytest.raises(ImportError, match=match):\n import_optional_dependency(\"fakemodule\")\n\n # Test min_version parameter\n result = import_optional_dependency(\"fakemodule\", min_version=\"0.8\")\n assert result is module\n\n with tm.assert_produces_warning(UserWarning):\n result = import_optional_dependency(\"fakemodule\", errors=\"warn\")\n assert result is None\n\n module.__version__ = \"1.0.0\" # exact match is OK\n result = import_optional_dependency(\"fakemodule\")\n assert result is module", + "completions": [ + "```python\ndef test_bad_version(monkeypatch):\n monkeypatch.setattr(import_optional_dependency, \"get_version\", lambda x: \"0.1.0\")\n monkeypatch.setattr(import_optional_dependency, \"VERSIONS\", {\"pandas\": \"0.2.0\"})\n with pytest.raises(ImportError):\n import_optional_dependency(\"pandas\")\n```\n" + ], + "line": 54, + "token": 352, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 92, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_mel_spectrogram(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mel_spectrogram", + "reference": "\n def test_mel_spectrogram(self):\n waveform = self._load_datasamples(1)[0]\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=513,\n num_mel_filters=13,\n min_frequency=100,\n max_frequency=4000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"htk\",\n )\n self.assertEqual(mel_filters.shape, (513, 13))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n )\n self.assertEqual(spec.shape, (513, 732))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n mel_filters=mel_filters,\n )\n self.assertEqual(spec.shape, (13, 732))\n\n # fmt: off\n expected = np.array([\n 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,\n 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,\n 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,\n 9.44153646e-04\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:, 300], expected))", + "completions": [ + "```python\n def test_mel_spectrogram(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n mel_floor = 1e-10\n log_mel = \"log\"\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (1, 10))\n self.assertEqual(spectrogram[0][0], 1)\n self.assertEqual(spectrogram[0][1], 2)\n self.assertEqual(spectrogram[0][2], 3)\n self.assertEqual(spectrogram[0][3], 4)\n self.assertEqual(spectrogram[0][4], 5)\n self.assertEqual(spectrogram[0][5], 6)\n self.assertEqual(spectrogram[0][6], 7)\n self.assertEqual(spectrogram[0][7], 8)\n self.assertEqual(spectrogram[0][8], 9)\n self.assertEqual(spectrogram[0][9], 10)\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 93, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef create(\n meta_schema,\n validators=(),\n version=None,\n type_checker=_types.draft202012_type_checker,\n format_checker=_format.draft202012_format_checker,\n id_of=_id_of,\n applicable_validators=methodcaller(\"items\"),\n):\n # preemptively don't shadow the `Validator.format_checker` local\n format_checker_arg = format_checker\n\n @attr.s\n class Validator:\n\n VALIDATORS = dict(validators)\n META_SCHEMA = dict(meta_schema)\n TYPE_CHECKER = type_checker\n FORMAT_CHECKER = format_checker_arg\n ID_OF = staticmethod(id_of)\n\n schema = attr.ib(repr=reprlib.repr)\n resolver = attr.ib(default=None, repr=False)\n format_checker = attr.ib(default=None)\n\n def __init_subclass__(cls):\n warnings.warn(\n (\n \"Subclassing validator classes is not intended to \"\n \"be part of their public API. A future version \"\n \"will make doing so an error, as the behavior of \"\n \"subclasses isn't guaranteed to stay the same \"\n \"between releases of jsonschema. Instead, prefer \"\n \"composition of validators, wrapping them in an object \"\n \"owned entirely by the downstream library.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n\n def __attrs_post_init__(self):\n if self.resolver is None:\n self.resolver = RefResolver.from_schema(\n self.schema,\n id_of=id_of,\n )\n\n @classmethod\n def check_schema(cls, schema, format_checker=_UNSET):\n Validator = validator_for(cls.META_SCHEMA, default=cls)\n if format_checker is _UNSET:\n format_checker = Validator.FORMAT_CHECKER\n validator = Validator(\n schema=cls.META_SCHEMA,\n format_checker=format_checker,\n )\n for error in validator.iter_errors(schema):\n raise exceptions.SchemaError.create_from(error)\n\n def evolve(self, **changes):\n # Essentially reproduces attr.evolve, but may involve instantiating\n # a different class than this one.\n cls = self.__class__\n\n schema = changes.setdefault(\"schema\", self.schema)\n NewValidator = validator_for(schema, default=cls)\n\n for field in attr.fields(cls):\n if not field.init:\n continue\n attr_name = field.name # To deal with private attributes.\n init_name = attr_name if attr_name[0] != \"_\" else attr_name[1:]\n if init_name not in changes:\n changes[init_name] = getattr(self, attr_name)\n\n return NewValidator(**changes)\n\n def iter_errors(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.iter_errors \"\n \"is deprecated and will be removed in a future \"\n \"release. Call validator.evolve(schema=new_schema).\"\n \"iter_errors(...) instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n else:\n _schema = self.schema\n\n if _schema is True:\n return\n elif _schema is False:\n yield exceptions.ValidationError(\n f\"False schema does not allow {instance!r}\",\n validator=None,\n validator_value=None,\n instance=instance,\n schema=_schema,\n )\n return\n\n scope = id_of(_schema)\n if scope:\n self.resolver.push_scope(scope)\n try:\n for k, v in applicable_validators(_schema):\n validator = self.VALIDATORS.get(k)\n if validator is None:\n continue\n\n errors = validator(self, v, instance, _schema) or ()\n for error in errors:\n # set details if not already set by the called fn\n error._set(\n validator=k,\n validator_value=v,\n instance=instance,\n schema=_schema,\n type_checker=self.TYPE_CHECKER,\n )\n if k not in {\"if\", \"$ref\"}:\n error.schema_path.appendleft(k)\n yield error\n finally:\n if scope:\n self.resolver.pop_scope()\n\n def descend(self, instance, schema, path=None, schema_path=None):\n for error in self.evolve(schema=schema).iter_errors(instance):\n if path is not None:\n error.path.appendleft(path)\n if schema_path is not None:\n error.schema_path.appendleft(schema_path)\n yield error\n\n def validate(self, *args, **kwargs):\n for error in self.iter_errors(*args, **kwargs):\n raise error\n\n def is_type(self, instance, type):\n try:\n return self.TYPE_CHECKER.is_type(instance, type)\n except exceptions.UndefinedTypeCheck:\n raise exceptions.UnknownType(type, instance, self.schema)\n\n def is_valid(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.is_valid is deprecated \"\n \"and will be removed in a future release. Call \"\n \"validator.evolve(schema=new_schema).is_valid(...) \"\n \"instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n self = self.evolve(schema=_schema)\n\n error = next(self.iter_errors(instance), None)\n return error is None\n\n if version is not None:\n safe = version.title().replace(\" \", \"\").replace(\"-\", \"\")\n Validator.__name__ = Validator.__qualname__ = f\"{safe}Validator\"\n Validator = validates(version)(Validator)\n\n return Validator\n```\n###test function signature:\n```python\n\n def test_validates_registers_meta_schema_id(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validates_registers_meta_schema_id", + "reference": "\n def test_validates_registers_meta_schema_id(self):\n meta_schema_key = \"meta schema id\"\n my_meta_schema = {\"id\": meta_schema_key}\n\n validators.create(\n meta_schema=my_meta_schema,\n version=\"my version\",\n id_of=lambda s: s.get(\"id\", \"\"),\n )\n self.addCleanup(validators._META_SCHEMAS.pop, meta_schema_key)\n\n self.assertIn(meta_schema_key, validators._META_SCHEMAS)", + "completions": [ + "```python\n def test_validates_registers_meta_schema_id(self):\n validator = create(\n meta_schema={\n \"$id\": \"https://example.com/schema.json\",\n \"type\": \"object\",\n \"properties\": {\"foo\": {\"type\": \"string\"}},\n }\n )\n self.assertEqual(\n validator.resolver.store[\"https://example.com/schema.json\"],\n validator.schema,\n )\n```\n" + ], + "line": 160, + "token": 896, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 94, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_metrics(directory: str) -> pd.Series:\n\n dfs = []\n # Loop through all csv files in the directory\n csv_files = [\n filename for filename in os.listdir(directory) if filename.endswith(\".csv\")\n ]\n\n if len(csv_files) == 0:\n return pd.Series()\n\n for filename in csv_files:\n filepath = os.path.join(directory, filename)\n # Read the csv file into a pandas DataFrame\n df = pd.read_csv(\n filepath,\n header=None,\n names=[\n \"initial_time_iso\",\n \"lead_time_hours\",\n \"channel\",\n \"metric\",\n \"value\",\n ],\n )\n out = pd.DataFrame()\n out[\"initial_time\"] = pd.to_datetime(df[\"initial_time_iso\"])\n out[\"lead_time\"] = pd.to_timedelta(df[\"lead_time_hours\"], unit=\"h\")\n out[\"channel\"] = df[\"channel\"]\n out[\"metric\"] = df[\"metric\"]\n out[\"value\"] = df[\"value\"]\n dfs.append(out)\n\n df = pd.concat(dfs, axis=0)\n df.set_index([\"initial_time\", \"lead_time\", \"channel\", \"metric\"], inplace=True)\n return df[\"value\"]\n```\n###test function signature:\n```python\n\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_and_read_metric", + "reference": "\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n initial_time = datetime.datetime(2022, 1, 1, 0, 0, 0)\n lead_time = datetime.timedelta(hours=24)\n channel = \"t2m\"\n metric = \"rmse\"\n value = 25.6\n\n # Read the metric from the directory\n directory = tmpdir.mkdir(\"metrics\")\n filename = \"test_metric.csv\"\n filepath = os.path.join(directory, filename)\n with open(filepath, \"w\") as f:\n forecast_metrics_io.write_metric(\n f, initial_time, lead_time, channel, metric, value\n )\n metrics = forecast_metrics_io.read_metrics(directory)\n\n # Check that the metric was written correctly\n expected_index = pd.MultiIndex.from_tuples(\n [(initial_time, lead_time, channel, metric)],\n names=[\"initial_time\", \"lead_time\", \"channel\", \"metric\"],\n )\n expected_value = pd.Series([value], index=expected_index, name=\"value\")\n pd.testing.assert_series_equal(metrics, expected_value)", + "completions": [ + "```python\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n metric = pd.Series(\n data=[1, 2, 3],\n index=[\n pd.Timestamp(\"2020-01-01 00:00:00\"),\n pd.Timestamp(\"2020-01-01 01:00:00\"),\n pd.Timestamp(\"2020-01-01 02:00:00\"),\n ],\n name=\"metric\",\n )\n metric.to_csv(tmpdir.join(\"metric.csv\"), header=False)\n\n # Read the metric from the file object\n read_metric = read_metrics(tmpdir.strpath)\n\n # Check that the metric is the same\n assert_series_equal(metric, read_metric)\n```\n" + ], + "line": 43, + "token": 320, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 95, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n\n def test_kde_bandwidth_method(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_kde_bandwidth_method", + "reference": "\n def test_kde_bandwidth_method(self):\n\n np.random.seed(8765678)\n n_basesample = 50\n xn = np.random.randn(n_basesample)\n\n # Default\n gkde = mlab.GaussianKDE(xn)\n # Supply a callable\n gkde2 = mlab.GaussianKDE(xn, 'scott')\n # Supply a scalar\n gkde3 = mlab.GaussianKDE(xn, bw_method=gkde.factor)\n\n xs = np.linspace(-7, 7, 51)\n kdepdf = gkde.evaluate(xs)\n kdepdf2 = gkde2.evaluate(xs)\n assert kdepdf.all() == kdepdf2.all()\n kdepdf3 = gkde3.evaluate(xs)\n assert kdepdf.all() == kdepdf3.all()", + "completions": [ + "```python\n def test_kde_bandwidth_method(self):\n from sklearn.neighbors import KernelDensity\n from sklearn.metrics import euclidean_distances\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n from sklearn.model_selection import train_test_split\n from sklearn.metrics import accuracy_score\n from sklearn.metrics import f1_score\n from sklearn.metrics import precision_score\n from sklearn.metrics import recall_score\n from sklearn.metrics import roc_auc_score\n from sklearn.metrics import confusion_matrix\n from sklearn.metrics import classification_report\n from sklearn.metrics import cohen_kappa_score\n from sklearn.metrics import mean_squared_error\n from sklearn.metrics import mean_absolute_error\n from sklearn.metrics import mean_squared_log_error\n from sklearn.metrics import median_absolute_error\n from sklearn.metrics import r2_score\n from sklearn.metrics import explained_variance_score\n from sklearn.metrics import silhouette_score\n from sklearn.metrics import calinski_harabasz_score\n from sklearn.metrics import davies_bouldin_score\n from sklearn.metrics import mutual_info_score\n from sklearn.metrics import normalized_mutual_info_score\n from sklearn.metrics import adjusted_mutual_info_score\n from sklearn.metrics import adjusted_rand_score\n from sklearn.metrics import completeness_score\n from sklearn.metrics import contingency_matrix\n from sklearn.metrics import homogeneity_completeness_v_measure\n from sklearn.metrics import homogeneity_score\n from sklearn.metrics import v_measure_score\n from sklearn.metrics import normalized_v_measure_score\n from sklearn.metrics import poisson\n from sklearn.metrics import kld\n from sklearn.metrics import kld_normalized\n from sklearn.metrics import kld_unnormalized\n from sklearn.metrics import kld_unnormalized_normalized\n from sklearn.metrics import kld_unnormalized_normalized_sym\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from" + ], + "line": 33, + "token": 273, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 96, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_top_k_embeddings(\n query_embedding: List[float],\n embeddings: List[List[float]],\n similarity_fn: Optional[Callable[..., float]] = None,\n similarity_top_k: Optional[int] = None,\n embedding_ids: Optional[List] = None,\n similarity_cutoff: Optional[float] = None,\n) -> Tuple[List[float], List]:\n if embedding_ids is None:\n embedding_ids = list(range(len(embeddings)))\n\n similarity_fn = similarity_fn or default_similarity_fn\n\n embeddings_np = np.array(embeddings)\n query_embedding_np = np.array(query_embedding)\n\n similarity_heap: List[Tuple[float, Any]] = []\n for i, emb in enumerate(embeddings_np):\n similarity = similarity_fn(query_embedding_np, emb)\n if similarity_cutoff is None or similarity > similarity_cutoff:\n heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))\n if similarity_top_k and len(similarity_heap) > similarity_top_k:\n heapq.heappop(similarity_heap)\n result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)\n\n result_similarities = [s for s, _ in result_tups]\n result_ids = [n for _, n in result_tups]\n\n return result_similarities, result_ids\n```\n###test function signature:\n```python\ndef test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_top_k_mmr_embeddings", + "reference": "def test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n query_embedding = [5.0, 0.0, 0.0]\n embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]]\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.8\n )\n\n assert np.isclose(0.8 * 4 / 5, result_similarities[0], atol=0.00001)\n assert np.isclose(\n 0.8 * 3 / 5 - (1 - 0.8) * (3 * 4 / 25 + 3 * 4 / 25),\n result_similarities[1],\n atol=0.00001,\n )\n assert np.isclose(\n 0.8 * -4 / 5 - (1 - 0.8) * (3 * -4 / 25 + 4 * 3 / 25),\n result_similarities[2],\n atol=0.00001,\n )\n assert result_ids == [0, 1, 2]\n\n # Tests that if the first embedding vector is close to the second,\n # it will return the third\n query_embedding = [1.0, 0.0, 1.0]\n embeddings = [[1.0, 0.0, 0.9], [1.0, 0.0, 0.8], [0.7, 0.0, 1.0]]\n\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.5\n )\n assert result_ids == [0, 2, 1]\n\n # Tests that embedding ids map properly to results\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, embedding_ids=[\"A\", \"B\", \"C\"], mmr_threshold=0.5\n )\n assert result_ids == [\"A\", \"C\", \"B\"]\n # Test that it will go back to the original order under a high threshold\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n assert result_ids == [0, 1, 2]\n\n # Test similarity_top_k works\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1, similarity_top_k=2\n )\n assert result_ids == [0, 1]\n\n # Test the results for get_top_k_embeddings and get_top_k_mmr_embeddings are the\n # same for threshold = 1\n query_embedding = [10, 23, 90, 78]\n embeddings = [[1, 23, 89, 68], [1, 74, 144, 23], [0.23, 0.0, 1.0, 9]]\n result_similarities_no_mmr, result_ids_no_mmr = get_top_k_embeddings(\n query_embedding, embeddings\n )\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n\n for result_no_mmr, result_with_mmr in zip(\n result_similarities_no_mmr, result_similarities\n ):\n assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)", + "completions": [ + "```python\ndef test_get_top_k_mmr_embeddings() -> None:\n query_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]\n embeddings = [\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n ]\n result_similarities, result_ids = get_top_k_embeddings(\n query_embedding, embeddings, similarity_top_k=3\n )\n assert result_similarities == [0.5, 0.4, 0.3]\n assert result_ids == [0, 1, 2]\n```\n" + ], + "line": 35, + "token": 309, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 97, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef frame_accuracy(\n logits: np.ndarray,\n label: np.ndarray,\n background_logit_threshold: float = 0.0,\n) -> float:\n top1_idx = np.argmax(logits, axis=-1)\n background_label = np.sum(label, axis=-1) == 0\n pred_background = (\n np.max(logits, axis=-1) <= background_logit_threshold\n ).astype(np.int32)\n\n # Extracts the label at the highest logit index for each input.\n top1_correct = np.take_along_axis(label, top1_idx[..., None], axis=-1)\n top1_correct = np.squeeze(top1_correct)\n foreground_correct = ~pred_background.astype(bool) * top1_correct\n\n # Count correctly classified background samples.\n background_correct = pred_background * background_label\n correct = foreground_correct + background_correct\n return np.sum(correct) / len(logits)\n```\n###test function signature:\n```python\n\n def test_frame_accuracy_all_background(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_frame_accuracy_all_background", + "reference": "\n def test_frame_accuracy_all_background(self):\n logits = np.array(\n [\n [1.5, 0.0, -1.0],\n [0.1, 0.7, 0.2],\n [0.1, 0.7, 0.2],\n [0.1, -0.7, 0.2],\n [-0.1, -0.7, -0.2],\n ],\n np.float32,\n )\n label = np.array(\n [\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n ],\n np.int32,\n )\n actual = metrics.frame_accuracy(logits, label)\n self.assertAlmostEqual(actual, 0.2)", + "completions": [ + "```python\ndef test_frame_accuracy_all_background(self):\n logits = np.array([[0.1, 0.2, 0.3, 0.4]])\n label = np.array([[0, 0, 0, 0]])\n self.assertAlmostEqual(frame_accuracy(logits, label), 1.0)\n```\n" + ], + "line": 28, + "token": 206, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 98, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_local_image_to_be_loaded", + "reference": "\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n file_path, _ = example_local_image\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=file_path,\n max_width=64,\n max_height=128,\n )\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert abs(scaling_factor - 0.5) < 1e-5\n assert decoding_result.shape == (64, 64, 3)\n assert (decoding_result == 0).all()", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n reference, image = example_local_image\n # when\n result = load_image_from_string(\n reference=reference,\n max_height=image.shape[0],\n max_width=image.shape[1],\n )\n # then\n assert result[0] == reference\n assert result[1] is None\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 99, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = \"htk\") -> Union[float, np.ndarray]:\n\n if mel_scale not in [\"slaney\", \"htk\", \"kaldi\"]:\n raise ValueError('mel_scale should be one of \"htk\", \"slaney\" or \"kaldi\".')\n\n if mel_scale == \"htk\":\n return 2595.0 * np.log10(1.0 + (freq / 700.0))\n elif mel_scale == \"kaldi\":\n return 1127.0 * np.log(1.0 + (freq / 700.0))\n\n min_log_hertz = 1000.0\n min_log_mel = 15.0\n logstep = 27.0 / np.log(6.4)\n mels = 3.0 * freq / 200.0\n\n if isinstance(freq, np.ndarray):\n log_region = freq >= min_log_hertz\n mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep\n elif freq >= min_log_hertz:\n mels = min_log_mel + np.log(freq / min_log_hertz) * logstep\n\n return mels\n```\n###test function signature:\n```python\n\n def test_hertz_to_mel(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_hertz_to_mel", + "reference": "\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(0.0), 0.0)\n self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)\n\n inputs = np.array([100, 200])\n expected = np.array([150.48910241, 283.22989816])\n self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))\n\n self.assertEqual(hertz_to_mel(0.0, \"slaney\"), 0.0)\n self.assertEqual(hertz_to_mel(100, \"slaney\"), 1.5)\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"slaney\"), expected))\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"kaldi\"), expected))\n\n with pytest.raises(ValueError):\n hertz_to_mel(100, mel_scale=None)", + "completions": [ + "```python\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(" + ], + "line": 27, + "token": 262, + "line_diff": 0, + "token_diff": 0 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/8/EI/line_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Python/8/EI/line_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..abc6dfb4b47fc193c4ff5adae0922cbc9812aa58 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/8/EI/line_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(15~104),subset_1(133~187),subset_2,subset_3,subset_4,subset_5,subset_6,subset_7(749~749) +StarCoder2-15b,24.25,27.66,0.00,0.00,0.00,0.00,0.00,20.70 +CodeLlama-7b,26.39,26.96,0.00,0.00,0.00,0.00,0.00,27.18 +CodeLlama-13b,26.72,23.82,0.00,0.00,0.00,0.00,0.00,31.82 +CodeLlama-34b,26.98,29.47,0.00,0.00,0.00,0.00,0.00,24.58 +DeepSeek-Coder-1.3b,25.45,30.19,0.00,0.00,0.00,0.00,0.00,21.28 +DeepSeek-Coder-6.7b,24.67,25.76,0.00,0.00,0.00,0.00,0.00,20.74 +DeepSeek-Coder-33b,28.81,33.38,0.00,0.00,0.00,0.00,0.00,25.46 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/8/EI/token_counts_EI.csv b/dataset/Test Generation/ComplexCodeEval-Python/8/EI/token_counts_EI.csv new file mode 100644 index 0000000000000000000000000000000000000000..847be51ed4c3b36c17867768dadeb11ffb3b7e81 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/8/EI/token_counts_EI.csv @@ -0,0 +1,8 @@ +Models, subset_0(122~985),subset_1(992~1112),subset_2(1981~1981),subset_3,subset_4,subset_5,subset_6,subset_7(7038~7038) +StarCoder2-15b,24.57,28.15,15.68,0.00,0.00,0.00,0.00,20.70 +CodeLlama-7b,26.16,41.70,22.34,0.00,0.00,0.00,0.00,27.18 +CodeLlama-13b,26.55,30.33,10.43,0.00,0.00,0.00,0.00,31.82 +CodeLlama-34b,27.27,28.12,19.34,0.00,0.00,0.00,0.00,24.58 +DeepSeek-Coder-1.3b,25.92,24.69,21.64,0.00,0.00,0.00,0.00,21.28 +DeepSeek-Coder-6.7b,24.78,21.50,29.94,0.00,0.00,0.00,0.00,20.74 +DeepSeek-Coder-33b,29.05,34.66,32.72,0.00,0.00,0.00,0.00,25.46 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/8/EI/tongji.py b/dataset/Test Generation/ComplexCodeEval-Python/8/EI/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/8/EI/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/8/QS/QS.json b/dataset/Test Generation/ComplexCodeEval-Python/8/QS/QS.json new file mode 100644 index 0000000000000000000000000000000000000000..79f8fa75d9c243c882929a94ee2f5b92a44f5c85 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/8/QS/QS.json @@ -0,0 +1,1302 @@ +[ + { + "id": 0, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_pair_vector_and_distance(g: dgl.DGLGraph):\n dst_pos = g.ndata[\"pos\"][g.edges()[1]] + g.edata[\"pbc_offshift\"]\n src_pos = g.ndata[\"pos\"][g.edges()[0]]\n bond_vec = dst_pos - src_pos\n bond_dist = torch.norm(bond_vec, dim=1)\n\n return bond_vec, bond_dist\n```\n###test function signature:\n```python\n\n def test_compute_angle(self, graph_Mo, graph_CH4):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_angle", + "reference": "\n def test_compute_angle(self, graph_Mo, graph_CH4):\n s1, g1, state1 = graph_Mo\n lattice = torch.tensor(s1.lattice.matrix, dtype=matgl.float_th).unsqueeze(dim=0)\n g1.edata[\"pbc_offshift\"] = torch.matmul(g1.edata[\"pbc_offset\"], lattice[0])\n g1.ndata[\"pos\"] = g1.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g1)\n g1.edata[\"bond_vec\"] = bv\n g1.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g1, 4.0)\n\n line_graph = create_line_graph(g1, 4.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n theta = np.arccos(np.clip(cos_loop, -1.0 + 1e-7, 1.0 - 1e-7))\n np.testing.assert_array_almost_equal(np.sort(theta), np.sort(np.array(line_graph.edata[\"theta\"])), decimal=4)\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n s2, g2, state2 = graph_CH4\n lattice = torch.tensor(np.identity(3), dtype=matgl.float_th).unsqueeze(dim=0)\n g2.edata[\"pbc_offshift\"] = torch.matmul(g2.edata[\"pbc_offset\"], lattice[0])\n g2.ndata[\"pos\"] = g2.ndata[\"frac_coords\"] @ lattice[0]\n bv, bd = compute_pair_vector_and_distance(g2)\n g2.edata[\"bond_vec\"] = bv\n g2.edata[\"bond_dist\"] = bd\n cos_loop = _calculate_cos_loop(g2, 2.0)\n\n line_graph = create_line_graph(g2, 2.0)\n line_graph.apply_edges(compute_theta_and_phi)\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )\n\n # test only compute theta\n line_graph.apply_edges(partial(compute_theta, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.arccos(np.array(cos_loop))), np.sort(np.array(line_graph.edata[\"theta\"]))\n )\n\n # test only compute theta with cosine\n _ = line_graph.edata.pop(\"cos_theta\")\n line_graph.apply_edges(partial(compute_theta, cosine=True, directed=False))\n np.testing.assert_array_almost_equal(\n np.sort(np.array(cos_loop)), np.sort(np.array(line_graph.edata[\"cos_theta\"]))\n )", + "completions": [ + "```python\n def test_compute_angle(self, graph_Mo, graph_CH4):\n \"\"\"\n Test compute_angle function\n \"\"\"\n # compute angle between Mo-C and Mo-H\n angle_Mo_C = compute_angle(graph_Mo, graph_CH4, 0, 1)\n angle_Mo_H = compute_angle(graph_Mo, graph_CH4, 0, 2)\n\n # check angle between Mo-C and Mo-H\n self.assertAlmostEqual(angle_Mo_C, 104.5, places=1)\n self.assertAlmostEqual(angle_Mo_H, 109.5, places=1)\n```\n" + ], + "line": 16, + "token": 148, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 1, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def _make_schedule(self, targets: List[Task]) -> List[Task]:\n self.schedule = []\n self.dependencies = self._build_dependencies(targets)\n\n edge_tups = []\n for node in self.dependencies:\n for dependency in self.dependencies[node]:\n edge_tups.append((dependency, node))\n\n for task in targets:\n # add edges from a dummy node to each target to guarantee\n # they will be included in the final schedule\n edge_tups.append((Executor.DUMMY_TASK_VALUE, task))\n\n def _compare_key(task: Union[Task, str]):\n if task == Executor.DUMMY_TASK_VALUE:\n return (\"\", 0)\n return (\n task.group_label() or \"\",\n -task.priority(),\n )\n\n graph = networkx.DiGraph(edge_tups)\n res = [\n t\n for t in networkx.lexicographical_topological_sort(graph, key=_compare_key)\n if t != Executor.DUMMY_TASK_VALUE\n ]\n return res\n```\n###test function signature:\n```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_group_label_with_dependencies", + "reference": "\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = create_mock_task(\"task1\", result=1, group_label=\"group1\")\n task2 = create_mock_task(\n \"task2\", result=2, dependencies={\"task1\": task1}, group_label=\"group2\"\n )\n task3 = create_mock_task(\n \"task3\", result=3, dependencies={\"task2\": task2}, group_label=\"group1\"\n )\n\n executor = Executor([task3])\n schedule = executor._make_schedule([task3])\n scheduled_labels = [\n task.group_label() for task in schedule if task.group_label()\n ]\n\n # Check if task3 is scheduled after task1 and task2 due to dependency, even though it has the same group label as task1\n group1_indices = [\n i for i, label in enumerate(scheduled_labels) if label == \"group1\"\n ]\n group2_index = scheduled_labels.index(\"group2\")\n\n assert (\n group1_indices[-1] > group2_index\n ), \"Task with the same group label but later dependency was not scheduled after different group label\"", + "completions": [ + "```python\n\n def test_group_label_with_dependencies(self):\n # Creating tasks with dependencies and group labels\n task1 = Task(name=\"task1\", group_label=\"group1\")\n task2 = Task(name=\"task2\", group_label=\"group1\")\n task3 = Task(name=\"task3\", group_label=\"group2\")\n task4 = Task(name=\"task4\", group_label=\"group2\")\n task5 = Task(name=\"task5\", group_label=\"group3\")\n task6 = Task(name=\"task6\", group_label=\"group3\")\n task7 = Task(name=\"task7\", group_label=\"group4\")\n task8 = Task(name=\"task8\", group_label=\"group4\")\n task9 = Task(name=\"task9\", group_label=\"group5\")\n task10 = Task(name=\"task10\", group_label=\"group5\")\n\n # Creating dependencies between tasks\n task1.depends_on(task2)\n task2.depends_on(task3)\n task3.depends_on(task4)\n task4.depends_on(task5)\n task5.depends_on(task6)\n task6.depends_on(task7)\n task7.depends_on(task8)\n task8.depends_on(task9)\n task9.depends_on(task10)\n\n # Creating a list of tasks\n tasks = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies\n tasks_with_dependencies = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with group labels\n tasks_with_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_group_labels = [task1, task2, task3, task4, task5, task6, task7, task8, task9, task10]\n\n # Creating a list of tasks with dependencies and group labels\n tasks_with_dependencies_and_" + ], + "line": 36, + "token": 242, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 2, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n):\n if max_seq_len >= max_seq_len and max_batch_size >= max_batch_size:\n return\n\n head_dim = dim // heads\n max_seq_len = find_multiple(max_seq_len, 8)\n\n for b in layers:\n b.attention.kv_cache = KVCache(\n max_batch_size, max_seq_len, heads, head_dim\n )\n\n freq_cis = precompute_freq_cis(block_size, dim // heads, rope_base)\n causal_mask = torch.tril(\n torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)\n )\n\n return causal_mask, freq_cis\n```\n###test function signature:\n```python\n\ndef test_setup_cache_max_batch_size_greater_than_max():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_setup_cache_max_batch_size_greater_than_max", + "reference": "\ndef test_setup_cache_max_batch_size_greater_than_max():\n layers = [\n Mock(attention=Mock(kw_cache=None)),\n Mock(attention=Mock(kw_cache=None)),\n ]\n max_batch_size = 64\n max_seq_len = 32\n heads = 4\n head_dim = 32\n block_size = 32\n rope_base = 1000\n setup_cache(\n max_batch_size + 10,\n max_seq_len,\n head_dim * heads,\n heads,\n layers,\n block_size,\n rope_base,\n )\n for layer in layers:\n assert isinstance(layer.attention.kw_cache, KVCache)\n assert layer.attention.kw_cache.k_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )\n assert layer.attention.kw_cache.v_cache.shape == torch.Size(\n [max_batch_size + 10, heads, max_seq_len, head_dim]\n )", + "completions": [ + "```python\ndef test_setup_cache_max_batch_size_greater_than_max():\n max_batch_size = 10\n max_seq_len = 10\n dim = 10\n heads = 10\n layers = [1, 2, 3]\n block_size = 10\n rope_base = 10\n\n causal_mask, freq_cis = setup_cache(\n max_batch_size, max_seq_len, dim, heads, layers, block_size, rope_base\n )\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size == max_batch_size\n assert b.attention.kv_cache.max_seq_len == max_seq_len\n assert b.attention.kv_cache.heads == heads\n assert b.attention.kv_cache.head_dim == dim // heads\n\n assert causal_mask.shape == (max_seq_len, max_seq_len)\n assert causal_mask.dtype == torch.bool\n assert causal_mask.all()\n\n assert freq_cis.shape == (block_size, dim // heads)\n assert freq_cis.dtype == torch.float32\n assert freq_cis.all()\n\n for b in layers:\n assert b.attention.kv_cache.max_batch_size" + ], + "line": 26, + "token": 168, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 3, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_file_to_be_downloaded", + "reference": "\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n response = Response()\n response.status_code = 200\n image = np.zeros((128, 128, 3), dtype=np.uint8)\n _, encoded_image = cv2.imencode(\".jpg\", image)\n response._content = encoded_image\n requests_get_mock.return_value = response\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=\"https://some/file.jpg\"\n )\n\n # then\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert scaling_factor is None\n assert decoding_result.shape == image.shape\n assert np.allclose(decoding_result, image)\n requests_get_mock.assert_called_once_with(\"https://some/file.jpg\")", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_file_to_be_downloaded(\n requests_get_mock: MagicMock,\n) -> None:\n # given\n reference = \"https://www.example.com/image.jpg\"\n max_height = 100\n max_width = 100\n requests_get_mock.return_value.content = b\"image_bytes\"\n # when\n result = load_image_from_string(\n reference=reference, max_height=max_height, max_width=max_width\n )\n # then\n assert result == (\n \"data:image/jpeg;base64,image_bytes\",\n None,\n )\n```\n" + ], + "line": 47, + "token": 276, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 4, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_random_circuits", + "reference": "\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n placer = placer(connectivity=connectivity)\n layout_circ = Circuit(5)\n initial_layout = placer(layout_circ)\n router = Sabre(connectivity=connectivity, lookahead=look, decay_lookahead=decay)\n circuit = generate_random_circuit(nqubits=5, ngates=n_gates)\n measurement = gates.M(*range(5))\n circuit.add(measurement)\n transpiled_circuit, final_qubit_map = router(circuit, initial_layout)\n assert router.added_swaps >= 0\n assert_connectivity(connectivity, transpiled_circuit)\n assert_placement(transpiled_circuit, final_qubit_map)\n assert n_gates + router.added_swaps + 1 == transpiled_circuit.ngates\n assert_circuit_equivalence(\n original_circuit=circuit,\n transpiled_circuit=transpiled_circuit,\n final_map=final_qubit_map,\n initial_map=initial_layout,\n )\n assert transpiled_circuit.queue[-1].register_name == measurement.register_name", + "completions": [ + "```python\ndef test_sabre_random_circuits(n_gates, look, decay, placer, connectivity):\n import random\n import networkx as nx\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import Program\n from pyquil.gates import *\n from pyquil.paulis import PauliTerm, PauliSum\n from pyquil.api import QVMConnection\n from pyquil.quil import" + ], + "line": 27, + "token": 250, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 5, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef xx_circuit_step(source, strength, target, embodiment):\n\n permute_source_for_overlap, permute_target_for_overlap = None, None\n\n # apply all possible reflections, shifts to the source\n for source_reflection_name in reflection_options:\n reflected_source_coord, source_reflection, reflection_phase_shift = apply_reflection(\n source_reflection_name, source\n )\n for source_shift_name in shift_options:\n shifted_source_coord, source_shift, shift_phase_shift = apply_shift(\n source_shift_name, reflected_source_coord\n )\n\n # check for overlap, back out permutation\n source_shared, target_shared = None, None\n for i, j in [(0, 0), (0, 1), (0, 2), (1, 0), (1, 1), (1, 2), (2, 0), (2, 1), (2, 2)]:\n\n if (\n abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi)) < EPSILON\n or abs(np.mod(abs(shifted_source_coord[i] - target[j]), np.pi) - np.pi)\n < EPSILON\n ):\n source_shared, target_shared = i, j\n break\n if source_shared is None:\n continue\n\n # pick out the other coordinates\n source_first, source_second = (x for x in [0, 1, 2] if x != source_shared)\n target_first, target_second = (x for x in [0, 1, 2] if x != target_shared)\n\n # check for arccos validity\n r, s, u, v, x, y = decompose_xxyy_into_xxyy_xx(\n float(target[target_first]),\n float(target[target_second]),\n float(shifted_source_coord[source_first]),\n float(shifted_source_coord[source_second]),\n float(strength),\n )\n if any(math.isnan(val) for val in (r, s, u, v, x, y)):\n continue\n\n # OK: this combination of things works.\n # save the permutation which rotates the shared coordinate into ZZ.\n permute_source_for_overlap = canonical_rotation_circuit(source_first, source_second)\n permute_target_for_overlap = canonical_rotation_circuit(target_first, target_second)\n break\n\n if permute_source_for_overlap is not None:\n break\n\n if permute_source_for_overlap is None:\n raise QiskitError(\n \"Error during RZX decomposition: Could not find a suitable Weyl \"\n f\"reflection to match {source} to {target} along {strength}.\"\n )\n\n prefix_circuit, affix_circuit = QuantumCircuit(2), QuantumCircuit(2)\n\n # the basic formula we're trying to work with is:\n # target^p_t_f_o =\n # rs * (source^s_reflection * s_shift)^p_s_f_o * uv * operation * xy\n # but we're rearranging it into the form\n # target = affix source prefix\n # and computing just the prefix / affix circuits.\n\n # the outermost prefix layer comes from the (inverse) target permutation.\n prefix_circuit.compose(permute_target_for_overlap.inverse(), inplace=True)\n # the middle prefix layer comes from the local Z rolls.\n prefix_circuit.rz(2 * x, [0])\n prefix_circuit.rz(2 * y, [1])\n prefix_circuit.compose(embodiment, inplace=True)\n prefix_circuit.rz(2 * u, [0])\n prefix_circuit.rz(2 * v, [1])\n # the innermost prefix layer is source_reflection, shifted by source_shift,\n # finally conjugated by p_s_f_o.\n prefix_circuit.compose(permute_source_for_overlap, inplace=True)\n prefix_circuit.compose(source_reflection, inplace=True)\n prefix_circuit.global_phase += -np.log(reflection_phase_shift).imag\n prefix_circuit.global_phase += -np.log(shift_phase_shift).imag\n\n # the affix circuit is constructed in reverse.\n # first (i.e., innermost), we install the other half of the source transformations and p_s_f_o.\n affix_circuit.compose(source_reflection.inverse(), inplace=True)\n affix_circuit.compose(source_shift, inplace=True)\n affix_circuit.compose(permute_source_for_overlap.inverse(), inplace=True)\n # then, the other local rolls in the middle.\n affix_circuit.rz(2 * r, [0])\n affix_circuit.rz(2 * s, [1])\n # finally, the other half of the p_t_f_o conjugation.\n affix_circuit.compose(permute_target_for_overlap, inplace=True)\n\n return {\"prefix_circuit\": prefix_circuit, \"affix_circuit\": affix_circuit}\n```\n###test function signature:\n```python\n def test_xx_circuit_step(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_xx_circuit_step", + "reference": " def test_xx_circuit_step(self):\n\n for _ in range(100):\n source_coordinate, interaction, target_coordinate = self._generate_xxyy_test_case()\n\n source_embodiment = qiskit.QuantumCircuit(2)\n source_embodiment.append(UnitaryGate(canonical_matrix(*source_coordinate)), [0, 1])\n interaction_embodiment = qiskit.QuantumCircuit(2)\n interaction_embodiment.append(UnitaryGate(canonical_matrix(*interaction)), [0, 1])\n\n prefix_circuit, affix_circuit = itemgetter(\"prefix_circuit\", \"affix_circuit\")(\n xx_circuit_step(\n source_coordinate, interaction[0], target_coordinate, interaction_embodiment\n )\n )\n\n target_embodiment = QuantumCircuit(2)\n target_embodiment.compose(prefix_circuit, inplace=True)\n target_embodiment.compose(source_embodiment, inplace=True)\n target_embodiment.compose(affix_circuit, inplace=True)\n self.assertTrue(\n np.all(\n np.abs(\n qiskit.quantum_info.operators.Operator(target_embodiment).data\n - canonical_matrix(*target_coordinate)\n )\n < EPSILON\n )\n )", + "completions": [ + "```python\n def test_xx_circuit_step(self):\n source = np.array([0.1, 0.2, 0.3])\n strength = 0.4\n target = np.array([0.5, 0.6, 0.7])\n embodiment = QuantumCircuit(2)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.cx(0, 1)\n embodiment.barrier()\n embodiment.rz(0.1, 0)\n embodiment.rz(0.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.3, 0)\n embodiment.rz(0.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.5, 0)\n embodiment.rz(0.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.7, 0)\n embodiment.rz(0.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(0.9, 0)\n embodiment.rz(1.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.1, 0)\n embodiment.rz(1.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.3, 0)\n embodiment.rz(1.4, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.5, 0)\n embodiment.rz(1.6, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.7, 0)\n embodiment.rz(1.8, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(1.9, 0)\n embodiment.rz(2.0, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment.h(1)\n embodiment.barrier()\n embodiment.rz(2.1, 0)\n embodiment.rz(2.2, 1)\n embodiment.barrier()\n embodiment.cx(0, 1)\n embodiment.h(0)\n embodiment" + ], + "line": 90, + "token": 807, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 6, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_einsum(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_einsum", + "reference": "\n def test_einsum(self):\n x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n self.assertAllClose(\n knp.einsum(\"ijk,lkj->il\", x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"ijk,ikj->i\", x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.einsum(\"i...,j...k->...ijk\", x, y),\n np.einsum(\"i..., j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n self.assertAllClose(\n knp.Einsum(\"ijk,lkj->il\")(x, y),\n np.einsum(\"ijk,lkj->il\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"ijk,ikj->i\")(x, y),\n np.einsum(\"ijk,ikj->i\", x, y),\n )\n self.assertAllClose(\n knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n np.einsum(\"i...,j...k->...ijk\", x, y),\n )\n self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))", + "completions": [ + "```python\n\n def test_arange(self):\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32),\n torch.arange(0, 10, 1, dtype=torch.float32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64),\n torch.arange(0, 10, 1, dtype=torch.float64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32),\n torch.arange(0, 10, 1, dtype=torch.int32),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64),\n torch.arange(0, 10, 1, dtype=torch.int64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool),\n torch.arange(0, 10, 1, dtype=torch.bool),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64),\n torch.arange(0, 10, 1, dtype=torch.complex64),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128),\n torch.arange(0, 10, 1, dtype=torch.complex128),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.float64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int32, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.int64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.bool, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex64, device=\"cuda\"),\n )\n self.assertEqual(\n arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n torch.arange(0, 10, 1, dtype=torch.complex128, device=\"cuda\"),\n )\n```\n" + ], + "line": 25, + "token": 195, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 7, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_center_padding(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_center_padding", + "reference": "\n def test_spectrogram_center_padding(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"reflect\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.1287945 , 0.12792738, 0.08311573, 0.03155122, 0.02470202,\n 0.00727857, 0.00910694, 0.00686163, 0.01238981, 0.01473668,\n 0.00336144, 0.00370314, 0.00600871, 0.01120164, 0.01942998,\n 0.03132008, 0.0232842 , 0.01124642, 0.02754783, 0.02423725,\n 0.00147893, 0.00038027, 0.00112299, 0.00596233, 0.00571529,\n 0.02084235, 0.0231855 , 0.00810006, 0.01837943, 0.00651339,\n 0.00093931, 0.00067426, 0.01058399, 0.01270507, 0.00151734,\n 0.00331913, 0.00302416, 0.01081792, 0.00754549, 0.00148963,\n 0.00111943, 0.00152573, 0.00608017, 0.01749986, 0.01205949,\n 0.0143082 , 0.01910573, 0.00413786, 0.03916619, 0.09873404,\n 0.08302026, 0.02673891, 0.00401255, 0.01397392, 0.00751862,\n 0.01024884, 0.01544606, 0.00638907, 0.00623633, 0.0085103 ,\n 0.00217659, 0.00276204, 0.00260835, 0.00299299,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=True,\n pad_mode=\"constant\",\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.06558744, 0.06889656, 0.06263352, 0.04264418, 0.03404115,\n 0.03244197, 0.02279134, 0.01646339, 0.01452216, 0.00826055,\n 0.00062093, 0.0031821 , 0.00419456, 0.00689327, 0.01106367,\n 0.01712119, 0.01721762, 0.00977533, 0.01606626, 0.02275621,\n 0.01727687, 0.00992739, 0.01217688, 0.01049927, 0.01022947,\n 0.01302475, 0.01166873, 0.01081812, 0.01057327, 0.00767912,\n 0.00429567, 0.00089625, 0.00654583, 0.00912084, 0.00700984,\n 0.00225026, 0.00290545, 0.00667712, 0.00730663, 0.00410813,\n 0.00073102, 0.00219296, 0.00527618, 0.00996585, 0.01123781,\n 0.00872816, 0.01165121, 0.02047945, 0.03681747, 0.0514379 ,\n 0.05137928, 0.03960042, 0.02821562, 0.01813349, 0.01201322,\n 0.01260964, 0.00900654, 0.00207905, 0.00456714, 0.00850599,\n 0.00788239, 0.00664407, 0.00824227, 0.00628301,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=128,\n center=False,\n )\n self.assertEqual(spec.shape, (257, 728))\n\n # fmt: off\n expected = np.array([\n 0.00250445, 0.02161521, 0.06232229, 0.04339567, 0.00937727,\n 0.01080616, 0.00248685, 0.0095264 , 0.00727476, 0.0079152 ,\n 0.00839946, 0.00254932, 0.00716622, 0.005559 , 0.00272623,\n 0.00581774, 0.01896395, 0.01829788, 0.01020514, 0.01632692,\n 0.00870888, 0.02065827, 0.0136022 , 0.0132382 , 0.011827 ,\n 0.00194505, 0.0189979 , 0.026874 , 0.02194014, 0.01923883,\n 0.01621437, 0.00661967, 0.00289517, 0.00470257, 0.00957801,\n 0.00191455, 0.00431664, 0.00544359, 0.01126213, 0.00785778,\n 0.00423469, 0.01322504, 0.02226548, 0.02318576, 0.03428908,\n 0.03648811, 0.0202938 , 0.011902 , 0.03226198, 0.06347476,\n 0.01306318, 0.05308729, 0.05474771, 0.03127991, 0.00998512,\n 0.01449977, 0.01272741, 0.00868176, 0.00850386, 0.00313876,\n 0.00811857, 0.00538216, 0.00685749, 0.00535275,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 0], expected))", + "completions": [ + "```python\n def test_spectrogram_center_padding(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 4\n hop_length = 2\n fft_length = 16\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n np.testing.assert_array_equal(\n spectrogram,\n np.array(\n [\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],\n [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0," + ], + "line": 95, + "token": 804, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 8, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_kword_ref_line", + "reference": "\ndef test_nanoplot_out_bars_with_kword_ref_line():\n\n out_bars_with_kword_ref_line = _generate_nanoplot(**CASES[12])\n\n assert _is_nanoplot_output(out_bars_with_kword_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_kword_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_kword_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = \"max\"\n y_ref_area = [1, 2]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = \"min\"\n expand_y = \"min\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_stroke_color=data_bar_negative_stroke_" + ], + "line": 749, + "token": 7038, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 9, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef arange(start, stop=None, step=1, dtype=None):\n if dtype is None:\n dtypes_to_resolve = [\n getattr(start, \"dtype\", type(start)),\n getattr(step, \"dtype\", type(step)),\n ]\n if stop is not None:\n dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n dtype = dtypes.result_type(*dtypes_to_resolve)\n dtype = to_torch_dtype(dtype)\n if stop is None:\n return torch.arange(end=start, dtype=dtype, device=get_device())\n return torch.arange(\n start, stop, step=step, dtype=dtype, device=get_device()\n )\n```\n###test function signature:\n```python\n\n def test_take_along_axis(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_take_along_axis", + "reference": "\n def test_take_along_axis(self):\n x = np.arange(24).reshape([1, 2, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=1),\n np.take_along_axis(x, indices, axis=1),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=1)(x, indices),\n np.take_along_axis(x, indices, axis=1),\n )\n\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )\n\n # Test with axis=None\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.array([1, 2, 3], dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=None),\n np.take_along_axis(x, indices, axis=None),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=None)(x, indices),\n np.take_along_axis(x, indices, axis=None),\n )\n\n # Test with negative indices\n x = np.arange(12).reshape([1, 1, 3, 4])\n indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n self.assertAllClose(\n knp.take_along_axis(x, indices, axis=2),\n np.take_along_axis(x, indices, axis=2),\n )\n self.assertAllClose(\n knp.TakeAlongAxis(axis=2)(x, indices),\n np.take_along_axis(x, indices, axis=2),\n )", + "completions": [ + "```python\n\n def test_take_along_axis(self):\n # test arange\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=1, dtype=torch.float32),\n torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=2, dtype=torch.float32),\n torch.tensor([0, 2, 4, 6, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=3, dtype=torch.float32),\n torch.tensor([0, 3, 6, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=4, dtype=torch.float32),\n torch.tensor([0, 4, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=5, dtype=torch.float32),\n torch.tensor([0, 5], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=6, dtype=torch.float32),\n torch.tensor([0, 6], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=7, dtype=torch.float32),\n torch.tensor([0, 7], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=8, dtype=torch.float32),\n torch.tensor([0, 8], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=9, dtype=torch.float32),\n torch.tensor([0, 9], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=10, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=11, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=12, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=13, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32),\n )\n )\n self.assertTrue(\n torch.equal(\n arange(start=0, stop=10, step=14, dtype=torch.float32),\n torch.tensor([0], dtype=torch.float32" + ], + "line": 25, + "token": 195, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 10, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_correlation(\n response_level_scores: dict[str, list[dict[str, dict[str, Any]]]],\n name: str = '',\n) -> dict[str, dict[str, dict[str, float]]]:\n factscore_metrics = set(list_metrics(response_level_scores[FACTSCORE]))\n pred_metrics = set(list_metrics(response_level_scores[EVAL_METHOD]))\n result = {}\n\n for metric in factscore_metrics.intersection(pred_metrics):\n factscore_scores = [\n find_metric(response_level_scores[FACTSCORE][i], metric)\n for i in range(len(response_level_scores[FACTSCORE]))\n ]\n pred_scores = [\n find_metric(response_level_scores[EVAL_METHOD][i], metric)\n for i in range(len(response_level_scores[EVAL_METHOD]))\n ]\n\n # Remove failed runs\n for i in range(len(factscore_scores) - 1, -1, -1):\n if factscore_scores[i] == -1 or pred_scores[i] == -1:\n factscore_scores.pop(i)\n pred_scores.pop(i)\n\n # Cannot calculate correlation with only one data point\n if len(factscore_scores) <= 1 or len(pred_scores) <= 1:\n pearson_result, spearman_result = None, None\n else:\n pearson_result = stats.pearsonr(factscore_scores, pred_scores)\n spearman_result = stats.spearmanr(factscore_scores, pred_scores)\n\n scatter_plot(\n factscore_scores,\n pred_scores,\n f'{name}-{metric}',\n FACTSCORE,\n EVAL_METHOD,\n )\n result[metric] = {\n _PEARSON: CorrelationResult(pearson_result).result,\n _SPEARMAN: CorrelationResult(spearman_result).result,\n }\n\n return result\n```\n###test function signature:\n```python\n\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_correlation_one_data_point", + "reference": "\n def test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n ) -> None:\n response_level_scores = {\n correlation_vs_factscore.FACTSCORE: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'N': 1}},\n ],\n correlation_vs_factscore.EVAL_METHOD: [\n {correlation_vs_factscore.RAW_STR: {'S': 0, 'M': 1}},\n ],\n }\n actual_output = correlation_vs_factscore.compute_correlation(\n response_level_scores=response_level_scores\n )\n mock_pearsonr.assert_not_called()\n mock_spearmanr.assert_not_called()\n mock_scatter_plot.assert_called()\n self.assertIsInstance(actual_output, dict)\n self.assertLen(actual_output, 1)\n self.assertIn('S', actual_output)\n self.assertIsInstance(actual_output['S'], dict)\n self.assertLen(actual_output['S'], 2)\n self.assertIn('Pearson', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Pearson'], dict)\n self.assertIn('correlation', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['correlation']))\n self.assertIn('p_value', actual_output['S']['Pearson'])\n self.assertTrue(math.isnan(actual_output['S']['Pearson']['p_value']))\n self.assertIn('Spearman', actual_output['S'])\n self.assertIsInstance(actual_output['S']['Spearman'], dict)\n self.assertIn('correlation', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['correlation']))\n self.assertIn('p_value', actual_output['S']['Spearman'])\n self.assertTrue(math.isnan(actual_output['S']['Spearman']['p_value']))", + "completions": [ + "```python\ndef test_compute_correlation_one_data_point(\n self,\n mock_spearmanr: mock.Mock,\n mock_pearsonr: mock.Mock,\n mock_scatter_plot: mock.Mock,\n) -> None:\n response_level_scores = {\n FACTSCORE: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n EVAL_METHOD: [\n {\n 'metric': 'metric1',\n 'value': 0.5,\n },\n {\n 'metric': 'metric2',\n 'value': 0.6,\n },\n ],\n }\n result = compute_correlation(response_level_scores)\n self.assertEqual(result, {})\n mock_spearmanr.assert_not_called()\n mock_pearsonr.assert_not_called()\n mock_scatter_plot.assert_not_called()\n```\n" + ], + "line": 54, + "token": 384, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 11, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_data_lines():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_data_lines", + "reference": "\ndef test_nanoplot_out_data_lines():\n\n out_data_lines = _generate_nanoplot(**CASES[0])\n\n assert _is_nanoplot_output(out_data_lines)\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_data_lines,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_data_lines():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n expected = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color,\n data_bar_negative_" + ], + "line": 749, + "token": 7038, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 12, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_daily_temp_mean", + "reference": "\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data = il_electricity_cdd_hdd_daily[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_daily[\"temperature_data\"]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert df.shape == (810, 3)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_mean\",\n ]\n\n assert round(df.temperature_mean.mean()) == 55.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_daily_temp_mean(il_electricity_cdd_hdd_daily):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"D\", tz=\"UTC\"\n )\n temperature_data = pd.DataFrame(\n {\"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10]},\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n df = compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=[],\n cooling_balance_points=[],\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n )\n expected = pd.DataFrame(\n {\n \"temperature_mean\": [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],\n \"n_hours_kept\": [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],\n \"n_hours_dropped\": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n },\n index=pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-12\", freq=\"D\", tz=\"UTC\"\n ),\n )\n pd.testing.assert_frame_equal(df, expected)\n```\n" + ], + "line": 166, + "token": 983, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 13, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef power_color(\n frequency,\n power,\n power_err=None,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n):\n freq_edges = np.asarray(freq_edges)\n if len(freq_edges) != 5:\n raise ValueError(\"freq_edges must have 5 elements\")\n\n frequency = np.asarray(frequency)\n power = np.asarray(power)\n\n if df is None:\n df = np.median(np.diff(frequency))\n input_frequency_low_edges = frequency - df / 2\n input_frequency_high_edges = frequency + df / 2\n\n if freq_edges.min() < input_frequency_low_edges[0]:\n raise ValueError(\"The minimum frequency is larger than the first frequency edge\")\n if freq_edges.max() > input_frequency_high_edges[-1]:\n raise ValueError(\"The maximum frequency is lower than the last frequency edge\")\n\n if power_err is None:\n power_err = power / np.sqrt(m)\n else:\n power_err = np.asarray(power_err)\n\n if freqs_to_exclude is not None:\n if len(np.shape(freqs_to_exclude)) == 1:\n freqs_to_exclude = [freqs_to_exclude]\n\n if (\n not isinstance(freqs_to_exclude, Iterable)\n or len(np.shape(freqs_to_exclude)) != 2\n or np.shape(freqs_to_exclude)[1] != 2\n ):\n raise ValueError(\"freqs_to_exclude must be of format [[f0, f1], [f2, f3], ...]\")\n for f0, f1 in freqs_to_exclude:\n frequency_mask = (input_frequency_low_edges > f0) & (input_frequency_high_edges < f1)\n idx0, idx1 = np.searchsorted(frequency, [f0, f1])\n power[frequency_mask] = np.mean([power[idx0], power[idx1]])\n\n var00, var00_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[:2],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var01, var01_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[2:4],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var10, var10_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[1:3],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n var11, var11_err = integrate_power_in_frequency_range(\n frequency,\n power,\n freq_edges[3:5],\n power_err=power_err,\n df=df,\n m=m,\n poisson_power=poisson_power,\n )\n pc0 = var00 / var01\n pc1 = var10 / var11\n pc0_err = pc0 * (var00_err / var00 + var01_err / var01)\n pc1_err = pc1 * (var10_err / var10 + var11_err / var11)\n if return_log:\n pc0_err = 1 / pc0 * pc0_err\n pc1_err = 1 / pc1 * pc1_err\n pc0 = np.log10(pc0)\n pc1 = np.log10(pc1)\n return pc0, pc0_err, pc1, pc1_err\n```\n###test function signature:\n```python\n\n def test_with_power_err(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_power_err", + "reference": "\n def test_with_power_err(self):\n pc0, pc0_err, pc1, pc1_err = power_color(\n self.freq,\n self.power,\n power_err=self.power / 2,\n )\n pc0e, pc0e_err, pc1e, pc1e_err = power_color(\n self.freq,\n self.power,\n power_err=self.power,\n )\n assert np.isclose(pc0, 1, atol=0.001)\n assert np.isclose(pc1, 1, atol=0.001)\n assert np.isclose(pc0e, 1, atol=0.001)\n assert np.isclose(pc1e, 1, atol=0.001)\n assert np.isclose(pc0e_err / pc0_err, 2, atol=0.001)\n assert np.isclose(pc1e_err / pc1_err, 2, atol=0.001)", + "completions": [ + "```python\n def test_with_power_err(self):\n frequency = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n power_err = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])\n pc0, pc0_err, pc1, pc1_err = power_color(\n frequency,\n power,\n power_err=power_err,\n freq_edges=[1 / 256, 1 / 32, 0.25, 2.0, 16.0],\n df=None,\n m=1,\n freqs_to_exclude=None,\n poisson_power=0,\n return_log=False,\n )\n self.assertEqual(pc0, 0.0)\n self.assertEqual(pc0_err, 0.0)\n self.assertEqual(pc1, 0.0)\n self.assertEqual(pc1_err, 0.0)\n```\n" + ], + "line": 97, + "token": 560, + "line_diff": 6, + "token_diff": 5 + }, + { + "id": 14, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def apply_top_k(\n self, logits: numpy.ndarray, filter_value=_MIN_FLOAT\n ) -> numpy.ndarray:\n logits_shape = logits.shape\n logits = logits.reshape(logits.shape[-1])\n top_k_indices = numpy.argpartition(logits, -self.top_k)[-self.top_k :]\n logits[~numpy.isin(numpy.arange(len(logits)), top_k_indices)] = filter_value\n\n return logits.reshape(logits_shape)\n```\n###test function signature:\n```python\n\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_apply_topk", + "reference": "\n def test_apply_topk(\n self,\n ):\n # logits for opt usually have shape (1,1,51200)\n logits = numpy.linspace(0, 1, 11).reshape((1, 1, 11))\n\n token_generator = TokenGenerator(\n logits_shape=logits[-1].shape[-1],\n top_k=3,\n )\n\n filter_value = -float(\"Inf\")\n new_logits = token_generator.apply_top_k(\n logits.copy(), filter_value=filter_value\n )\n\n for _ in range(token_generator.top_k):\n curr_max, idx = numpy.max(new_logits), numpy.argmax(new_logits)\n assert curr_max > filter_value\n new_logits = numpy.delete(new_logits, idx)\n\n assert numpy.all(new_logits == filter_value)", + "completions": [ + "```python\n def test_apply_topk(\n self,\n ):\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]])\n self.assertEqual(\n self.apply_top_k(logits, filter_value=_MIN_FLOAT).tolist(),\n [[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]],\n )\n\n logits = numpy.array([[1, 2, 3, 4, 5, 6, 7, 8, 9," + ], + "line": 21, + "token": 152, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 15, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(\n self,\n sources: List[Union[str, Path, ByteStream]],\n meta: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,\n ):\n parser = MarkdownIt(renderer_cls=RendererPlain)\n if self.table_to_single_line:\n parser.enable(\"table\")\n\n documents = []\n meta_list = normalize_metadata(meta=meta, sources_count=len(sources))\n\n for source, metadata in tqdm(\n zip(sources, meta_list),\n total=len(sources),\n desc=\"Converting markdown files to Documents\",\n disable=not self.progress_bar,\n ):\n try:\n bytestream = get_bytestream_from_source(source)\n except Exception as e:\n logger.warning(\"Could not read {source}. Skipping it. Error: {error}\", source=source, error=e)\n continue\n try:\n file_content = bytestream.data.decode(\"utf-8\")\n text = parser.render(file_content)\n except Exception as conversion_e:\n logger.warning(\n \"Failed to extract text from {source}. Skipping it. Error: {error}\",\n source=source,\n error=conversion_e,\n )\n continue\n\n merged_metadata = {**bytestream.meta, **metadata}\n document = Document(content=text, meta=merged_metadata)\n documents.append(document)\n\n return {\"documents\": documents}\n```\n###test function signature:\n```python\n\n def test_run_calls_normalize_metadata(self, test_files_path):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_calls_normalize_metadata", + "reference": "\n def test_run_calls_normalize_metadata(self, test_files_path):\n bytestream = ByteStream(data=b\"test\", meta={\"author\": \"test_author\", \"language\": \"en\"})\n\n converter = MarkdownToDocument()\n\n with patch(\"haystack.components.converters.markdown.normalize_metadata\") as normalize_metadata, patch(\n \"haystack.components.converters.markdown.MarkdownIt\"\n ):\n converter.run(sources=[bytestream, test_files_path / \"markdown\" / \"sample.md\"], meta={\"language\": \"it\"})\n\n # check that the metadata normalizer is called properly\n normalize_metadata.assert_called_with(meta={\"language\": \"it\"}, sources_count=2)", + "completions": [ + "```python\n def test_run_calls_normalize_metadata(self, test_files_path):\n # Arrange\n sources = [test_files_path / \"test_file.md\"]\n meta = {\"key\": \"value\"}\n\n # Act\n run(sources, meta)\n\n # Assert\n normalize_metadata.assert_called_once_with(meta=meta, sources_count=1)\n```\n" + ], + "line": 45, + "token": 311, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 16, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_config(config: object) -> None:\n if hasattr(config, \"kind\"):\n Config.running_config.classes[config.kind] = config # type: ignore\n Config.running_config.types[config.kind] = {}\n for field in fields(config): # type: ignore\n if hasattr(field, \"type\"):\n Config.running_config.types[config.kind][field.name] = optional_origin(field.type)\n else:\n raise RuntimeError(\"Config must have a 'kind' attribute\")\n```\n###test function signature:\n```python\n\ndef test_config():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_config", + "reference": "\ndef test_config():\n arg_parser = get_arg_parser()\n core_add_args(arg_parser)\n arg_parser.parse_known_args()\n cfg = Config(\"test\")\n cfg.add_config(ConfigTest)\n cfg.init_default_config()\n assert Config.dict() == {\n \"configtest\": {\n \"testvar1\": \"testing123\",\n \"testvar2\": 12345,\n \"testvar3\": {\n \"mydict\": {\"foo\": \"bar\", \"abc\": {\"def\": \"ghi\"}},\n \"myint\": 0,\n \"mystr\": \"Hello\",\n },\n }\n }\n cfg2 = Config(\"test2\")\n assert cfg.configtest.testvar1 == cfg2.configtest.testvar1\n assert cfg.configtest.testvar1 == Config.configtest.testvar1\n assert cfg.configtest.testvar2 == cfg2.configtest.testvar2\n assert cfg.configtest.testvar3 == cfg2.configtest.testvar3\n Config.configtest.testvar2 += 1\n assert cfg.configtest.testvar2 == 12346\n with pytest.raises(ConfigNotFoundError):\n Config.does_not_exist.foo = \"bar\"\n with pytest.raises(ConfigNotFoundError):\n cfg.does_not_exist.foo = \"bar\"", + "completions": [ + "```python\ndef test_config():\n from dataclasses import dataclass, field\n from typing import Optional\n\n @dataclass\n class Config:\n kind: str\n a: Optional[int] = field(default=None)\n b: Optional[int] = field(default=None)\n\n add_config(Config)\n\n assert Config.running_config.classes[\"Config\"] == Config\n assert Config.running_config.types[\"Config\"] == {\"a\": int, \"b\": int}\n```\n" + ], + "line": 19, + "token": 166, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 17, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_with_num_ref_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_with_num_ref_line", + "reference": "\ndef test_nanoplot_out_with_num_ref_line():\n\n out_with_num_ref_line = _generate_nanoplot(**CASES[1])\n\n assert _is_nanoplot_output(out_with_num_ref_line)\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"pattern\",\n attrs=[\n (\"width\", \"8\"),\n (\"height\", \"8\"),\n (\"patternUnits\", \"userSpaceOnUse\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"path\",\n attrs=[\n (\"class\", \"area-closed\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill-opacity\", \"0.7\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"circle\",\n attrs=[\n (\"cx\", \"50.0\"),\n (\"cy\", \"115.0\"),\n (\"r\", \"10\"),\n (\"stroke\", \"#FFFFFF\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#FF0000\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"vert-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"g\",\n attrs=[\n (\"class\", \"y-axis-line\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_with_num_ref_line,\n tag=\"line\",\n attrs=[\n (\"class\", \"ref-line\"),\n (\"x1\", \"50.0\"),\n (\"x2\", \"550\"),\n (\"stroke\", \"#75A8B0\"),\n (\"stroke-width\", \"1\"),\n (\"stroke-dasharray\", \"4 3\"),\n (\"stroke-linecap\", \"round\"),\n (\"vector-effect\", \"non-scaling-stroke\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_with_num_ref_line():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_line = 2\n y_ref_area = [1, 3]\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"line\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_color," + ], + "line": 749, + "token": 7038, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 18, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def begin_typing(self, x=None):\n self.capturekeystrokes = True\n # Disable keypress shortcuts, which may otherwise cause the figure to\n # be saved, closed, etc., until the user stops typing. The way to\n # achieve this depends on whether toolmanager is in use.\n stack = ExitStack() # Register cleanup actions when user stops typing.\n self._on_stop_typing = stack.close\n toolmanager = getattr(\n self.ax.figure.canvas.manager, \"toolmanager\", None)\n if toolmanager is not None:\n # If using toolmanager, lock keypresses, and plan to release the\n # lock when typing stops.\n toolmanager.keypresslock(self)\n stack.callback(toolmanager.keypresslock.release, self)\n else:\n # If not using toolmanager, disable all keypress-related rcParams.\n # Avoid spurious warnings if keymaps are getting deprecated.\n with _api.suppress_matplotlib_deprecation_warning():\n stack.enter_context(mpl.rc_context(\n {k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n```\n###test function signature:\n```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_TextBox", + "reference": "\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n plt.rcParams._set(\"toolbar\", toolbar)\n\n submit_event = mock.Mock(spec=noop, return_value=None)\n text_change_event = mock.Mock(spec=noop, return_value=None)\n tool = widgets.TextBox(ax, '')\n tool.on_submit(submit_event)\n tool.on_text_change(text_change_event)\n\n assert tool.text == ''\n\n do_event(tool, '_click')\n\n tool.set_val('x**2')\n\n assert tool.text == 'x**2'\n assert text_change_event.call_count == 1\n\n tool.begin_typing()\n tool.stop_typing()\n\n assert submit_event.call_count == 2\n\n do_event(tool, '_click')\n do_event(tool, '_keypress', key='+')\n do_event(tool, '_keypress', key='5')\n\n assert text_change_event.call_count == 3", + "completions": [ + "```python\n\ndef test_TextBox(ax, toolbar):\n # Avoid \"toolmanager is provisional\" warning.\n with ExitStack() as stack:\n stack.enter_context(mpl.rc_context({k: [] for k in mpl.rcParams if k.startswith(\"keymap.\")}))\n ax.begin_typing()\n assert ax.capturekeystrokes is True\n assert toolbar.keypresslock.is_locked(ax)\n```\n" + ], + "line": 31, + "token": 258, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 19, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def read_events(\n db: PrefectDBInterface,\n session: AsyncSession,\n events_filter: EventFilter,\n limit: \"int | None\" = None,\n offset: \"int | None\" = None,\n) -> Sequence[ORMEvent]:\n # Always order by occurred timestamp, with placeholder for order direction\n order = sa.desc if events_filter.order == EventOrder.DESC else sa.asc\n\n # Check if distinct fields are provided\n if distinct_fields := build_distinct_queries(events_filter):\n # Define window function\n window_function = (\n sa.func.row_number()\n .over(partition_by=distinct_fields, order_by=order(db.Event.occurred))\n .label(\"row_number\")\n )\n # Create a subquery with the window function\n subquery = (\n sa.select(db.Event, window_function)\n .where(\n sa.and_(\n *events_filter.build_where_clauses(db)\n ) # Ensure the same filters are applied here\n )\n .subquery()\n )\n\n # Alias the subquery for easier column references\n aliased_table = aliased(db.Event, subquery)\n\n # Create the final query from the subquery, filtering to get only rows with row_number = 1\n select_events_query = sa.select(aliased_table).where(subquery.c.row_number == 1)\n\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(subquery.c.occurred))\n\n else:\n # If no distinct fields are provided, create a query for all events\n select_events_query = sa.select(db.Event).where(\n sa.and_(*events_filter.build_where_clauses(db))\n )\n # Order by the occurred timestamp\n select_events_query = select_events_query.order_by(order(db.Event.occurred))\n\n if limit is not None:\n limit = max(0, min(limit, events_filter.logical_limit))\n select_events_query = select_events_query.limit(limit=limit)\n if offset is not None:\n select_events_query = select_events_query.offset(offset=offset)\n\n logger.debug(\"Running PostgreSQL query: %s\", select_events_query)\n\n select_events_query_result = await session.execute(select_events_query)\n return select_events_query_result.scalars().unique().all()\n```\n###test function signature:\n```python\n\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_read_events", + "reference": "\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_1.id\n\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=3),\n until=pendulum.now(\"UTC\").subtract(days=1),\n ),\n ),\n )\n assert len(events) == 1\n assert events[0].id == event_2.id\n\n # Read criteria should apply AND logic\n # Nothing should be returned for this resource\n async with session as session:\n events = await read_events(\n session=session,\n events_filter=EventFilter(\n resource=EventResourceFilter(id=[\"prefect.garbage.foo\"]),\n occurred=EventOccurredFilter(\n since=pendulum.now(\"UTC\").subtract(days=1)\n ),\n ),\n )\n assert len(events) == 0", + "completions": [ + "```python\n async def test_read_events(self, session: AsyncSession, event_1, event_2):\n # Arrange\n events_filter = EventFilter(\n occurred_after=datetime.datetime(2022, 1, 1),\n occurred_before=datetime.datetime(2022, 1, 2),\n order=EventOrder.DESC,\n )\n # Act\n events = await read_events(\n db=self.db,\n session=session,\n events_filter=events_filter,\n )\n # Assert\n self.assertEqual(len(events), 2)\n self.assertEqual(events[0], event_2)\n self.assertEqual(events[1], event_1)\n```\n" + ], + "line": 58, + "token": 412, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 20, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n def test_evaluate_diff_dim(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_evaluate_diff_dim", + "reference": " def test_evaluate_diff_dim(self):\n x1 = np.arange(3, 10, 2)\n kde = mlab.GaussianKDE(x1)\n x2 = np.arange(3, 12, 2)\n y_expected = [\n 0.08797252, 0.11774109, 0.11774109, 0.08797252, 0.0370153\n ]\n y = kde.evaluate(x2)\n np.testing.assert_array_almost_equal(y, y_expected, 7)", + "completions": [ + "```python\n def test_evaluate_diff_dim(self):\n \"\"\"\n Test evaluate with different dimensions\n \"\"\"\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n\n # Generate a random dataset\n X, _ = make_blobs(n_samples=100, n_features=10, centers=5,\n random_state=42)\n\n # Standardize the data\n scaler = StandardScaler()\n X = scaler.fit_transform(X)\n\n # Create a GaussianMixture object\n gmm = GaussianMixture(n_components=5, covariance_type='full')\n gmm.fit(X)\n\n # Generate a random set of points\n n_points = 100\n points = np.random.rand(n_points, 10)\n\n # Evaluate the points\n result = gmm.evaluate(points)\n\n # Check that the result is correct\n self.assertEqual(result.shape, (n_points,))\n self.assertTrue(np.all(result >= 0))\n self.assertTrue(np.all(result <= 1))\n```\n" + ], + "line": 33, + "token": 273, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 21, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _differentiate(func, x, *, args=(), atol=None, rtol=None, maxiter=10,\n order=8, initial_step=0.5, step_factor=2.0,\n step_direction=0, preserve_shape=False, callback=None):\n # TODO (followup):\n # - investigate behavior at saddle points\n # - array initial_step / step_factor?\n # - multivariate functions?\n\n res = _differentiate_iv(func, x, args, atol, rtol, maxiter, order, initial_step,\n step_factor, step_direction, preserve_shape, callback)\n (func, x, args, atol, rtol, maxiter, order,\n h0, fac, hdir, preserve_shape, callback) = res\n\n # Initialization\n # Since f(x) (no step) is not needed for central differences, it may be\n # possible to eliminate this function evaluation. However, it's useful for\n # input validation and standardization, and everything else is designed to\n # reduce function calls, so let's keep it simple.\n temp = eim._initialize(func, (x,), args, preserve_shape=preserve_shape)\n func, xs, fs, args, shape, dtype = temp\n x, f = xs[0], fs[0]\n df = np.full_like(f, np.nan)\n # Ideally we'd broadcast the shape of `hdir` in `_elementwise_algo_init`, but\n # it's simpler to do it here than to generalize `_elementwise_algo_init` further.\n # `hdir` and `x` are already broadcasted in `_differentiate_iv`, so we know\n # that `hdir` can be broadcasted to the final shape.\n hdir = np.broadcast_to(hdir, shape).flatten()\n\n status = np.full_like(x, eim._EINPROGRESS, dtype=int) # in progress\n nit, nfev = 0, 1 # one function evaluations performed above\n # Boolean indices of left, central, right, and (all) one-sided steps\n il = hdir < 0\n ic = hdir == 0\n ir = hdir > 0\n io = il | ir\n\n # Most of these attributes are reasonably obvious, but:\n # - `fs` holds all the function values of all active `x`. The zeroth\n # axis corresponds with active points `x`, the first axis corresponds\n # with the different steps (in the order described in\n # `_differentiate_weights`).\n # - `terms` (which could probably use a better name) is half the `order`,\n # which is always even.\n work = _RichResult(x=x, df=df, fs=f[:, np.newaxis], error=np.nan, h=h0,\n df_last=np.nan, error_last=np.nan, h0=h0, fac=fac,\n atol=atol, rtol=rtol, nit=nit, nfev=nfev,\n status=status, dtype=dtype, terms=(order+1)//2,\n hdir=hdir, il=il, ic=ic, ir=ir, io=io)\n # This is the correspondence between terms in the `work` object and the\n # final result. In this case, the mapping is trivial. Note that `success`\n # is prepended automatically.\n res_work_pairs = [('status', 'status'), ('df', 'df'), ('error', 'error'),\n ('nit', 'nit'), ('nfev', 'nfev'), ('x', 'x')]\n\n def pre_func_eval(work):\n \"\"\"Determine the abscissae at which the function needs to be evaluated.\n\n See `_differentiate_weights` for a description of the stencil (pattern\n of the abscissae).\n\n In the first iteration, there is only one stored function value in\n `work.fs`, `f(x)`, so we need to evaluate at `order` new points. In\n subsequent iterations, we evaluate at two new points. Note that\n `work.x` is always flattened into a 1D array after broadcasting with\n all `args`, so we add a new axis at the end and evaluate all point\n in one call to the function.\n\n For improvement:\n - Consider measuring the step size actually taken, since `(x + h) - x`\n is not identically equal to `h` with floating point arithmetic.\n - Adjust the step size automatically if `x` is too big to resolve the\n step.\n - We could probably save some work if there are no central difference\n steps or no one-sided steps.\n \"\"\"\n n = work.terms # half the order\n h = work.h # step size\n c = work.fac # step reduction factor\n d = c**0.5 # square root of step reduction factor (one-sided stencil)\n # Note - no need to be careful about dtypes until we allocate `x_eval`\n\n if work.nit == 0:\n hc = h / c**np.arange(n)\n hc = np.concatenate((-hc[::-1], hc))\n else:\n hc = np.asarray([-h, h]) / c**(n-1)\n\n if work.nit == 0:\n hr = h / d**np.arange(2*n)\n else:\n hr = np.asarray([h, h/d]) / c**(n-1)\n\n n_new = 2*n if work.nit == 0 else 2 # number of new abscissae\n x_eval = np.zeros((len(work.hdir), n_new), dtype=work.dtype)\n il, ic, ir = work.il, work.ic, work.ir\n x_eval[ir] = work.x[ir, np.newaxis] + hr\n x_eval[ic] = work.x[ic, np.newaxis] + hc\n x_eval[il] = work.x[il, np.newaxis] - hr\n return x_eval\n\n def post_func_eval(x, f, work):\n \"\"\" Estimate the derivative and error from the function evaluations\n\n As in `pre_func_eval`: in the first iteration, there is only one stored\n function value in `work.fs`, `f(x)`, so we need to add the `order` new\n points. In subsequent iterations, we add two new points. The tricky\n part is getting the order to match that of the weights, which is\n described in `_differentiate_weights`.\n\n For improvement:\n - Change the order of the weights (and steps in `pre_func_eval`) to\n simplify `work_fc` concatenation and eliminate `fc` concatenation.\n - It would be simple to do one-step Richardson extrapolation with `df`\n and `df_last` to increase the order of the estimate and/or improve\n the error estimate.\n - Process the function evaluations in a more numerically favorable\n way. For instance, combining the pairs of central difference evals\n into a second-order approximation and using Richardson extrapolation\n to produce a higher order approximation seemed to retain accuracy up\n to very high order.\n - Alternatively, we could use `polyfit` like Jacobi. An advantage of\n fitting polynomial to more points than necessary is improved noise\n tolerance.\n \"\"\"\n n = work.terms\n n_new = n if work.nit == 0 else 1\n il, ic, io = work.il, work.ic, work.io\n\n # Central difference\n # `work_fc` is *all* the points at which the function has been evaluated\n # `fc` is the points we're using *this iteration* to produce the estimate\n work_fc = (f[ic, :n_new], work.fs[ic, :], f[ic, -n_new:])\n work_fc = np.concatenate(work_fc, axis=-1)\n if work.nit == 0:\n fc = work_fc\n else:\n fc = (work_fc[:, :n], work_fc[:, n:n+1], work_fc[:, -n:])\n fc = np.concatenate(fc, axis=-1)\n\n # One-sided difference\n work_fo = np.concatenate((work.fs[io, :], f[io, :]), axis=-1)\n if work.nit == 0:\n fo = work_fo\n else:\n fo = np.concatenate((work_fo[:, 0:1], work_fo[:, -2*n:]), axis=-1)\n\n work.fs = np.zeros((len(ic), work.fs.shape[-1] + 2*n_new))\n work.fs[ic] = work_fc\n work.fs[io] = work_fo\n\n wc, wo = _differentiate_weights(work, n)\n work.df_last = work.df.copy()\n work.df[ic] = fc @ wc / work.h\n work.df[io] = fo @ wo / work.h\n work.df[il] *= -1\n\n work.h /= work.fac\n work.error_last = work.error\n # Simple error estimate - the difference in derivative estimates between\n # this iteration and the last. This is typically conservative because if\n # convergence has begin, the true error is much closer to the difference\n # between the current estimate and the *next* error estimate. However,\n # we could use Richarson extrapolation to produce an error estimate that\n # is one order higher, and take the difference between that and\n # `work.df` (which would just be constant factor that depends on `fac`.)\n work.error = abs(work.df - work.df_last)\n\n def check_termination(work):\n \"\"\"Terminate due to convergence, non-finite values, or error increase\"\"\"\n stop = np.zeros_like(work.df).astype(bool)\n\n i = work.error < work.atol + work.rtol*abs(work.df)\n work.status[i] = eim._ECONVERGED\n stop[i] = True\n\n if work.nit > 0:\n i = ~((np.isfinite(work.x) & np.isfinite(work.df)) | stop)\n work.df[i], work.status[i] = np.nan, eim._EVALUEERR\n stop[i] = True\n\n # With infinite precision, there is a step size below which\n # all smaller step sizes will reduce the error. But in floating point\n # arithmetic, catastrophic cancellation will begin to cause the error\n # to increase again. This heuristic tries to avoid step sizes that are\n # too small. There may be more theoretically sound approaches for\n # detecting a step size that minimizes the total error, but this\n # heuristic seems simple and effective.\n i = (work.error > work.error_last*10) & ~stop\n work.status[i] = _EERRORINCREASE\n stop[i] = True\n\n return stop\n\n def post_termination_check(work):\n return\n\n def customize_result(res, shape):\n return shape\n\n return eim._loop(work, callback, shape, maxiter, func, args, dtype,\n pre_func_eval, post_func_eval, check_termination,\n post_termination_check, customize_result, res_work_pairs,\n preserve_shape)\n```\n###test function signature:\n```python\n\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_maxiter_callback", + "reference": "\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n x = 0.612814\n dist = stats.norm()\n maxiter = 3\n\n def f(x):\n res = dist.cdf(x)\n return res\n\n default_order = 8\n res = differentiate(f, x, maxiter=maxiter, rtol=1e-15)\n assert not np.any(res.success)\n assert np.all(res.nfev == default_order + 1 + (maxiter - 1)*2)\n assert np.all(res.nit == maxiter)\n\n def callback(res):\n callback.iter += 1\n callback.res = res\n assert hasattr(res, 'x')\n assert res.df not in callback.dfs\n callback.dfs.add(res.df)\n assert res.status == eim._EINPROGRESS\n if callback.iter == maxiter:\n raise StopIteration\n callback.iter = -1 # callback called once before first iteration\n callback.res = None\n callback.dfs = set()\n\n res2 = differentiate(f, x, callback=callback, rtol=1e-15)\n # terminating with callback is identical to terminating due to maxiter\n # (except for `status`)\n for key in res.keys():\n if key == 'status':\n assert res[key] == eim._ECONVERR\n assert callback.res[key] == eim._EINPROGRESS\n assert res2[key] == eim._ECALLBACK\n else:\n assert res2[key] == callback.res[key] == res[key]", + "completions": [ + "```python\n def test_maxiter_callback(self):\n # Test behavior of `maxiter` parameter and `callback` interface\n def callback(work):\n if work.nit == 0:\n work.maxiter = 1\n else:\n work.maxiter = 0\n res = _differentiate(lambda x: x**2, 0, maxiter=2, callback=callback)\n self.assertEqual(res.nit, 1)\n self.assertEqual(res.status, 0)\n```\n" + ], + "line": 187, + "token": 1981, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 22, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef get_standard_metadata(\n pyproject_dict: Mapping[str, Any],\n settings: ScikitBuildSettings,\n) -> StandardMetadata:\n new_pyproject_dict = dict(pyproject_dict)\n # Handle any dynamic metadata\n for field, provider, config in load_dynamic_metadata(settings.metadata):\n if provider is None:\n msg = f\"{field} is missing provider\"\n raise KeyError(msg)\n if field not in pyproject_dict.get(\"project\", {}).get(\"dynamic\", []):\n msg = f\"{field} is not in project.dynamic\"\n raise KeyError(msg)\n new_pyproject_dict[\"project\"][field] = provider.dynamic_metadata(field, config)\n new_pyproject_dict[\"project\"][\"dynamic\"].remove(field)\n\n metadata = StandardMetadata.from_pyproject(new_pyproject_dict)\n\n # For scikit-build-core < 0.5, we keep the normalized name for back-compat\n if settings.minimum_version is not None and settings.minimum_version < Version(\n \"0.5\"\n ):\n metadata.name = metadata.canonical_name\n\n # The description field is required to be one line. Instead of merging it\n # or cutting off subsequent lines (setuptools), we throw a nice error.\n # But we didn't validate before 0.9.\n if (\n settings.minimum_version is None or settings.minimum_version >= Version(\"0.9\")\n ) and \"\\n\" in (metadata.description or \"\"):\n msg = \"Multiple lines in project.description are not supported; this is supposed to be a one line summary\"\n raise ValueError(msg)\n\n return metadata\n```\n###test function signature:\n```python\n\ndef test_plugin_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_plugin_metadata", + "reference": "\ndef test_plugin_metadata():\n reason_msg = (\n \"install hatch-fancy-pypi-readme and setuptools-scm to test the \"\n \"dynamic metadata plugins\"\n )\n pytest.importorskip(\"hatch_fancy_pypi_readme\", reason=reason_msg)\n pytest.importorskip(\"setuptools_scm\", reason=reason_msg)\n if shutil.which(\"git\") is None:\n pytest.skip(\"git is not installed\")\n\n shutil.copy(\"plugin_project.toml\", \"pyproject.toml\")\n\n subprocess.run([\"git\", \"init\", \"--initial-branch=main\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.name\", \"bot\"], check=True)\n subprocess.run([\"git\", \"config\", \"user.email\", \"bot@scikit-build.org\"], check=True)\n subprocess.run([\"git\", \"add\", \"pyproject.toml\"], check=True)\n subprocess.run([\"git\", \"commit\", \"-m\", \"initial commit\"], check=True)\n subprocess.run([\"git\", \"tag\", \"v0.1.0\", \"-m\", \"initial commint\"], check=True)\n\n with Path(\"pyproject.toml\").open(\"rb\") as ft:\n pyproject = tomllib.load(ft)\n settings_reader = SettingsReader(pyproject, {}, state=\"metadata_wheel\")\n settings = settings_reader.settings\n\n settings_reader.validate_may_exit()\n\n metadata = get_standard_metadata(pyproject, settings)\n\n assert str(metadata.version) == \"0.1.0\"\n assert metadata.readme == pyproject_metadata.Readme(\n \"Fragment #1Fragment #2\", None, \"text/x-rst\"\n )\n\n assert set(GetRequires().dynamic_metadata()) == {\n \"hatch-fancy-pypi-readme>=22.3\",\n \"setuptools-scm\",\n }", + "completions": [ + "```python\n\ndef test_plugin_metadata():\n from skbuild import __version__ as skbuild_version\n from skbuild.settings import ScikitBuildSettings\n from skbuild.metadata import StandardMetadata\n from skbuild.utils import load_dynamic_metadata\n from skbuild.version import Version\n\n pyproject_dict = {\n \"build-system\": {\n \"requires\": [\"setuptools>=40.8.0\", \"wheel\"],\n \"build-backend\": \"setuptools.build_meta\",\n },\n \"project\": {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n \"description\": \"A build system for Python packages\",\n \"dynamic\": [\"name\", \"version\"],\n },\n }\n\n settings = ScikitBuildSettings(\n minimum_version=Version(\"0.1.0\"),\n metadata=[\n (\"name\", None, None),\n (\"version\", None, None),\n ],\n )\n\n metadata = get_standard_metadata(pyproject_dict, settings)\n\n assert metadata.name == \"scikit-build\"\n assert metadata.version == \"0.1.0\"\n assert metadata.description == \"A build system for Python packages\"\n\n assert metadata.canonical_name == \"scikit-build\"\n assert metadata.canonical_version == \"0.1.0\"\n\n assert metadata.dynamic_metadata == {\n \"name\": \"scikit-build\",\n \"version\": \"0.1.0\",\n }\n\n assert metadata.dynamic_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_providers == [None, None]\n\n assert metadata.dynamic_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields == [\"name\", \"version\"]\n\n assert metadata.dynamic_metadata_providers == [None, None]\n\n assert metadata.dynamic_metadata_configs == [None, None]\n\n assert metadata.dynamic_metadata_fields_by_provider == {None: [\"name\", \"version\"]}\n\n assert metadata.dynamic_metadata_providers_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_configs_by_field == {\n \"name\": None,\n \"version\": None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field == {\n (None, \"name\"): \"scikit-build\",\n (None, \"version\"): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider == {\n (\"name\", None): \"scikit-build\",\n (\"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider == {\n (\"name\", None): None,\n (\"version\", None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config == {\n (None, \"name\", None): \"scikit-build\",\n (None, \"version\", None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config == {\n (\"name\", None, None): \"scikit-build\",\n (\"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_configs_by_field_and_provider_and_config == {\n (\"name\", None, None): None,\n (\"version\", None, None): None,\n }\n\n assert metadata.dynamic_metadata_fields_by_provider_and_field_and_config_and_provider == {\n (None, \"name\", None, None): \"scikit-build\",\n (None, \"version\", None, None): \"0.1.0\",\n }\n\n assert metadata.dynamic_metadata_providers_by_field_and_provider_and_config_and_provider == {\n (\"name\", None, None" + ], + "line": 40, + "token": 347, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 23, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _get_datetime_format(value):\n if not isinstance(value, pd.Series):\n value = pd.Series(value)\n\n value = value[~value.isna()]\n value = value.astype(str).to_numpy()\n\n return _guess_datetime_format_for_array(value)\n```\n###test function signature:\n```python\ndef test__get_datetime_format():\n # Setup\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test__get_datetime_format", + "reference": "def test__get_datetime_format():\n # Setup\n string_value = '2021-02-02'\n list_value = [np.nan, '2021-02-02']\n series_value = pd.Series(['2021-02-02T12:10:59'])\n\n # Run\n string_out = _get_datetime_format(string_value)\n list_out = _get_datetime_format(list_value)\n series_out = _get_datetime_format(series_value)\n\n # Assert\n expected_output = '%Y-%m-%d'\n assert string_out == expected_output\n assert list_out == expected_output\n assert series_out == '%Y-%m-%dT%H:%M:%S'", + "completions": [ + "```python\ndef test__get_datetime_format():\n # Setup\n value = pd.Series([\"2020-01-01\", \"2020-01-02\", \"2020-01-03\"])\n expected = \"%Y-%m-%d\"\n\n # Exercise\n actual = _get_datetime_format(value)\n\n # Verify\n assert actual == expected\n```\n" + ], + "line": 17, + "token": 124, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 24, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def load(self) -> None:\n super().load()\n\n _path = str(\n importlib_resources.files(\"distilabel\")\n / \"steps\"\n / \"tasks\"\n / \"templates\"\n / \"ultrafeedback\"\n / f\"{self.aspect}.jinja2\"\n )\n\n self._template = Template(open(_path).read())\n```\n###test function signature:\n```python\n\n def test_process_with_complex_aspect(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_process_with_complex_aspect", + "reference": "\n def test_process_with_complex_aspect(self) -> None:\n pipeline = Pipeline(name=\"unit-test-pipeline\")\n llm = UltraFeedbackLLM()\n\n task = UltraFeedback(\n name=\"ultrafeedback\",\n aspect=\"truthfulness\",\n llm=llm,\n pipeline=pipeline,\n )\n task.load()\n\n assert next(\n task.process([{\"instruction\": \"test\", \"generations\": [\"A\", \"B\"]}])\n ) == [\n {\n \"instruction\": \"test\",\n \"generations\": [\"A\", \"B\"],\n \"types\": [1, 2],\n \"rationales\": [\"text\", \"text\"],\n \"ratings\": [1, 2],\n \"rationales-for-ratings\": [\"text\", \"text\"],\n \"model_name\": \"ultrafeedback-model\",\n }\n ]", + "completions": [ + "```python\n def test_process_with_complex_aspect(self) -> None:\n self.aspect = \"complex_aspect\"\n self.load()\n self.assertEqual(\n self._template.render(\n {\n \"aspect\": self.aspect,\n \"aspect_name\": \"Complex Aspect\",\n \"aspect_description\": \"This is a complex aspect\",\n \"aspect_type\": \"complex\",\n \"aspect_value\": \"1\",\n \"aspect_weight\": \"1\",\n \"aspect_color\": \"red\",\n \"aspect_icon\": \"icon-circle\",\n \"aspect_icon_color\": \"red\",\n \"aspect_icon_bg_color\": \"white\",\n \"aspect_icon_size\": \"20\",\n \"aspect_icon_border_color\": \"red\",\n \"aspect_icon_border_width\": \"2\",\n \"aspect_icon_border_radius\": \"50\",\n \"aspect_icon_padding\": \"5\",\n \"aspect_icon_margin\": \"5\",\n \"aspect_icon_bg_size\": \"20\",\n \"aspect_icon_bg_border_color\": \"red\",\n \"aspect_icon_bg_border_width\": \"2\",\n \"aspect_icon_bg_border_radius\": \"50\",\n \"aspect_icon_bg_padding\": \"5\",\n \"aspect_icon_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_margin\": \"5\",\n \"aspect_icon_bg_bg_bg_bg_bg_color\": \"white\",\n \"aspect_icon_bg_bg_bg_bg_bg_size\": \"20\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_color\": \"red\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_width\": \"2\",\n \"aspect_icon_bg_bg_bg_bg_bg_border_radius\": \"50\",\n \"aspect_icon_bg_bg_bg_bg_bg_padding\": \"5\",\n \"aspect_icon_bg_bg_bg_" + ], + "line": 21, + "token": 145, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 25, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _setup_cuda_device(device: torch.device) -> torch.device:\n local_rank = _get_local_rank() or 0\n if device.index is None:\n device = torch.device(type=\"cuda\", index=local_rank)\n\n # Ensure index is available before setting device\n if device.index >= torch.cuda.device_count():\n raise RuntimeError(\n \"The local rank is larger than the number of available GPUs.\"\n )\n\n torch.cuda.set_device(device)\n return device\n```\n###test function signature:\n```python\n\n def test_get_gpu_device(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_gpu_device", + "reference": "\n def test_get_gpu_device(self) -> None:\n device_idx = torch.cuda.device_count() - 1\n assert device_idx >= 0\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(device_idx)}, clear=True):\n device = get_device()\n assert device.type == \"cuda\"\n assert device.index == device_idx\n assert device.index == torch.cuda.current_device()\n\n # Test that we raise an error if the device index is specified on distributed runs\n if device_idx > 0:\n with pytest.raises(\n RuntimeError,\n match=f\"Device specified is cuda:0 but was assigned cuda:{device_idx}\",\n ):\n device = get_device(\"cuda:0\")\n\n invalid_device_idx = device_idx + 10\n with mock.patch.dict(os.environ, {\"LOCAL_RANK\": str(invalid_device_idx)}):\n with pytest.raises(\n RuntimeError,\n match=\"The local rank is larger than the number of available GPUs\",\n ):\n device = get_device(\"cuda\")\n\n # Test that we fall back to 0 if LOCAL_RANK is not specified\n device = torch.device(_get_device_type_from_env())\n device = _setup_cuda_device(device)\n assert device.type == \"cuda\"\n assert device.index == 0\n assert device.index == torch.cuda.current_device()", + "completions": [ + "```python\n def test_get_gpu_device(self) -> None:\n device = torch.device(type=\"cuda\", index=0)\n self.assertEqual(device, _setup_cuda_device(device))\n```\n" + ], + "line": 21, + "token": 157, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 26, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def download(cls, client: GiskardClient, project_key, model_id, *_args, **_kwargs):\n local_dir = settings.home_dir / settings.cache_dir / \"models\" / model_id\n client.load_artifact(local_dir, posixpath.join(\"models\", model_id))\n meta_response: ModelMetaInfo = client.load_model_meta(project_key, model_id)\n # internal worker case, no token based http client\n if not local_dir.exists():\n raise RuntimeError(f\"Cannot find existing model {project_key}.{model_id} in {local_dir}\")\n with (Path(local_dir) / META_FILENAME).open(encoding=\"utf-8\") as f:\n file_meta = yaml.load(f, Loader=yaml.Loader)\n classification_labels = cls.cast_labels(meta_response)\n meta = ModelMeta(\n name=meta_response.name,\n description=meta_response.description,\n model_type=SupportedModelTypes[meta_response.modelType],\n feature_names=meta_response.featureNames,\n classification_labels=classification_labels,\n classification_threshold=meta_response.threshold,\n loader_module=file_meta[\"loader_module\"],\n loader_class=file_meta[\"loader_class\"],\n )\n\n model_py_ver = (\n tuple(meta_response.languageVersion.split(\".\")) if \"PYTHON\" == meta_response.language.upper() else None\n )\n\n clazz = cls.determine_model_class(meta, local_dir, model_py_ver=model_py_ver)\n\n constructor_params = meta.__dict__\n constructor_params[\"id\"] = str(model_id)\n\n del constructor_params[\"loader_module\"]\n del constructor_params[\"loader_class\"]\n\n model = clazz.load(local_dir, model_py_ver=model_py_ver, **constructor_params)\n return model\n```\n###test function signature:\n```python\n\ndef test_model_download(request):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_model_download", + "reference": "\ndef test_model_download(request):\n model: BaseModel = request.getfixturevalue(\"german_credit_model\")\n project_key = str(uuid.uuid4())\n\n with utils.MockedProjectCacheDir():\n with utils.MockedClient(mock_all=False) as (client, mr):\n # The model needs to request files\n requested_urls = []\n requested_urls.extend(utils.register_uri_for_model_meta_info(mr, model, project_key))\n requested_urls.extend(\n utils.register_uri_for_model_artifact_info(mr, model, project_key, register_file_contents=True)\n )\n\n downloaded_model = BaseModel.download(client=client, project_key=project_key, model_id=str(model.id))\n\n for requested_url in requested_urls:\n assert utils.is_url_requested(mr.request_history, requested_url)\n\n assert downloaded_model.id == model.id\n assert downloaded_model.meta == model.meta", + "completions": [ + "```python\ndef test_model_download(request):\n from pathlib import Path\n from unittest.mock import MagicMock\n from giskard.client import GiskardClient\n from giskard.models import ModelMetaInfo, ModelMeta, SupportedModelTypes\n from giskard.utils import cast_labels\n from giskard.models.base import Model\n from giskard.models.classification import ClassificationModel\n from giskard.models.regression import RegressionModel\n from giskard.models.clustering import ClusteringModel\n from giskard.models.dim_reduction import DimReductionModel\n from giskard.models.anomaly_detection import AnomalyDetectionModel\n from giskard.models.time_series import TimeSeriesModel\n from giskard.models.text_classification import TextClassificationModel\n from giskard.models.text_recommendation import TextRecommendationModel\n from giskard.models.image_classification import ImageClassificationModel\n from giskard.models.image_recommendation import ImageRecommendationModel\n from giskard.models.image_segmentation import ImageSegmentationModel\n from giskard.models.image_generation import ImageGenerationModel\n from giskard.models.image_style_transfer import ImageStyleTransferModel\n from giskard.models.image_super_resolution import ImageSuperResolutionModel\n from giskard.models.image_denoising import ImageDenoisingModel\n from giskard.models.image_colorization import ImageColorizationModel\n from giskard.models.image_deraining import ImageDerainingModel\n from giskard.models.image_deblurring import ImageDeblurringModel\n from giskard.models.image_edge_detection import ImageEdgeDetectionModel\n from giskard.models.image_face_detection import ImageFaceDetectionModel\n from giskard.models.image_face_recognition import ImageFaceRecognitionModel\n from giskard.models.image_object_detection import ImageObjectDetectionModel\n from giskard.models.image_object_recognition import ImageObjectRecognitionModel\n from giskard.models.image_segmentation_mask_rcnn import ImageSegmentationMaskRCNNModel\n from giskard.models.image_segmentation_yolo import ImageSegmentationYOLOModel\n from giskard.models.image_segmentation_faster_rcnn import ImageSegmentationFasterRCNNModel\n from giskard.models.image_segmentation_unet import ImageSegmentationUNetModel\n from giskard.models.image_segmentation_deeplab import ImageSegmentationDeepLabModel\n from giskard.models.image_segmentation_fcn import ImageSegmentationFCNModel\n from giskard.models.image_segmentation_pspnet import ImageSegmentationPSPNetModel\n from giskard.models.image_segmentation_linknet import ImageSegmentationLinkNetModel\n from giskard.models.image_segmentation_unet_plus import ImageSegmentationUNetPlusModel\n from giskard.models.image_segmentation_deeplab_v3 import ImageSegmentationDeepLabV3Model\n from giskard.models.image_segmentation_deeplab_v3_plus import ImageSegmentationDeepLabV3PlusModel\n from giskard.models.image_segmentation_fcn_resnet101 import ImageSegmentationFCNResNet101Model\n from giskard.models.image_segmentation_fcn_resnet50 import ImageSegmentationFCNResNet50Model\n from giskard.models.image_segmentation_fcn_mobilenet_v2 import ImageSegmentationFCNMobileNetV2Model\n from giskard.models.image_segmentation_fcn_mobilenet_v3 import ImageSegmentationFCNMobileNetV3Model\n from giskard.models.image_segmentation_fcn_efficientnet_b" + ], + "line": 40, + "token": 316, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 27, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\ndef test_rounded_date_columns():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rounded_date_columns", + "reference": "def test_rounded_date_columns():\n df = pl.DataFrame(\n [\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.000000\", \"exp\": True},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00.123456\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:01.000000\", \"exp\": False},\n {\"a\": \"2017-01-01\", \"b\": \"2017-01-01 00:00:00\", \"exp\": True},\n ]\n )\n\n df = df.with_columns(df[\"a\"].str.to_date().alias(\"a_dt\"))\n actual = columns_equal(df[\"a_dt\"], df[\"b\"])\n expected = df[\"exp\"]\n assert_series_equal(actual, expected, check_names=False)", + "completions": [ + "```python\ndef test_rounded_date_columns():\n # given\n col_1 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n col_2 = pl.Series([\"2021-01-01\", \"2021-01-02\", \"2021-01-03\"])\n\n # when\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n\n # then\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 28, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef annotations_to_instances(annos, image_size, mask_format=\"polygon\"):\n boxes = [BoxMode.convert(obj[\"bbox\"], obj[\"bbox_mode\"], BoxMode.XYXY_ABS) for obj in annos]\n target = Instances(image_size)\n boxes = target.gt_boxes = Boxes(boxes)\n boxes.clip(image_size)\n\n classes = [obj[\"category_id\"] for obj in annos]\n classes = torch.tensor(classes, dtype=torch.int64)\n target.gt_classes = classes\n\n if len(annos) and \"segmentation\" in annos[0]:\n segms = [obj[\"segmentation\"] for obj in annos]\n if mask_format == \"polygon\":\n masks = PolygonMasks(segms)\n else:\n assert mask_format == \"bitmask\", mask_format\n masks = []\n for segm in segms:\n if isinstance(segm, list):\n # polygon\n masks.append(polygons_to_bitmask(segm, *image_size))\n elif isinstance(segm, dict):\n # COCO RLE\n masks.append(mask_util.decode(segm))\n elif isinstance(segm, np.ndarray):\n assert segm.ndim == 2, \"Expect segmentation of 2 dimensions, got {}.\".format(\n segm.ndim\n )\n # mask array\n masks.append(segm)\n else:\n raise ValueError(\n \"Cannot convert segmentation of type '{}' to BitMasks!\"\n \"Supported types are: polygons as list[list[float] or ndarray],\"\n \" COCO-style RLE as a dict, or a full-image segmentation mask \"\n \"as a 2D ndarray.\".format(type(segm))\n )\n # torch.from_numpy does not support array with negative stride.\n masks = BitMasks(\n torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks])\n )\n target.gt_masks = masks\n\n if len(annos) and \"keypoints\" in annos[0]:\n kpts = [obj.get(\"keypoints\", []) for obj in annos]\n target.gt_keypoints = Keypoints(kpts)\n\n return target\n```\n###test function signature:\n```python\n\n def test_transform_simple_annotation(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_simple_annotation", + "reference": "\n def test_transform_simple_annotation(self):\n transforms = T.TransformList([T.HFlipTransform(400)])\n anno = {\n \"bbox\": np.asarray([10, 10, 200, 300]),\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 3,\n \"segmentation\": [[10, 10, 100, 100, 100, 10], [150, 150, 200, 150, 200, 200]],\n }\n\n output = detection_utils.transform_instance_annotations(anno, transforms, (400, 400))\n self.assertTrue(np.allclose(output[\"bbox\"], [200, 10, 390, 300]))\n self.assertEqual(len(output[\"segmentation\"]), len(anno[\"segmentation\"]))\n self.assertTrue(np.allclose(output[\"segmentation\"][0], [390, 10, 300, 100, 300, 10]))\n\n detection_utils.annotations_to_instances([output, output], (400, 400))", + "completions": [ + "```python\n def test_transform_simple_annotation(self):\n annos = [\n {\n \"bbox\": [10, 10, 100, 100],\n \"bbox_mode\": BoxMode.XYXY_ABS,\n \"category_id\": 0,\n \"segmentation\": [\n [10, 10, 50, 50, 90, 90, 100, 100],\n [20, 20, 50, 50, 80, 80, 100, 100],\n ],\n \"keypoints\": [10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n }\n ]\n image_size = (100, 100)\n target = annotations_to_instances(annos, image_size)\n assert target.gt_boxes.tensor.tolist() == [[10, 10, 100, 100]]\n assert target.gt_classes.tolist() == [0]\n assert target.gt_masks.polygons().tolist() == [\n [[10, 10], [50, 50], [90, 90], [100, 100]],\n [[20, 20], [50, 50], [80, 80], [100, 100]],\n ]\n assert target.gt_keypoints.tensor.tolist() == [[10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]\n```\n" + ], + "line": 54, + "token": 437, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 29, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef assert_connectivity(connectivity: nx.Graph, circuit: Circuit):\n if list(connectivity.nodes) != list(range(connectivity.number_of_nodes())):\n node_mapping = {node: i for i, node in enumerate(connectivity.nodes)}\n new_connectivity = nx.Graph()\n new_connectivity.add_edges_from(\n [(node_mapping[u], node_mapping[v]) for u, v in connectivity.edges]\n )\n connectivity = new_connectivity\n for gate in circuit.queue:\n if len(gate.qubits) > 2 and not isinstance(gate, gates.M):\n raise_error(ConnectivityError, f\"{gate.name} acts on more than two qubits.\")\n if len(gate.qubits) == 2:\n if (gate.qubits[0], gate.qubits[1]) not in connectivity.edges:\n raise_error(\n ConnectivityError,\n f\"Circuit does not respect connectivity. {gate.name} acts on {gate.qubits}.\",\n )\n```\n###test function signature:\n```python\n\ndef test_sabre_simple(seed):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_sabre_simple", + "reference": "\ndef test_sabre_simple(seed):\n placer = Trivial()\n circ = Circuit(5)\n circ.add(gates.CZ(0, 1))\n initial_layout = placer(circ)\n router = Sabre(connectivity=star_connectivity(), seed=seed)\n routed_circuit, final_map = router(circuit=circ, initial_layout=initial_layout)\n assert router.added_swaps == 1\n assert final_map == {\"q0\": 2, \"q1\": 1, \"q2\": 0, \"q3\": 3, \"q4\": 4}\n assert routed_circuit.queue[0].qubits == (0, 2)\n assert isinstance(routed_circuit.queue[0], gates.SWAP)\n assert isinstance(routed_circuit.queue[1], gates.CZ)\n assert_connectivity(circuit=routed_circuit, connectivity=star_connectivity())\n assert_circuit_equivalence(\n original_circuit=circ,\n transpiled_circuit=routed_circuit,\n final_map=final_map,\n initial_map=initial_layout,\n )", + "completions": [ + "```python\ndef test_sabre_simple(seed):\n import random\n random.seed(seed)\n import networkx as nx\n from pyquil.quil import Program\n from pyquil.api import QVMConnection\n from pyquil.gates import CNOT, H, MEASURE\n from pyquil.quilbase import Declare\n from pyquil.quilbase import DefGate\n from pyquil.quilbase import DefCircuit\n from pyquil.quilbase import Gate\n from pyquil.quilbase import Circuit\n from pyquil.quilbase import QubitPlaceholder\n from pyquil.quilbase import MemoryReference\n from pyquil.quilbase import Parameter\n from pyquil.quilbase import Pragma\n from pyquil.quilbase import PragmaDeclare\n from pyquil.quilbase import PragmaDefGate\n from pyquil.quilbase import PragmaDefCircuit\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefQubit\n from pyquil.quilbase import PragmaDefMemory\n from pyquil.quilbase import PragmaDefRegister\n from pyquil.quilbase import PragmaDefLabel\n from pyquil.quilbase import PragmaDefFunction\n from pyquil.quilbase import PragmaDefProcedure\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyquil.quilbase import PragmaDefProgram\n from pyqu" + ], + "line": 27, + "token": 242, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 30, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef unique_patterns_2d(\n agrid: NDArray[np.int64], ksize: int, periodic_input: bool\n) -> tuple[NDArray[np.int64], NDArray[np.int64], NDArray[np.int64]]:\n assert ksize >= 1\n if periodic_input:\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n else:\n # TODO: implement non-wrapped image handling\n # a = np.pad(a, ((0,k-1),(0,k-1),*(((0,0),)*(len(a.shape)-2))), mode='constant', constant_values=None)\n agrid = np.pad(\n agrid,\n ((0, ksize - 1), (0, ksize - 1), *(((0, 0),) * (len(agrid.shape) - 2))),\n mode=\"wrap\",\n )\n\n patches: NDArray[np.int64] = np.lib.stride_tricks.as_strided(\n agrid,\n (\n agrid.shape[0] - ksize + 1,\n agrid.shape[1] - ksize + 1,\n ksize,\n ksize,\n *agrid.shape[2:],\n ),\n agrid.strides[:2] + agrid.strides[:2] + agrid.strides[2:],\n writeable=False,\n )\n patch_codes = hash_downto(patches, 2)\n uc, ui = np.unique(patch_codes, return_index=True)\n locs = np.unravel_index(ui, patch_codes.shape)\n up: NDArray[np.int64] = patches[locs[0], locs[1]]\n ids: NDArray[np.int64] = np.vectorize(\n {code: ind for ind, code in enumerate(uc)}.get\n )(patch_codes)\n return ids, up, patch_codes\n```\n###test function signature:\n```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_patterns_2d", + "reference": "\ndef test_unique_patterns_2d(img_redmaze) -> None:\n img = img_redmaze\n tile_size = 1\n pattern_width = 2\n _tile_catalog, tile_grid, _code_list, _unique_tiles = wfc_tiles.make_tile_catalog(\n img, tile_size\n )\n\n (\n _patterns_in_grid,\n pattern_contents_list,\n patch_codes,\n ) = wfc_patterns.unique_patterns_2d(tile_grid, pattern_width, True)\n assert patch_codes[1][2] == 4867810695119132864\n assert pattern_contents_list[7][1][1] == 8253868773529191888", + "completions": [ + "```python\n\ndef test_unique_patterns_2d(img_redmaze) -> None:\n ids, up, patch_codes = unique_patterns_2d(img_redmaze, 3, True)\n assert ids.shape == (10,)\n assert up.shape == (10, 3, 3)\n assert patch_codes.shape == (10,)\n```\n" + ], + "line": 48, + "token": 425, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 31, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef serialize(data: Any, encoding: str | None = None) -> str | bytes:\n serialized: bytes | str\n\n if encoding is not None:\n serialized = yaml.dump(data, encoding=encoding, Dumper=AiiDADumper)\n else:\n serialized = yaml.dump(data, Dumper=AiiDADumper)\n\n return serialized\n```\n###test function signature:\n```python\ndef test_mixed_attribute_normal_dict():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mixed_attribute_normal_dict", + "reference": "def test_mixed_attribute_normal_dict():\n from aiida.common.extendeddicts import AttributeDict\n\n # Construct a nested `AttributeDict`, which should make all nested dictionaries `AttributeDicts` recursively\n dictionary = {'nested': AttributeDict({'dict': 'string', 'value': 1})}\n attribute_dict = AttributeDict(dictionary)\n\n # Now add a normal dictionary in the attribute dictionary\n attribute_dict['nested']['normal'] = {'a': 2}\n\n serialized = serialize.serialize(attribute_dict)\n deserialized = serialize.deserialize_unsafe(serialized)\n\n assert attribute_dict, deserialized", + "completions": [ + "```python\ndef test_mixed_attribute_normal_dict():\n data = {\n \"a\": 1,\n \"b\": 2,\n \"c\": 3,\n \"d\": 4,\n \"e\": 5,\n \"f\": 6,\n \"g\": 7,\n \"h\": 8,\n \"i\": 9,\n \"j\": 10,\n \"k\": 11,\n \"l\": 12,\n \"m\": 13,\n \"n\": 14,\n \"o\": 15,\n \"p\": 16,\n \"q\": 17,\n \"r\": 18,\n \"s\": 19,\n \"t\": 20,\n \"u\": 21,\n \"v\": 22,\n \"w\": 23,\n \"x\": 24,\n \"y\": 25,\n \"z\": 26,\n }\n assert serialize(data) == \"\"\"a: 1\nb: 2\nc: 3\nd: 4\ne: 5\nf: 6\ng: 7\nh: 8\ni: 9\nj: 10\nk: 11\nl: 12\nm: 13\nn: 14\no: 15\np: 16\nq: 17\nr: 18\ns: 19\nt: 20\nu: 21\nv: 22\nw: 23\nx: 24\ny: 25\nz: 26\n\"\"\"\n```\n" + ], + "line": 17, + "token": 133, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 32, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected_against_alias", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n http_client.select_model(model_id=\"yolov8n-640\")\n\n # when\n result = http_client.unload_model(model_id=\"yolov8n-640\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"coco/3\",\n }\n assert (\n http_client.selected_model is None\n ), \"Even when alias is in use - selected model should be emptied\"", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected_against_alias(\n requests_mock: Mocker,\n) -> None:\n # given\n requests_mock.post(\n f\"{API_URL}/model/remove\",\n json={\n \"model_id\": \"test_model_id\",\n },\n headers=DEFAULT_HEADERS,\n )\n client = RoboflowClient(api_key=\"test_api_key\", api_secret=\"test_api_secret\")\n # when\n response = client.unload_model(model_id=\"test_model_alias\")\n # then\n assert response.model_id == \"test_model_id\"\n assert response.model_name == \"test_model_name\"\n assert response.model_version == \"test_model_version\"\n assert response.model_type == \"test_model_type\"\n assert response.model_status == \"test_model_status\"\n assert response.model_created_at == \"test_model_created_at\"\n assert response.model_updated_at == \"test_model_updated_at\"\n assert response.model_tags == \"test_model_tags\"\n assert response.model_description == \"test_model_description\"\n assert response.model_url == \"test_model_url\"\n assert response.model_thumbnail_url == \"test_model_thumbnail_url\"\n assert response.model_training_data_url == \"test_model_training_data_url\"\n assert response.model_training_data_type == \"test_model_training_data_type\"\n assert response.model_training_data_size == \"test_model_training_data_size\"\n assert response.model_training_data_description == \"test_model_training_data_description\"\n assert response.model_training_data_tags == \"test_model_training_data_tags\"\n assert response.model_training_data_created_at == \"test_model_training_data_created_at\"\n assert response.model_training_data_updated_at == \"test_model_training_data_updated_at\"\n assert response.model_training_data_status == \"test_model_training_data_status\"\n assert response.model_training_data_status_message == \"test_model_training_data_status_message\"\n assert response.model_training_data_status_details == \"test_model_training_data_status_details\"\n assert response.model_training_data_status_created_at == \"test_model_training_data_status_created_at\"\n assert response.model_training_data_status_updated_at == \"test_model_training_data_status_updated_at\"\n assert response.model_training_data_status_completed_at == \"test_model_training_data_status_completed_at\"\n assert response.model_training_data_status_duration == \"test_model_training_data_status_duration\"\n assert response.model_training_data_status_budget_type == \"test_model_training_data_status_budget_type\"\n assert response.model_training_data_status_budget_amount == \"test_model_training_data_status_budget_amount\"\n assert response.model_training_data_status_budget_unit == \"test_model_training_data_status_budget_unit\"\n assert response.model_training_data_status_budget_remaining == \"test_model_training_data_status_budget_remaining\"\n assert response.model_training_data_status_budget_message == \"test_model_training_data_status_budget_message\"\n assert response.model_training_data_status_budget_details == \"test_model_training_data_status_budget_details\"\n assert response.model_training_data_status_budget_created_at == \"test_model_training_data_status_budget_created_at\"\n assert response.model_training_data_status_budget_updated_at == \"test_" + ], + "line": 31, + "token": 162, + "line_diff": 2, + "token_diff": 0 + }, + { + "id": 33, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef edge_index_to_vector(\n edge_index: Tensor,\n size: Tuple[int, int],\n bipartite: bool,\n force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n row, col = edge_index\n\n if bipartite: # No need to account for self-loops.\n idx = (row * size[1]).add_(col)\n population = size[0] * size[1]\n return idx, population\n\n elif force_undirected:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We only operate on the upper triangular matrix:\n mask = row < col\n row, col = row[mask], col[mask]\n offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n idx = row.mul_(num_nodes).add_(col).sub_(offset)\n population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n return idx, population\n\n else:\n assert size[0] == size[1]\n num_nodes = size[0]\n\n # We remove self-loops as we do not want to take them into account\n # when sampling negative values.\n mask = row != col\n row, col = row[mask], col[mask]\n col[row < col] -= 1\n idx = row.mul_(num_nodes - 1).add_(col)\n population = num_nodes * num_nodes - num_nodes\n return idx, population\n```\n###test function signature:\n```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_edge_index_to_vector_and_vice_versa", + "reference": "\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n N = 10\n row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n edge_index = torch.stack([row, col], dim=0)\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n assert population == N * N\n assert idx.tolist() == list(range(population))\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n assert is_undirected(edge_index2)\n assert edge_index.tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n assert population == N * N - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n force_undirected=True)\n assert population == (N * (N + 1)) / 2 - N\n assert idx.tolist() == list(range(population))\n mask = edge_index[0] != edge_index[1] # Remove self-loops.\n edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n force_undirected=True)\n assert is_undirected(edge_index2)\n assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()", + "completions": [ + "```python\n\ndef test_edge_index_to_vector_and_vice_versa():\n # Create a fully-connected graph:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = False\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n\n # Test force_undirected:\n edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n size = (4, 4)\n bipartite = False\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]\n assert population == 16\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]))\n\n # Test force_undirected with bipartite graph:\n edge_index = torch.tensor([[0, 1], [1, 0]])\n size = (2, 2)\n bipartite = True\n force_undirected = True\n idx, population = edge_index_to_vector(edge_index, size, bipartite, force_undirected)\n assert idx.tolist() == [0, 1, 2, 3]\n assert population == 4\n\n # Convert back to edge index:\n edge_index = vector_to_edge_index(idx, size, bipartite, force_undirected)\n assert torch.all(edge_index == torch.tensor([[0, 1], [1, 0]]))\n```\n" + ], + "line": 43, + "token": 351, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 34, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_integration_test(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_integration_test", + "reference": "\n def test_spectrogram_integration_test(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n # fmt: off\n expected = np.array([\n 0.02464888, 0.04648664, 0.05872392, 0.02311783, 0.0327175 ,\n 0.02433643, 0.01198814, 0.02055709, 0.01559287, 0.01394357,\n 0.01299037, 0.01728045, 0.0254554 , 0.02486533, 0.02011792,\n 0.01755333, 0.02100457, 0.02337024, 0.01436963, 0.01464558,\n 0.0211017 , 0.0193489 , 0.01272165, 0.01858462, 0.03722598,\n 0.0456542 , 0.03281558, 0.00620586, 0.02226466, 0.03618042,\n 0.03508182, 0.02271432, 0.01051649, 0.01225771, 0.02315293,\n 0.02331886, 0.01417785, 0.0106844 , 0.01791214, 0.017177 ,\n 0.02125114, 0.05028201, 0.06830665, 0.05216664, 0.01963666,\n 0.06941418, 0.11513043, 0.12257859, 0.10948435, 0.08568069,\n 0.05509328, 0.05047818, 0.047112 , 0.05060737, 0.02982424,\n 0.02803827, 0.02933729, 0.01760491, 0.00587815, 0.02117637,\n 0.0293578 , 0.03452379, 0.02194803, 0.01676056,\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n self.assertTrue(np.allclose(spec[:64, 400], expected))\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=256,\n num_mel_filters=400,\n min_frequency=20,\n max_frequency=8000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"kaldi\",\n triangularize_in_mel_space=True,\n )\n\n mel_filters = np.pad(mel_filters, ((0, 1), (0, 0)))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"povey\", periodic=False),\n frame_length=400,\n hop_length=160,\n fft_length=512,\n power=2.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n preemphasis=0.97,\n mel_filters=mel_filters,\n log_mel=\"log\",\n mel_floor=1.1920928955078125e-07,\n remove_dc_offset=True,\n )\n self.assertEqual(spec.shape, (400, 584))\n\n # fmt: off\n expected = np.array([-15.94238515, -8.20712299, -8.22704352, -15.94238515,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -6.52463769, -7.73677889, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.18650018, -3.37195286,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -4.70190154, -2.4217066 , -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -5.62755239, -3.53385194,\n -15.94238515, -15.94238515, -15.94238515, -15.94238515,\n -9.43303023, -8.77480925, -15.94238515, -15.94238515,\n -15.94238515, -15.94238515, -4.2951092 , -5.51585994,\n -15.94238515, -15.94238515, -15.94238515, -4.40151721,\n -3.95228878, -15.94238515, -15.94238515, -15.94238515,\n -6.10365415, -4.59494697, -15.94238515, -15.94238515,\n -15.94238515, -8.10727767, -6.2585298 , -15.94238515,\n -15.94238515, -15.94238515, -5.60161702, -4.47217004,\n -15.94238515, -15.94238515, -15.94238515, -5.91641988]\n )\n # fmt: on\n self.assertTrue(np.allclose(spec[:64, 400], expected, atol=1e-5))", + "completions": [ + "```python\n def test_spectrogram_integration_test(self):\n waveform = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])\n window = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n expected_spectrogram = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]])\n\n self.assertTrue(np.allclose(spectrogram, expected_spectrogram))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 35, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef run_query_expansion_node(modules: List[Callable],\n module_params: List[Dict],\n previous_result: pd.DataFrame,\n node_line_dir: str,\n strategies: Dict,\n ) -> pd.DataFrame:\n if not os.path.exists(node_line_dir):\n os.makedirs(node_line_dir)\n node_dir = os.path.join(node_line_dir, \"query_expansion\")\n if not os.path.exists(node_dir):\n os.makedirs(node_dir)\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n\n # run query expansion\n results, execution_times = zip(*map(lambda task: measure_speed(\n task[0], project_dir=project_dir, previous_result=previous_result, **task[1]), zip(modules, module_params)))\n average_times = list(map(lambda x: x / len(results[0]), execution_times))\n\n # save results to folder\n pseudo_module_params = deepcopy(module_params)\n for i, module_param in enumerate(pseudo_module_params):\n if 'prompt' in module_params:\n module_param['prompt'] = str(i)\n filepaths = list(map(lambda x: os.path.join(node_dir, f'{x}.parquet'), range(len(modules))))\n list(map(lambda x: x[0].to_parquet(x[1], index=False), zip(results, filepaths))) # execute save to parquet\n filenames = list(map(lambda x: os.path.basename(x), filepaths))\n\n # make summary file\n summary_df = pd.DataFrame({\n 'filename': filenames,\n 'module_name': list(map(lambda module: module.__name__, modules)),\n 'module_params': module_params,\n 'execution_time': average_times,\n })\n\n # Run evaluation when there are more than one module.\n if len(modules) > 1:\n # pop general keys from strategies (e.g. metrics, speed_threshold)\n general_key = ['metrics', 'speed_threshold']\n general_strategy = dict(filter(lambda x: x[0] in general_key, strategies.items()))\n extra_strategy = dict(filter(lambda x: x[0] not in general_key, strategies.items()))\n\n # first, filter by threshold if it is enabled.\n if general_strategy.get('speed_threshold') is not None:\n results, filenames = filter_by_threshold(results, average_times, general_strategy['speed_threshold'],\n filenames)\n\n # check metrics in strategy\n if general_strategy.get('metrics') is None:\n raise ValueError(\"You must at least one metrics for query expansion evaluation.\")\n\n if extra_strategy.get('top_k') is None:\n extra_strategy['top_k'] = 10 # default value\n\n # get retrieval modules from strategy\n retrieval_callables, retrieval_params = make_retrieval_callable_params(extra_strategy)\n\n # get retrieval_gt\n retrieval_gt = pd.read_parquet(os.path.join(project_dir, \"data\", \"qa.parquet\"))['retrieval_gt'].tolist()\n\n # run evaluation\n evaluation_results = list(map(lambda result: evaluate_one_query_expansion_node(\n retrieval_callables, retrieval_params, result['queries'].tolist(), retrieval_gt,\n general_strategy['metrics'], project_dir, previous_result), results))\n\n evaluation_df = pd.DataFrame({\n 'filename': filenames,\n **{f'query_expansion_{metric_name}': list(map(lambda x: x[metric_name].mean(), evaluation_results))\n for metric_name in general_strategy['metrics']}\n })\n summary_df = pd.merge(on='filename', left=summary_df, right=evaluation_df, how='left')\n\n best_result, best_filename = select_best_average(evaluation_results, general_strategy['metrics'], filenames)\n # change metric name columns to query_expansion_metric_name\n best_result = best_result.rename(columns={\n metric_name: f'query_expansion_{metric_name}' for metric_name in strategies['metrics']})\n best_result = best_result.drop(columns=['retrieved_contents', 'retrieved_ids', 'retrieve_scores'])\n else:\n best_result, best_filename = results[0], filenames[0]\n best_result = pd.concat([previous_result, best_result], axis=1)\n\n # add 'is_best' column at summary file\n summary_df['is_best'] = summary_df['filename'] == best_filename\n\n # save files\n summary_df.to_csv(os.path.join(node_dir, \"summary.csv\"), index=False)\n best_result.to_parquet(os.path.join(node_dir, f\"best_{os.path.splitext(best_filename)[0]}.parquet\"), index=False)\n\n return best_result\n```\n###test function signature:\n```python\n\ndef test_run_query_expansion_node_default(node_line_dir):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_run_query_expansion_node_default", + "reference": "\ndef test_run_query_expansion_node_default(node_line_dir):\n generator_models['mock'] = MockLLM\n project_dir = pathlib.PurePath(node_line_dir).parent.parent\n qa_path = os.path.join(project_dir, \"data\", \"qa.parquet\")\n previous_result = pd.read_parquet(qa_path)\n\n modules = [query_decompose, hyde]\n module_params = [{'llm': \"mock\", 'temperature': 0.2, 'batch': 7}, {'llm': \"mock\"}]\n strategies = {\n 'metrics': metrics\n }\n best_result = run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n base_query_expansion_test(best_result, node_line_dir)", + "completions": [ + "```python\ndef test_run_query_expansion_node_default(node_line_dir):\n modules = [query_expansion_module]\n module_params = [{'prompt': 'default'}]\n previous_result = pd.DataFrame({'query': ['a', 'b', 'c']})\n strategies = {'metrics': ['recall', 'precision', 'f1'], 'speed_threshold': 1000}\n run_query_expansion_node(modules, module_params, previous_result, node_line_dir, strategies)\n```\n" + ], + "line": 84, + "token": 850, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 36, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def run(self) -> Tuple[pd.DataFrame, FreqDetectionResult]:\n try:\n if self._prophet_detector_params[\"suppress_stan\"]:\n pd.set_option(\"mode.chained_assignment\", None)\n\n # Skip measurements based on feedbacks given from SODA Cloud\n preprocessed_df = self.preprocess(time_series_df=self.raw_time_series_df)\n\n # Automatically detect frequency of the time series\n freq_detector = FrequencyDetector(\n logs=self.logs,\n params=self.params,\n time_series_df=preprocessed_df,\n manual_freq=self.training_dataset_params.frequency,\n )\n freq_detection_result = freq_detector.detect_frequency()\n\n # Return if frequency detection failed\n if freq_detection_result.error_code_int >= ERROR_CODE_LEVEL_CUTTOFF:\n return self.exit_with_warning(freq_detection_result)\n\n # Apply training dataset configurations\n training_df = self.apply_training_dataset_configs(\n time_series_df=preprocessed_df, freq_detection_result=freq_detection_result\n )\n\n # Remove big gaps from the time series to not confuse Prophet\n training_df = self.remove_big_gaps_from_time_series(\n time_series_df=training_df, freq_detection_result=freq_detection_result\n )\n\n # Only use the last n points for training based on the window length\n window_length = self.get_window_length(training_df=training_df)\n training_df = training_df.iloc[-window_length:]\n\n training_df_shape = training_df[\"y\"].dropna().shape[0]\n if training_df_shape <= self._min_n_points:\n freq_detection_result = get_not_enough_measurements_freq_result(n_data_points=training_df_shape)\n return self.exit_with_warning(freq_detection_result)\n\n model_hyperparameters = self.get_prophet_hyperparameters(time_series_df=training_df)\n\n predictions_df = self.setup_fit_predict(\n time_series_df=training_df, model_hyperparameters=model_hyperparameters\n )\n anomalies_df = self.detect_anomalies(time_series_df=training_df, predictions_df=predictions_df)\n anomalies_df = self.generate_severity_zones(anomalies_df=anomalies_df)\n anomalies_df = self.compute_alert_level(anomalies_df=anomalies_df)\n return anomalies_df, freq_detection_result\n except Exception as e:\n raise e\n```\n###test function signature:\n```python\n\ndef test_with_exit() -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_with_exit", + "reference": "\ndef test_with_exit() -> None:\n time_series_df = generate_random_dataframe(size=2, n_rows_to_convert_none=0, frequency=\"D\")\n detector = ProphetDetector(\n logs=LOGS,\n params=PARAMS,\n time_series_df=time_series_df,\n model_cfg=ModelConfigs(),\n training_dataset_params=TrainingDatasetParameters(),\n severity_level_params=SeverityLevelParameters(),\n )\n df_anomalies, frequency_result = detector.run()\n assert df_anomalies.empty\n assert frequency_result.error_code_int == 100\n assert frequency_result.freq_detection_strategy == \"not_enough_measurements_custom\"", + "completions": [ + "```python\ndef test_with_exit() -> None:\n from soda.sodacl.checks.check_frequency import FrequencyDetector\n from soda.sodacl.checks.check_frequency import get_not_enough_measurements_freq_result\n from soda.sodacl.checks.check_frequency import ERROR_CODE_LEVEL_CUTTOFF\n from soda.sodacl.checks.check_frequency import FreqDetectionResult\n from soda.sodacl.checks.check_frequency import CheckFrequency\n from soda.sodacl.checks.check_frequency import CheckFrequencyParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyLogs\n from soda.sodacl.checks.check_frequency import CheckFrequencyTrainingDatasetParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyProphetDetectorParams\n from soda.sodacl.checks.check_frequency import CheckFrequencyPro" + ], + "line": 52, + "token": 310, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 37, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef average_transforms(\n R: torch.Tensor,\n t: torch.Tensor,\n w: torch.Tensor,\n mask: torch.Tensor,\n dim: int,\n t_edge: Optional[torch.Tensor] = None,\n dither: Optional[bool] = True,\n dither_eps: float = 1e-4,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert dim >= 0, \"dimension must index from the left\"\n w = torch.where(\n mask[..., None].bool(), w, torch.full_like(w, torch.finfo(w.dtype).min)\n )\n\n # We use different averaging models based on the number of weights\n num_transform_weights = w.size(-1)\n if num_transform_weights == 1:\n # Share a single scalar weight between t and R.\n probs = w.softmax(dim)\n t_probs = probs\n R_probs = probs[..., None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 2:\n # Use separate scalar weights for each of t and R.\n probs = w.softmax(dim)\n t_probs, R_probs = probs.unbind(-1)\n t_probs = t_probs[..., None]\n R_probs = R_probs[..., None, None]\n\n # Average translation.\n t_avg = (t * t_probs).sum(dim)\n elif num_transform_weights == 3:\n # For R use a signed scalar weight.\n R_probs = w[..., 2].softmax(dim)[..., None, None]\n\n # For t use a two-parameter precision matrix P = P_isometric + P_radial.\n # We need to hand compute softmax over the shared dim x 2 elements.\n w_t = w[..., :2]\n w_t_total = w_t.logsumexp([dim, -1], True)\n p_iso, p_rad = (w_t - w_t_total).exp().unbind(-1)\n\n # Use Gaussian fusion for translation.\n t_edge = t_edge * mask.to(t_edge.dtype)[..., None]\n t_avg, _ = fuse_gaussians_isometric_plus_radial(t, p_iso, p_rad, t_edge, dim)\n else:\n raise NotImplementedError\n\n # Average rotation via SVD\n R_avg_unc = (R * R_probs).sum(dim)\n R_avg_unc = R_avg_unc + dither_eps * torch.randn_like(R_avg_unc)\n U, S, Vh = torch.linalg.svd(R_avg_unc, full_matrices=True)\n R_avg = U @ Vh\n\n # Enforce that matrix is rotation matrix\n d = torch.linalg.det(R_avg)\n d_expand = F.pad(d[..., None, None], (2, 0), value=1.0)\n Vh = Vh * d_expand\n R_avg = U @ Vh\n return R_avg, t_avg\n```\n###test function signature:\n```python\n\ndef test_average_transforms(rotations, translations):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_average_transforms", + "reference": "\ndef test_average_transforms(rotations, translations):\n R = torch.stack([rotations[0], torch.eye(3)])\n t = torch.stack([translations[0], torch.zeros(3)])\n w = torch.ones(2, 2)\n mask = torch.ones(2)\n # average of a transform with the identity is \"half\" the transform\n R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)\n R_total_fromavg, _ = compose_transforms(\n R_avg, torch.zeros(3), R_avg, torch.zeros(3)\n )\n _, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)\n assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)\n assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)", + "completions": [ + "```python\ndef test_average_transforms(rotations, translations):\n # test case 1\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 2\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 3\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 4\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 5\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 6\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 7\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 8\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w, mask, 0)\n assert torch.allclose(R, rotations)\n assert torch.allclose(t, translations)\n\n # test case 9\n mask = torch.ones(1, 1, 1, 1)\n w = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])\n R, t = average_transforms(rotations, translations, w" + ], + "line": 65, + "token": 537, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 38, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef btlm_config_to_gpt2_config(btlm_config: PretrainedConfig) -> GPT2Config:\n return GPT2Config(\n vocab_size=btlm_config.vocab_size,\n n_positions=0 if btlm_config.position_embedding_type == \"alibi\" else btlm_config.n_positions,\n n_embd=btlm_config.hidden_size,\n n_layer=btlm_config.num_hidden_layers,\n n_head=btlm_config.num_attention_heads,\n n_inner=btlm_config.n_inner,\n activation_function=btlm_config.activation_function,\n resid_pdrop=btlm_config.resid_pdrop,\n embd_pdrop=btlm_config.embd_pdrop,\n attn_pdrop=btlm_config.attn_pdrop,\n layer_norm_epsilon=btlm_config.layer_norm_epsilon,\n initializer_range=btlm_config.initializer_range,\n bos_token_id=btlm_config.bos_token_id,\n eos_token_id=btlm_config.eos_token_id,\n # These are new arguments not in the original GPT2Config\n use_alibi=btlm_config.position_embedding_type == \"alibi\",\n use_flash_attn=btlm_config.position_embedding_type == \"alibi\", # Alibi code path requires flash_attn\n mup_width_scale=btlm_config.mup_width_scale,\n mup_embeddings_multiplier=btlm_config.mup_embeddings_scale,\n mup_output_multiplier=btlm_config.mup_output_alpha,\n mup_scale_qk_dot_by_d=btlm_config.mup_scale_qk_dot_by_d,\n mlp_multiple_of=1,\n )\n```\n###test function signature:\n```python\ndef test_btlm_optimized(model_name):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_btlm_optimized", + "reference": "def test_btlm_optimized(model_name):\n dtype = torch.float16\n device = \"cuda\"\n config = btlm_config_to_gpt2_config(\n AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n )\n config.fused_bias_fc = True\n config.fused_dropout_add_ln = True\n config.residual_in_fp32 = True\n\n pretrained_state_dict = remap_state_dict_hf_btlm(state_dict_from_pretrained(model_name), config)\n model = GPTLMHeadModel(config, device=device, dtype=dtype)\n model.load_state_dict(pretrained_state_dict)\n model.eval()\n\n torch.manual_seed(0)\n batch_size = 2\n max_seqlen = 256\n seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)\n input_ids = torch.randint(\n 0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long, device=device\n )\n with torch.no_grad():\n out = model.transformer(input_ids)\n logits = model(input_ids).logits\n del model\n\n # Without device_map, the model is loaded on the CPU, which is very slow\n # Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB\n model_ref = AutoModelForCausalLM.from_pretrained(\n model_name, device_map=\"auto\", trust_remote_code=True\n )\n model_ref.eval()\n with torch.no_grad():\n out_ref = model_ref.transformer(input_ids).last_hidden_state.to(device=device)\n logits_ref = model_ref(input_ids).logits.to(device=device)\n del model_ref\n\n model_hf = AutoModelForCausalLM.from_pretrained(\n model_name,\n torch_dtype=dtype,\n device_map={\"\": device},\n trust_remote_code=True,\n )\n model_hf.eval()\n with torch.no_grad():\n out_hf = model_hf.transformer(input_ids).last_hidden_state\n logits_hf = model_hf(input_ids).logits\n del model_hf\n\n print(f\"Output max diff: {(out - out_ref).abs().max().item()}\")\n print(f\"Output mean diff: {(out - out_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}\")\n assert (out - out_ref).abs().max().item() < 3 * (out_hf - out_ref).abs().max().item()\n\n print(f\"Logits max diff: {(logits - logits_ref).abs().max().item()}\")\n print(f\"Logits mean diff: {(logits - logits_ref).abs().mean().item()}\")\n print(f\"HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}\")\n print(f\"HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}\")\n assert (logits - logits_ref).abs().max().item() < 3 * (\n logits_hf - logits_ref\n ).abs().max().item()", + "completions": [ + "```python\ndef test_btlm_optimized(model_name):\n from transformers import GPT2Config, PretrainedConfig\n from bert_extended.optimized import btlm_config_to_gpt2_config\n\n config = PretrainedConfig.from_pretrained(model_name)\n gpt2_config = btlm_config_to_gpt2_config(config)\n gpt2_config.to_json_file(f\"{model_name}_optimized.json\")\n```\n" + ], + "line": 35, + "token": 169, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 39, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef grad_solid_harmonic(r: jnp.ndarray, l_max: int) -> jnp.ndarray:\n def cartesian_sum(r, ell):\n # computes all cartesian sums of order ell\n # i.e. all terms (a*x + b*y + c*z) for a+b+c = ell\n if ell == 0:\n return np.array([0.])\n return np.sum(np.stack(\n np.meshgrid(*[r]*ell, indexing='ij'), axis=0).reshape(ell, -1), axis=0)\n\n dharmonics = []\n for ell in range(l_max + 1):\n dx = cartesian_sum(np.array([1., 0., 0.]), ell)\n dy = cartesian_sum(np.array([0., 1., 0.]), ell)\n dz = cartesian_sum(np.array([0., 0., 1.]), ell)\n d = np.stack([dx, dy, dz], axis=1)\n cp = jax.vmap(functools.partial(cartesian_product, ell=ell), out_axes=-1)\n dcp = jnp.where(r.T[None, :, :] == 0, 0,\n d[:, :, None]*cp(r)[:, None, :] / r.T[None, :, :])\n sh = jnp.einsum('cs,cxn->snx', full_cart2sph(ell, True), dcp)\n pad = l_max - ell\n sh = jnp.pad(sh, [[pad, pad], [0, 0], [0, 0]])\n dharmonics.append(sh)\n\n return jnp.stack(dharmonics, axis=1) # [L, M, grid, xyz]\n```\n###test function signature:\n```python\n\n def test_grad_solid_harmonic(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_grad_solid_harmonic", + "reference": "\n def test_grad_solid_harmonic(self):\n np.random.seed(0)\n r = np.random.randn(100, 3)\n l_max = 5\n\n jax_grad = jax.jacfwd(lambda x: gto.solid_harmonic_from_cart(x, l_max))\n expected = jnp.transpose(jnp.squeeze(jax.vmap(jax_grad)(r[:, None, :])),\n [1, 2, 0, 3])\n\n with self.subTest('by hand'):\n observed = gto.grad_solid_harmonic(r, l_max)\n np.testing.assert_allclose(observed, expected, atol=1.e-4)\n with self.subTest('by jax'):\n observed_jacfwd = gto.grad_solid_harmonic_by_jacfwd(r, l_max)\n np.testing.assert_allclose(observed_jacfwd, expected, atol=1.e-4)", + "completions": [ + "```python\n def test_grad_solid_harmonic(self):\n r = jnp.array([[0.5, 0.5, 0.5]])\n l_max = 2\n dharmonics = grad_solid_harmonic(r, l_max)\n self.assertEqual(dharmonics.shape, (l_max + 1, l_max + 1, 1, 3))\n self.assertTrue(jnp.allclose(dharmonics[0, 0, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 0, 0, :], jnp.array([1., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[1, 1, 0, :], jnp.array([0., 1., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 0, 0, :], jnp.array([0., 0., 1.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 1, 0, :], jnp.array([0., 0., 0.])))\n self.assertTrue(jnp.allclose(dharmonics[2, 2, 0, :], jnp.array([0., 0., 0.])))\n```\n" + ], + "line": 32, + "token": 389, + "line_diff": 2, + "token_diff": 4 + }, + { + "id": 40, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_basis(\n base_shape, angular_tesselation, remove_symmetries=True, eps=1e-4\n):\n\n if base_shape == 'tetrahedron':\n verts = np.array([\n (np.sqrt(8 / 9), 0, -1 / 3),\n (-np.sqrt(2 / 9), np.sqrt(2 / 3), -1 / 3),\n (-np.sqrt(2 / 9), -np.sqrt(2 / 3), -1 / 3),\n (0, 0, 1),\n ])\n faces = np.array([(0, 1, 2), (0, 2, 3), (0, 1, 3), (1, 2, 3)])\n elif base_shape == 'icosahedron':\n a = (np.sqrt(5) + 1) / 2\n verts = np.array([\n (-1, 0, a),\n (1, 0, a),\n (-1, 0, -a),\n (1, 0, -a),\n (0, a, 1),\n (0, a, -1),\n (0, -a, 1),\n (0, -a, -1),\n (a, 1, 0),\n (-a, 1, 0),\n (a, -1, 0),\n (-a, -1, 0),\n ]) / np.sqrt(a + 2)\n faces = np.array([\n (0, 4, 1),\n (0, 9, 4),\n (9, 5, 4),\n (4, 5, 8),\n (4, 8, 1),\n (8, 10, 1),\n (8, 3, 10),\n (5, 3, 8),\n (5, 2, 3),\n (2, 7, 3),\n (7, 10, 3),\n (7, 6, 10),\n (7, 11, 6),\n (11, 0, 6),\n (0, 1, 6),\n (6, 1, 10),\n (9, 0, 11),\n (9, 11, 2),\n (9, 2, 5),\n (7, 2, 11),\n ])\n elif base_shape == 'octahedron':\n verts = np.array(\n [(0, 0, -1), (0, 0, 1), (0, -1, 0), (0, 1, 0), (-1, 0, 0), (1, 0, 0)]\n )\n corners = np.array(list(itertools.product([-1, 1], repeat=3)))\n pairs = np.argwhere(compute_sq_dist(corners.T, verts.T) == 2)\n faces = np.sort(np.reshape(pairs[:, 1], [3, -1]).T, 1)\n else:\n raise ValueError(f'base_shape {base_shape} not supported')\n verts = tesselate_geodesic(verts, faces, angular_tesselation)\n\n if remove_symmetries:\n # Remove elements of `verts` that are reflections of each other.\n match = compute_sq_dist(verts.T, -verts.T) < eps\n verts = verts[~np.any(np.triu(match), axis=0), :]\n\n basis = verts[:, ::-1]\n return basis\n```\n###test function signature:\n```python\n def test_generate_basis_golden(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_basis_golden", + "reference": " def test_generate_basis_golden(self):\n\n basis = geopoly.generate_basis('tetrahedron', 1)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('tetrahedron', 2)\n basis_golden = np.array([\n [-0.33333333, -0.81649658, -0.47140452],\n [-0.57735027, 0.00000000, -0.81649658],\n [-0.33333333, 0.81649658, -0.47140452],\n [-0.57735027, -0.70710678, 0.40824829],\n [-0.57735027, 0.70710678, 0.40824829],\n [-0.33333333, 0.00000000, 0.94280904],\n [1.00000000, 0.00000000, 0.00000000],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('icosahedron', 2)\n basis_golden = np.array([\n [0.85065081, 0.00000000, 0.52573111],\n [0.80901699, 0.50000000, 0.30901699],\n [0.52573111, 0.85065081, 0.00000000],\n [1.00000000, 0.00000000, 0.00000000],\n [0.80901699, 0.50000000, -0.30901699],\n [0.85065081, 0.00000000, -0.52573111],\n [0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, -0.85065081],\n [0.50000000, 0.30901699, -0.80901699],\n [0.00000000, 1.00000000, 0.00000000],\n [-0.52573111, 0.85065081, 0.00000000],\n [-0.30901699, 0.80901699, -0.50000000],\n [0.00000000, 0.52573111, 0.85065081],\n [-0.30901699, 0.80901699, 0.50000000],\n [0.30901699, 0.80901699, 0.50000000],\n [0.50000000, 0.30901699, 0.80901699],\n [0.50000000, -0.30901699, 0.80901699],\n [0.00000000, 0.00000000, 1.00000000],\n [-0.50000000, 0.30901699, 0.80901699],\n [-0.80901699, 0.50000000, 0.30901699],\n [-0.80901699, 0.50000000, -0.30901699],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))\n\n basis = geopoly.generate_basis('octahedron', 4)\n basis_golden = np.array([\n [0.00000000, 0.00000000, -1.00000000],\n [0.00000000, -0.31622777, -0.94868330],\n [0.00000000, -0.70710678, -0.70710678],\n [0.00000000, -0.94868330, -0.31622777],\n [0.00000000, -1.00000000, 0.00000000],\n [-0.31622777, 0.00000000, -0.94868330],\n [-0.40824829, -0.40824829, -0.81649658],\n [-0.40824829, -0.81649658, -0.40824829],\n [-0.31622777, -0.94868330, 0.00000000],\n [-0.70710678, 0.00000000, -0.70710678],\n [-0.81649658, -0.40824829, -0.40824829],\n [-0.70710678, -0.70710678, 0.00000000],\n [-0.94868330, 0.00000000, -0.31622777],\n [-0.94868330, -0.31622777, 0.00000000],\n [-1.00000000, 0.00000000, 0.00000000],\n [0.00000000, -0.31622777, 0.94868330],\n [0.00000000, -0.70710678, 0.70710678],\n [0.00000000, -0.94868330, 0.31622777],\n [0.40824829, -0.40824829, 0.81649658],\n [0.40824829, -0.81649658, 0.40824829],\n [0.31622777, -0.94868330, 0.00000000],\n [0.81649658, -0.40824829, 0.40824829],\n [0.70710678, -0.70710678, 0.00000000],\n [0.94868330, -0.31622777, 0.00000000],\n [0.31622777, 0.00000000, -0.94868330],\n [0.40824829, 0.40824829, -0.81649658],\n [0.40824829, 0.81649658, -0.40824829],\n [0.70710678, 0.00000000, -0.70710678],\n [0.81649658, 0.40824829, -0.40824829],\n [0.94868330, 0.00000000, -0.31622777],\n [0.40824829, -0.40824829, -0.81649658],\n [0.40824829, -0.81649658, -0.40824829],\n [0.81649658, -0.40824829, -0.40824829],\n ])\n self.assertTrue(is_same_basis(basis.T, basis_golden.T))", + "completions": [ + "```python\n def test_generate_basis_golden(self):\n basis = generate_basis(\n base_shape='tetrahedron', angular_tesselation=10, remove_symmetries=True, eps=1e-4\n )\n self.assertTrue(np.allclose(basis, np.array([[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0]])))\n```\n" + ], + "line": 75, + "token": 692, + "line_diff": 5, + "token_diff": 6 + }, + { + "id": 41, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def t_map(self, t: Union[float, torch.Tensor]) -> torch.Tensor:\n if not isinstance(t, torch.Tensor):\n t = torch.Tensor([t]).float()\n\n t_max = self.t_max.to(t.device)\n t_min = self.t_min.to(t.device)\n t_tilde = t_min + (t_max - t_min) * t\n\n return t_tilde\n```\n###test function signature:\n```python\n\ndef test_noise_schedule_diffusion_coeff(kind):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_noise_schedule_diffusion_coeff", + "reference": "\ndef test_noise_schedule_diffusion_coeff(kind):\n noise_schedule = GaussianNoiseSchedule(kind=kind, log_snr_range=(-6, 6))\n ts = torch.linspace(1e-2, 1 - 1e-2, 10)\n t_map = noise_schedule.t_map(ts) # map time to the prescribed log_SNR range\n\n if kind == \"log_snr\":\n g = noise_schedule.g(ts)\n\n # compute true beta_t\n l_range = noise_schedule.log_snr_range\n snr = torch.exp((1 - t_map) * l_range[1] + t_map * l_range[0])\n g_true = (-(l_range[0] - l_range[1]) / (snr + 1)).sqrt()\n\n assert torch.allclose(g, g_true, atol=1e-4)\n\n if kind == \"ot_linear\":\n g = noise_schedule.g(ts)\n g_true = (2.0 * t_map / (1.0 - t_map)).sqrt()\n assert torch.allclose(g, g_true, atol=1e-4)", + "completions": [ + "```python\ndef test_noise_schedule_diffusion_coeff(kind):\n from torch.testing import assert_allclose\n from torch.nn import functional as F\n from torch import tensor\n from torch.distributions import Uniform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import Transform\n from torch.distributions.transforms import SigmoidTransform\n from torch.distributions.transforms import TanhTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import ComposeTransform\n from torch.distributions.transforms import PowerTransform\n from torch.distributions.transforms import ExpTransform\n from torch.distributions.transforms import LogTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import AffineTransform\n from torch.distributions.transforms import" + ], + "line": 17, + "token": 143, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 42, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_data_quality", + "reference": "\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index, temperature_data, temperature_mean=False, data_quality=True\n )\n assert df.shape == (27, 4)\n assert list(sorted(df.columns)) == [\n \"n_days_dropped\",\n \"n_days_kept\",\n \"temperature_not_null\",\n \"temperature_null\",\n ]\n assert round(df.temperature_not_null.mean(), 2) == 729.23\n assert round(df.temperature_null.mean(), 2) == 0.0", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_data_quality(\n il_electricity_cdd_hdd_billing_monthly,\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"1H\", tz=\"America/Chicago\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0,\n 10.0" + ], + "line": 168, + "token": 984, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 43, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def unload_model(self, model_id: str) -> RegisteredModels:\n self.__ensure_v1_client_mode()\n de_aliased_model_id = resolve_roboflow_model_alias(model_id=model_id)\n response = requests.post(\n f\"{self.__api_url}/model/remove\",\n json={\n \"model_id\": de_aliased_model_id,\n },\n headers=DEFAULT_HEADERS,\n )\n response.raise_for_status()\n response_payload = response.json()\n if (\n de_aliased_model_id == self.__selected_model\n or model_id == self.__selected_model\n ):\n self.__selected_model = None\n return RegisteredModels.from_dict(response_payload)\n```\n###test function signature:\n```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_client_unload_single_model_when_successful_response_expected", + "reference": "\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n api_url = \"http://some.com\"\n requests_mock.post(\n f\"{api_url}/model/remove\",\n json={\"models\": [{\"model_id\": \"some/1\", \"task_type\": \"classification\"}]},\n )\n http_client = InferenceHTTPClient(api_key=\"my-api-key\", api_url=api_url)\n\n # when\n result = http_client.unload_model(model_id=\"some/1\")\n\n # then\n assert result == RegisteredModels(\n models=[ModelDescription(model_id=\"some/1\", task_type=\"classification\")]\n )\n assert requests_mock.last_request.json() == {\n \"model_id\": \"some/1\",\n }", + "completions": [ + "```python\n\ndef test_client_unload_single_model_when_successful_response_expected(\n requests_mock: Mocker,\n) -> None:\n # given\n model_id = \"my_model_id\"\n de_aliased_model_id = \"my_de_aliased_model_id\"\n response_payload = {\n \"model_id\": de_aliased_model_id,\n \"model_name\": \"my_model_name\",\n \"model_version\": \"my_model_version\",\n \"model_type\": \"my_model_type\",\n \"model_status\": \"my_model_status\",\n \"model_description\": \"my_model_description\",\n \"model_tags\": [\"my_model_tag1\", \"my_model_tag2\"],\n \"model_created_at\": \"my_model_created_at\",\n \"model_updated_at\": \"my_model_updated_at\",\n }\n requests_mock.post(\n f\"{self.__api_url}/model/remove\",\n json=response_payload,\n headers=DEFAULT_HEADERS,\n )\n # when\n response = self.client.unload_model(model_id=model_id)\n # then\n assert response.model_id == de_aliased_model_id\n assert response.model_name == \"my_model_name\"\n assert response.model_version == \"my_model_version\"\n assert response.model_type == \"my_model_type\"\n assert response.model_status == \"my_model_status\"\n assert response.model_description == \"my_model_description\"\n assert response.model_tags == [\"my_model_tag1\", \"my_model_tag2\"]\n assert response.model_created_at == \"my_model_created_at\"\n assert response.model_updated_at == \"my_model_updated_at\"\n assert self.client.selected_model is None\n```\n" + ], + "line": 31, + "token": 162, + "line_diff": 2, + "token_diff": 1 + }, + { + "id": 44, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=20.0,\n max_iteration=100,\n distance_strict=False,\n seed=None,\n):\n rng = np.random.default_rng(seed=seed)\n units_locations = np.zeros((num_units, 3), dtype=\"float32\")\n\n minimum_x, maximum_x = np.min(channel_locations[:, 0]) - margin_um, np.max(channel_locations[:, 0]) + margin_um\n minimum_y, maximum_y = np.min(channel_locations[:, 1]) - margin_um, np.max(channel_locations[:, 1]) + margin_um\n\n units_locations[:, 0] = rng.uniform(minimum_x, maximum_x, size=num_units)\n units_locations[:, 1] = rng.uniform(minimum_y, maximum_y, size=num_units)\n units_locations[:, 2] = rng.uniform(minimum_z, maximum_z, size=num_units)\n\n if minimum_distance is not None:\n solution_found = False\n renew_inds = None\n for i in range(max_iteration):\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n inds0, inds1 = np.nonzero(distances < minimum_distance)\n mask = inds0 != inds1\n inds0 = inds0[mask]\n inds1 = inds1[mask]\n\n if inds0.size > 0:\n if renew_inds is None:\n renew_inds = np.unique(inds0)\n else:\n # random only bad ones in the previous set\n renew_inds = renew_inds[np.isin(renew_inds, np.unique(inds0))]\n\n units_locations[:, 0][renew_inds] = rng.uniform(minimum_x, maximum_x, size=renew_inds.size)\n units_locations[:, 1][renew_inds] = rng.uniform(minimum_y, maximum_y, size=renew_inds.size)\n units_locations[:, 2][renew_inds] = rng.uniform(minimum_z, maximum_z, size=renew_inds.size)\n else:\n solution_found = True\n break\n\n if not solution_found:\n if distance_strict:\n raise ValueError(\n f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=} \"\n \"You can use distance_strict=False or reduce minimum distance\"\n )\n else:\n warnings.warn(f\"generate_unit_locations(): no solution for {minimum_distance=} and {max_iteration=}\")\n\n return units_locations\n```\n###test function signature:\n```python\n\ndef test_generate_unit_locations():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_generate_unit_locations", + "reference": "\ndef test_generate_unit_locations():\n seed = 0\n\n probe = generate_multi_columns_probe(num_columns=2, num_contact_per_column=20, xpitch=20, ypitch=20)\n channel_locations = probe.contact_positions\n\n num_units = 100\n minimum_distance = 20.0\n unit_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=20.0,\n minimum_z=5.0,\n maximum_z=40.0,\n minimum_distance=minimum_distance,\n max_iteration=500,\n distance_strict=False,\n seed=seed,\n )\n distances = np.linalg.norm(unit_locations[:, np.newaxis] - unit_locations[np.newaxis, :], axis=2)\n dist_flat = np.triu(distances, k=1).flatten()\n dist_flat = dist_flat[dist_flat > 0]\n assert np.all(dist_flat > minimum_distance)\n\n # import matplotlib.pyplot as plt\n # fig, ax = plt.subplots()\n # ax.hist(dist_flat, bins = np.arange(0, 400, 10))\n # fig, ax = plt.subplots()\n # from probeinterface.plotting import plot_probe\n # plot_probe(probe, ax=ax)\n # ax.scatter(unit_locations[:, 0], unit_locations[:,1], marker='*', s=20)\n # plt.show()", + "completions": [ + "```python\n\ndef test_generate_unit_locations():\n num_units = 10\n channel_locations = np.array([[0, 0, 0], [10, 10, 10], [20, 20, 20]])\n margin_um = 20.0\n minimum_z = 5.0\n maximum_z = 40.0\n minimum_distance = 20.0\n max_iteration = 100\n distance_strict = False\n seed = None\n\n units_locations = generate_unit_locations(\n num_units,\n channel_locations,\n margin_um=margin_um,\n minimum_z=minimum_z,\n maximum_z=maximum_z,\n minimum_distance=minimum_distance,\n max_iteration=max_iteration,\n distance_strict=distance_strict,\n seed=seed,\n )\n\n assert units_locations.shape == (num_units, 3)\n assert np.all(units_locations[:, 0] >= margin_um)\n assert np.all(units_locations[:, 0] <= 100 - margin_um)\n assert np.all(units_locations[:, 1] >= margin_um)\n assert np.all(units_locations[:, 1] <= 100 - margin_um)\n assert np.all(units_locations[:, 2] >= minimum_z)\n assert np.all(units_locations[:, 2] <= maximum_z)\n\n # check distance\n distances = np.linalg.norm(units_locations[:, np.newaxis] - units_locations[np.newaxis, :], axis=2)\n assert np.all(distances >= minimum_distance)\n\n # check no duplicate\n assert np.all(np.unique(units_locations, axis=0) == units_locations)\n```\n" + ], + "line": 58, + "token": 453, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 45, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def to_dataframe(\n self, index: Optional[Union[str, None]] = \"date\", sort_by: Optional[str] = None\n ) -> pd.DataFrame:\n\n def is_list_of_basemodel(items: Union[List[T], T]) -> bool:\n return isinstance(items, list) and all(\n isinstance(item, BaseModel) for item in items\n )\n\n if self.results is None or not self.results:\n raise OpenBBError(\"Results not found.\")\n\n if isinstance(self.results, pd.DataFrame):\n return self.results\n\n try:\n res = self.results\n df = None\n sort_columns = True\n\n # List[Dict]\n if isinstance(res, list) and len(res) == 1 and isinstance(res[0], dict):\n r = res[0]\n dict_of_df = {}\n\n for k, v in r.items():\n # Dict[str, List[BaseModel]]\n if is_list_of_basemodel(v):\n dict_of_df[k] = basemodel_to_df(v, index)\n sort_columns = False\n # Dict[str, Any]\n else:\n dict_of_df[k] = pd.DataFrame(v)\n\n df = pd.concat(dict_of_df, axis=1)\n\n # List[BaseModel]\n elif is_list_of_basemodel(res):\n dt: Union[List[Data], Data] = res # type: ignore\n df = basemodel_to_df(dt, index)\n sort_columns = False\n # List[List | str | int | float] | Dict[str, Dict | List | BaseModel]\n else:\n try:\n df = pd.DataFrame(res)\n # Set index, if any\n if index is not None and index in df.columns:\n df.set_index(index, inplace=True)\n\n except ValueError:\n if isinstance(res, dict):\n df = pd.DataFrame([res])\n\n if df is None:\n raise OpenBBError(\"Unsupported data format.\")\n\n # Drop columns that are all NaN, but don't rearrange columns\n if sort_columns:\n df.sort_index(axis=1, inplace=True)\n df = df.dropna(axis=1, how=\"all\")\n\n # Sort by specified column\n if sort_by:\n df.sort_values(by=sort_by, inplace=True)\n\n except OpenBBError as e:\n raise e\n except ValueError as ve:\n raise OpenBBError(\n f\"ValueError: {ve}. Ensure the data format matches the expected format.\"\n ) from ve\n except TypeError as te:\n raise OpenBBError(\n f\"TypeError: {te}. Check the data types in your results.\"\n ) from te\n except Exception as ex:\n raise OpenBBError(f\"An unexpected error occurred: {ex}\") from ex\n\n return df\n```\n###test function signature:\n```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_to_dataframe_w_args", + "reference": "def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n co = OBBject(results=results)\n\n # Act and Assert\n result = co.to_dataframe(index=index, sort_by=sort_by)\n assert isinstance(result, pd.DataFrame)\n assert result.index.name == index\n\n # check if dataframe is properly sorted\n assert result[sort_by].is_monotonic_increasing", + "completions": [ + "```python\ndef test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n from openbb_terminal.decorators import log_start_end\n from openbb_terminal.stocks.fundamental_analysis import financedatabase_model\n\n @log_start_end(log=logger)\n def test_to_dataframe_w_args(results, index, sort_by):\n # Arrange\n df = financedatabase_model.to_dataframe(results, index, sort_by)\n # Act\n # Assert\n assert df.shape[0] > 0\n assert df.shape[1] > 0\n assert df.index.name == index\n assert df.columns.tolist() == sort_by\n\n # Act\n test_to_dataframe_w_args(results, index, sort_by)\n # Assert\n```\n" + ], + "line": 76, + "token": 566, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 46, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef unique(ar, return_index=False, return_inverse=False,\n return_counts=False, axis=None, *, equal_nan=True):\n ar = np.asanyarray(ar)\n if axis is None:\n ret = _unique1d(ar, return_index, return_inverse, return_counts, \n equal_nan=equal_nan)\n return _unpack_tuple(ret)\n\n # axis was specified and not None\n try:\n ar = np.moveaxis(ar, axis, 0)\n except np.AxisError:\n # this removes the \"axis1\" or \"axis2\" prefix from the error message\n raise np.AxisError(axis, ar.ndim) from None\n\n # Must reshape to a contiguous 2D array for this to work...\n orig_shape, orig_dtype = ar.shape, ar.dtype\n ar = ar.reshape(orig_shape[0], np.prod(orig_shape[1:], dtype=np.intp))\n ar = np.ascontiguousarray(ar)\n dtype = [('f{i}'.format(i=i), ar.dtype) for i in range(ar.shape[1])]\n\n # At this point, `ar` has shape `(n, m)`, and `dtype` is a structured\n # data type with `m` fields where each field has the data type of `ar`.\n # In the following, we create the array `consolidated`, which has\n # shape `(n,)` with data type `dtype`.\n try:\n if ar.shape[1] > 0:\n consolidated = ar.view(dtype)\n else:\n # If ar.shape[1] == 0, then dtype will be `np.dtype([])`, which is\n # a data type with itemsize 0, and the call `ar.view(dtype)` will\n # fail. Instead, we'll use `np.empty` to explicitly create the\n # array with shape `(len(ar),)`. Because `dtype` in this case has\n # itemsize 0, the total size of the result is still 0 bytes.\n consolidated = np.empty(len(ar), dtype=dtype)\n except TypeError as e:\n # There's no good way to do this for object arrays, etc...\n msg = 'The axis argument to unique is not supported for dtype {dt}'\n raise TypeError(msg.format(dt=ar.dtype)) from e\n\n def reshape_uniq(uniq):\n n = len(uniq)\n uniq = uniq.view(orig_dtype)\n uniq = uniq.reshape(n, *orig_shape[1:])\n uniq = np.moveaxis(uniq, 0, axis)\n return uniq\n\n output = _unique1d(consolidated, return_index,\n return_inverse, return_counts, equal_nan=equal_nan)\n output = (reshape_uniq(output[0]),) + output[1:]\n return _unpack_tuple(output)\n```\n###test function signature:\n```python\n\n def test_unique_1d(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_unique_1d", + "reference": "\n def test_unique_1d(self):\n\n def check_all(a, b, i1, i2, c, dt):\n base_msg = 'check {0} failed for type {1}'\n\n msg = base_msg.format('values', dt)\n v = unique(a)\n assert_array_equal(v, b, msg)\n\n msg = base_msg.format('return_index', dt)\n v, j = unique(a, True, False, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i1, msg)\n\n msg = base_msg.format('return_inverse', dt)\n v, j = unique(a, False, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, i2, msg)\n\n msg = base_msg.format('return_counts', dt)\n v, j = unique(a, False, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j, c, msg)\n\n msg = base_msg.format('return_index and return_inverse', dt)\n v, j1, j2 = unique(a, True, True, False)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n\n msg = base_msg.format('return_index and return_counts', dt)\n v, j1, j2 = unique(a, True, False, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format('return_inverse and return_counts', dt)\n v, j1, j2 = unique(a, False, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i2, msg)\n assert_array_equal(j2, c, msg)\n\n msg = base_msg.format(('return_index, return_inverse '\n 'and return_counts'), dt)\n v, j1, j2, j3 = unique(a, True, True, True)\n assert_array_equal(v, b, msg)\n assert_array_equal(j1, i1, msg)\n assert_array_equal(j2, i2, msg)\n assert_array_equal(j3, c, msg)\n\n a = [5, 7, 1, 2, 1, 5, 7]*10\n b = [1, 2, 5, 7]\n i1 = [2, 3, 0, 1]\n i2 = [2, 3, 0, 1, 0, 2, 3]*10\n c = np.multiply([2, 1, 2, 2], 10)\n\n # test for numeric arrays\n types = []\n types.extend(np.typecodes['AllInteger'])\n types.extend(np.typecodes['AllFloat'])\n types.append('datetime64[D]')\n types.append('timedelta64[D]')\n for dt in types:\n aa = np.array(a, dt)\n bb = np.array(b, dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for object arrays\n dt = 'O'\n aa = np.empty(len(a), dt)\n aa[:] = a\n bb = np.empty(len(b), dt)\n bb[:] = b\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for structured arrays\n dt = [('', 'i'), ('', 'i')]\n aa = np.array(list(zip(a, a)), dt)\n bb = np.array(list(zip(b, b)), dt)\n check_all(aa, bb, i1, i2, c, dt)\n\n # test for ticket #2799\n aa = [1. + 0.j, 1 - 1.j, 1]\n assert_array_equal(np.unique(aa), [1. - 1.j, 1. + 0.j])\n\n # test for ticket #4785\n a = [(1, 2), (1, 2), (2, 3)]\n unq = [1, 2, 3]\n inv = [0, 1, 0, 1, 1, 2]\n a1 = unique(a)\n assert_array_equal(a1, unq)\n a2, a2_inv = unique(a, return_inverse=True)\n assert_array_equal(a2, unq)\n assert_array_equal(a2_inv, inv)\n\n # test for chararrays with return_inverse (gh-5099)\n a = np.chararray(5)\n a[...] = ''\n a2, a2_inv = np.unique(a, return_inverse=True)\n assert_array_equal(a2_inv, np.zeros(5))\n\n # test for ticket #9137\n a = []\n a1_idx = np.unique(a, return_index=True)[1]\n a2_inv = np.unique(a, return_inverse=True)[1]\n a3_idx, a3_inv = np.unique(a, return_index=True,\n return_inverse=True)[1:]\n assert_equal(a1_idx.dtype, np.intp)\n assert_equal(a2_inv.dtype, np.intp)\n assert_equal(a3_idx.dtype, np.intp)\n assert_equal(a3_inv.dtype, np.intp)\n\n # test for ticket 2111 - float\n a = [2.0, np.nan, 1.0, np.nan]\n ua = [1.0, 2.0, np.nan]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - complex\n a = [2.0-1j, np.nan, 1.0+1j, complex(0.0, np.nan), complex(1.0, np.nan)]\n ua = [1.0+1j, 2.0-1j, complex(0.0, np.nan)]\n ua_idx = [2, 0, 3]\n ua_inv = [1, 2, 0, 2, 2]\n ua_cnt = [1, 1, 3]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - datetime64\n nat = np.datetime64('nat')\n a = [np.datetime64('2020-12-26'), nat, np.datetime64('2020-12-24'), nat]\n ua = [np.datetime64('2020-12-24'), np.datetime64('2020-12-26'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for ticket 2111 - timedelta\n nat = np.timedelta64('nat')\n a = [np.timedelta64(1, 'D'), nat, np.timedelta64(1, 'h'), nat]\n ua = [np.timedelta64(1, 'h'), np.timedelta64(1, 'D'), nat]\n ua_idx = [2, 0, 1]\n ua_inv = [1, 2, 0, 2]\n ua_cnt = [1, 1, 2]\n assert_equal(np.unique(a), ua)\n assert_equal(np.unique(a, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(a, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(a, return_counts=True), (ua, ua_cnt))\n\n # test for gh-19300\n all_nans = [np.nan] * 4\n ua = [np.nan]\n ua_idx = [0]\n ua_inv = [0, 0, 0, 0]\n ua_cnt = [4]\n assert_equal(np.unique(all_nans), ua)\n assert_equal(np.unique(all_nans, return_index=True), (ua, ua_idx))\n assert_equal(np.unique(all_nans, return_inverse=True), (ua, ua_inv))\n assert_equal(np.unique(all_nans, return_counts=True), (ua, ua_cnt))", + "completions": [ + "```python\n def test_unique_1d(self):\n ar = np.array([1, 2, 3, 4, 5, 5, 5, 5, 5, 5])\n self.assertEqual(unique(ar), np.array([1, 2, 3, 4, 5]))\n self.assertEqual(unique(ar, return_index=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n self.assertEqual(unique(ar, return_index=True, return_inverse=True, return_counts=True), (np.array([1, 2, 3, 4, 5]), np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))\n```\n" + ], + "line": 56, + "token": 571, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 47, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef scaled_dot_product_attention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n att_mask: Optional[Union[AttentionMask, \"SparseCS\", torch.Tensor]],\n dropout: Optional[torch.nn.Module] = None,\n block_size: int = 128,\n) -> torch.Tensor:\n autocast_disabled = (\n _has_cpp_library\n and isinstance(att_mask, SparseCS)\n or (att_mask is not None and att_mask.is_sparse)\n )\n seq_len = q.shape[-2]\n\n # switch if:\n # causal is required but mask is not sparse\n # fp16 or under amp context\n # sequence length is divisible by block size\n # same seq len for K and Q\n switch_to_blocksparse = (\n _is_blocksparse_available\n and (att_mask is not None and not att_mask.is_sparse)\n and (isinstance(att_mask, AttentionMask) and att_mask.is_causal)\n and (q.dtype == torch.float16 or torch.is_autocast_enabled())\n and not seq_len % block_size\n and q.shape[-2] == k.shape[-2]\n )\n\n if switch_to_blocksparse:\n logger.info(\"Switching causal attention to Triton blocksparse...\")\n return blocksparse_attention(q, k, v, dropout, block_size)\n\n with torch.cuda.amp.autocast(enabled=False) if autocast_disabled else nullcontext(): # type: ignore\n if autocast_disabled:\n q, k, v = q.float(), k.float(), v.float()\n\n att = scaled_query_key_softmax(q, k, att_mask=att_mask)\n\n # Optional dropout, could be part of the masking in the future\n att = _apply_dropout(att, dropout)\n\n # Get to the predicted values, for all heads\n # y = att @ v # (N, S, S) x (N, S, hs) -> (N, S, hs)\n y = bmm(att, v)\n return y\n```\n###test function signature:\n```python\n\ndef test_amp_attention_dense(device):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_amp_attention_dense", + "reference": "\ndef test_amp_attention_dense(device):\n b, s, d = 8, 64, 32\n prob = 0.9\n\n a = torch.rand(b, s, d, device=device)\n m = torch.rand(s, s, device=device) > prob\n\n with torch.cuda.amp.autocast():\n r = scaled_dot_product_attention(a, a, a, m)\n\n expected_device = torch.float16 if device == \"cuda\" else torch.float32\n assert r.dtype == expected_device", + "completions": [ + "```python\n\ndef test_amp_attention_dense(device):\n q = torch.randn(1, 128, 128, device=device)\n k = torch.randn(1, 128, 128, device=device)\n v = torch.randn(1, 128, 128, device=device)\n att_mask = torch.randn(1, 128, 128, device=device)\n dropout = torch.nn.Dropout(0.1)\n block_size = 128\n y = scaled_dot_product_attention(q, k, v, att_mask, dropout, block_size)\n assert y.shape == (1, 128, 128)\n```\n" + ], + "line": 50, + "token": 390, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 48, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def quantize(\n cls,\n tensor: Tensor,\n nbits: int = 4,\n channel_wise: bool = True,\n group_size: int = 64,\n optimize: bool = False,\n round_zero: bool = False,\n axis: int = 0,\n bitpack: bool = True,\n compute_dtype: torch.dtype | None = None,\n view_as_float: bool = False,\n device: str = \"cuda\",\n ) -> tuple:\n assert nbits in Quantizer.SUPPORTED_BITS, (\n \"nbits=\" + str(nbits) + \" not supported.\"\n )\n assert axis in [0, 1], \"axis should be either 0 or 1\"\n if group_size is not None:\n assert is_divisible(tensor.numel(), group_size), (\n \"group_size should be divisble by the total tensor dimensions. shape: \"\n + str(tensor.shape)\n + \", group_size: \"\n + str(group_size)\n )\n\n W = tensor.float()\n shape = W.shape\n\n # Reshape for grouping\n if (group_size is not None) and channel_wise:\n W = (\n W.reshape([-1, group_size])\n if (axis == 1)\n else W.reshape([group_size, -1])\n )\n\n # Get min/max values\n if not channel_wise:\n _min, _max = W.min(), W.max()\n optimize = False\n else:\n _min = W.min(axis=axis, keepdim=True)[0]\n _max = W.max(axis=axis, keepdim=True)[0]\n\n max_v = 2**nbits - 1\n min_v = 0\n min_max = [min_v, max_v]\n\n # Note: here we work with the inverse of the scale to avoid division and quantize instead via W*scale + zero, the scale is inverted later on.\n scale = (max_v / (_max - _min)).clamp(\n max=2e4\n ) # clamp to avoid half-precision problems\n zero = -_min * scale\n\n # Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14\n if round_zero:\n zero = torch.round(zero)\n\n # Fine-tune weights\n if optimize:\n W_q, scale, zero = Quantizer.optimize_weights(\n tensor=W,\n scale=scale,\n zero=zero,\n min_max=min_max,\n axis=axis,\n device=device,\n )\n else:\n W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])\n\n # Store meta-data (we invert the scale for dequantization)\n meta = {\n \"nbits\": nbits,\n \"group_size\": group_size,\n \"shape\": shape,\n \"scale\": 1.0 / scale,\n \"zero\": zero,\n \"axis\": axis,\n \"packing\": Quantizer.bit_to_packing[nbits],\n }\n meta[\"unpack_view_dtype\"] = Quantizer.unpack_view_dtype[meta[\"packing\"]]\n\n # Pack bits\n meta[\"view_as_float\"] = view_as_float\n if bitpack:\n W_q = Quantizer.pack[meta[\"packing\"]](W_q)\n if view_as_float:\n W_q = W_q.view(\n torch.float32 if compute_dtype is None else compute_dtype\n ) # store quantized weights as compute_dtype\n else:\n W_q = W_q.to(tensor.dtype)\n meta[\"packing\"] = None\n\n # cleanup\n del W, _min, _max\n torch.cuda.empty_cache()\n\n return W_q, meta\n```\n###test function signature:\n```python\n\n def test_quantizer(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_quantizer", + "reference": "\n def test_quantizer(self): \n for compute_dtype in [torch.float32, torch.float16, torch.bfloat16]:\n for nbits in [8,4,3,2,1]:\n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, view_as_float=False)\n if nbits == 3:\n assert W_q.dtype == torch.int32\n else:\n assert W_q.dtype == torch.uint8\n w_dq = Quantizer.dequantize(W_q, meta)\n norm1 = torch.norm(self.w - w_dq, p=0.7)\n \n W_q, meta = Quantizer.quantize(self.w, nbits=nbits, round_zero=True, optimize=True, compute_dtype=compute_dtype, view_as_float=True)\n assert W_q.dtype == compute_dtype\n w_dq = Quantizer.dequantize(W_q, meta)\n norm2 = torch.norm(self.w - w_dq, p=0.7)\n \n self.assertTrue(torch.equal(norm1, norm2))", + "completions": [ + "```python\n\n def test_quantizer(self):\n # test case 1\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 2\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=False, round_zero=True, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], 0.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 3\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor, nbits=4, channel_wise=True, group_size=64, optimize=True, round_zero=False, axis=0, bitpack=True, compute_dtype=None, view_as_float=False, device=\"cuda\")\n self.assertEqual(W_q.shape, (1, 3, 224, 224))\n self.assertEqual(meta[\"nbits\"], 4)\n self.assertEqual(meta[\"group_size\"], 64)\n self.assertEqual(meta[\"shape\"], (1, 3, 224, 224))\n self.assertEqual(meta[\"scale\"], 1.0 / 15.999999999999998)\n self.assertEqual(meta[\"zero\"], -1.0)\n self.assertEqual(meta[\"axis\"], 0)\n self.assertEqual(meta[\"packing\"], \"uint4\")\n self.assertEqual(meta[\"unpack_view_dtype\"], torch.float32)\n self.assertEqual(meta[\"view_as_float\"], False)\n\n # test case 4\n tensor = torch.randn(1, 3, 224, 224)\n W_q, meta = quantize(tensor," + ], + "line": 100, + "token": 645, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 49, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_num_ref_area", + "reference": "\ndef test_nanoplot_out_bars_with_num_ref_area():\n\n out_bars_with_num_ref_area = _generate_nanoplot(**CASES[13])\n\n assert _is_nanoplot_output(out_bars_with_num_ref_area)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_num_ref_area,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_num_ref_area():\n y_vals = [1, 2, 3, 4, 5]\n y_ref_area = [1, 3]\n y_ref_line = 2\n x_vals = [1, 2, 3, 4, 5]\n expand_x = [1, 2, 3, 4, 5]\n expand_y = [1, 2, 3, 4, 5]\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5]\n all_single_y_vals = [1, 2, 3, 4, 5]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_stroke_width=data_line_stroke_width,\n data_area_fill_color=data_area_fill_color,\n data_bar_stroke_color=data_bar_stroke_color,\n data_bar_stroke_width=data_bar_stroke_width,\n data_bar_fill_color=data_bar_fill_" + ], + "line": 749, + "token": 7038, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 50, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef mean(x, axis=None, keepdims=False):\n if isinstance(x, (list, tuple)):\n x = stack(x)\n x = convert_to_tensor(x)\n if axis == () or axis == []:\n # Torch handles the empty axis case differently from numpy.\n return x\n axis = to_tuple_or_list(axis) # see [NB] below\n\n ori_dtype = standardize_dtype(x.dtype)\n # torch.mean only supports floating point inputs\n compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n if \"int\" in ori_dtype or ori_dtype == \"bool\":\n result_dtype = compute_dtype\n else:\n result_dtype = ori_dtype\n\n # [NB] the python torch op torch.mean() is generated into\n # `torch._C._VariableFunctions.pyi`, and the method\n # signature is overloaded.\n # Dynamo won't actually find the correct signature of\n # `torch.mean()` if arguments are passed via kwargs\n # So we have to pass the arguments via positional args\n # EXCEPT for those that are forced as kwargs via the `*`\n # delimiter in the overloaded method signatures.\n # Additionally, we have to create a singleton-tuple\n # when `axis` is an int to match the existing fn signature\n result = torch.mean(\n x,\n axis,\n keepdims,\n dtype=to_torch_dtype(compute_dtype),\n )\n return cast(result, result_dtype)\n```\n###test function signature:\n```python\n\n def test_mean(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mean", + "reference": "\n def test_mean(self):\n x = np.array([[1, 2, 3], [3, 2, 1]])\n self.assertAllClose(knp.mean(x), np.mean(x))\n self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n self.assertAllClose(\n knp.mean(x, axis=1, keepdims=True),\n np.mean(x, axis=1, keepdims=True),\n )\n\n self.assertAllClose(knp.Mean()(x), np.mean(x))\n self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n self.assertAllClose(\n knp.Mean(axis=1, keepdims=True)(x),\n np.mean(x, axis=1, keepdims=True),\n )\n\n # test overflow\n x = np.array([65504, 65504, 65504], dtype=\"float16\")\n self.assertAllClose(knp.mean(x), np.mean(x))", + "completions": [ + "```python\n\n def test_mean(self):\n x = torch.randn(2, 3, 4)\n self.assertEqual(mean(x), torch.mean(x))\n self.assertEqual(mean(x, axis=0), torch.mean(x, dim=0))\n self.assertEqual(mean(x, axis=1), torch.mean(x, dim=1))\n self.assertEqual(mean(x, axis=2), torch.mean(x, dim=2))\n self.assertEqual(mean(x, axis=(0, 1)), torch.mean(x, dim=(0, 1)))\n self.assertEqual(mean(x, axis=(0, 2)), torch.mean(x, dim=(0, 2)))\n self.assertEqual(mean(x, axis=(1, 2)), torch.mean(x, dim=(1, 2)))\n self.assertEqual(mean(x, axis=(0, 1, 2)), torch.mean(x, dim=(0, 1, 2)))\n self.assertEqual(mean(x, axis=0, keepdims=True), torch.mean(x, dim=0, keepdim=True))\n self.assertEqual(mean(x, axis=1, keepdims=True), torch.mean(x, dim=1, keepdim=True))\n self.assertEqual(mean(x, axis=2, keepdims=True), torch.mean(x, dim=2, keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1), keepdims=True), torch.mean(x, dim=(0, 1), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 2), keepdims=True), torch.mean(x, dim=(0, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(1, 2), keepdims=True), torch.mean(x, dim=(1, 2), keepdim=True))\n self.assertEqual(mean(x, axis=(0, 1, 2), keepdims=True), torch.mean(x, dim=(0, 1, 2), keepdim=True))\n```\n" + ], + "line": 42, + "token": 322, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 51, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_intrinsic_value(\n cash_flow: float,\n growth_rate: float,\n perpetual_growth_rate: float,\n weighted_average_cost_of_capital: float,\n cash_and_cash_equivalents: float,\n total_debt: float,\n shares_outstanding: float,\n periods: int = 5,\n) -> pd.DataFrame:\n components = {}\n\n cash_flow_projection = [cash_flow]\n\n # Cash Flow to Use\n for period in range(1, periods + 1):\n if period == 1:\n cash_flow_projection.append(cash_flow_projection[0] * (1 + growth_rate))\n else:\n cash_flow_projection.append(\n cash_flow_projection[period - 1] * (1 + growth_rate)\n )\n\n # Calculate the Terminal Value\n terminal_value = (\n cash_flow_projection[-1]\n * (1 + perpetual_growth_rate)\n / (weighted_average_cost_of_capital - perpetual_growth_rate)\n )\n\n # Add Terminal Value to the end of the cash flow projection\n cash_flow_projection[-1] = cash_flow_projection[-1] + terminal_value\n\n # Calculate the Present Value based on the Discounted Cash Flow Formula\n cash_flow_present_value = []\n for index, cash_flow_value in enumerate(cash_flow_projection):\n cash_flow_present_value.append(\n cash_flow_value / (1 + weighted_average_cost_of_capital) ** (index + 1)\n )\n\n # Calculate the Enterprise Value\n enterprise_value = sum(cash_flow_present_value)\n\n # Calculate the Equity Value\n equity_value = enterprise_value + cash_and_cash_equivalents - total_debt\n\n # Calculate the Intrinsic Value\n intrinsic_value = equity_value / shares_outstanding\n\n # Combine the components into a dictionary\n components = {\n \"Terminal Value\": terminal_value,\n \"Cash Flow Projection\": cash_flow_projection[-1],\n \"Enterprise Value\": enterprise_value,\n \"Equity Value\": equity_value,\n \"Intrinsic Value\": intrinsic_value,\n }\n\n return pd.DataFrame.from_dict(\n components, orient=\"index\", columns=[f\"Periods = {periods}\"]\n )\n```\n###test function signature:\n```python\n\ndef test_get_intrinsic_value(recorder):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_intrinsic_value", + "reference": "\ndef test_get_intrinsic_value(recorder):\n recorder.capture(\n intrinsic_model.get_intrinsic_value(\n cash_flow=500,\n growth_rate=0.05,\n perpetual_growth_rate=0.02,\n weighted_average_cost_of_capital=0.1,\n cash_and_cash_equivalents=100,\n total_debt=100,\n shares_outstanding=100,\n periods=5,\n )\n )", + "completions": [ + "```python\n\ndef test_get_intrinsic_value(recorder):\n cash_flow = 1000000\n growth_rate = 0.05\n perpetual_growth_rate = 0.05\n weighted_average_cost_of_capital = 0.1\n cash_and_cash_equivalents = 1000000\n total_debt = 1000000\n shares_outstanding = 1000000\n periods = 5\n\n result = get_intrinsic_value(\n cash_flow,\n growth_rate,\n perpetual_growth_rate,\n weighted_average_cost_of_capital,\n cash_and_cash_equivalents,\n total_debt,\n shares_outstanding,\n periods,\n )\n\n recorder.capture(result)\n```\n" + ], + "line": 61, + "token": 379, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 52, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef to_numpy(\n evset: EventSet,\n timestamp_to_datetime: bool = True,\n timestamps: bool = True,\n) -> Dict[str, ndarray]:\n timestamp_key = \"timestamp\"\n index_names = evset.schema.index_names()\n feature_names = evset.schema.feature_names()\n\n column_names = index_names + feature_names\n if timestamps:\n column_names += [timestamp_key]\n\n dst = {column_name: [] for column_name in column_names}\n for index, data in evset.data.items():\n assert isinstance(index, tuple)\n\n if timestamps:\n # Timestamps\n if evset.schema.is_unix_timestamp and timestamp_to_datetime:\n dst[timestamp_key].append(\n data.timestamps.astype(\"datetime64[s]\")\n )\n else:\n dst[timestamp_key].append(data.timestamps)\n\n # Features\n for feature_name, feature in zip(feature_names, data.features):\n dst[feature_name].append(feature)\n\n # Indexes\n num_timestamps = len(data.timestamps)\n for index_name, index_item in zip(index_names, index):\n dst[index_name].append(np.repeat(index_item, num_timestamps))\n\n dst = {k: np.concatenate(v) for k, v in dst.items()}\n return dst\n```\n###test function signature:\n```python\n\n def test_correct(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct", + "reference": "\n def test_correct(self):\n evset = event_set(\n timestamps=[\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n features={\n \"feature_1\": [0.5, 0.6],\n \"my_index\": [\"red\", \"blue\"],\n },\n indexes=[\"my_index\"],\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"timestamp\": np.array(\n [\"2023-11-08T17:14:38\", \"2023-11-29T21:44:46\"],\n dtype=\"datetime64[s]\",\n ),\n \"feature_1\": np.array([0.5, 0.6]),\n \"my_index\": np.array([b\"red\", b\"blue\"]),\n }\n\n for k in expected:\n np.testing.assert_array_equal(\n np.sort(result[k]), np.sort(expected[k])\n )", + "completions": [ + "```python\n def test_correct(self):\n evset = EventSet(\n data={\n \"timestamp\": [1610000000, 1610000000, 1610000000, 1610000000],\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n },\n schema=EventSetSchema(\n index_names=[\"index\"],\n feature_names=[\"feature1\", \"feature2\"],\n is_unix_timestamp=True,\n ),\n )\n\n result = to_numpy(evset)\n\n expected = {\n \"index\": [0, 1, 2, 3],\n \"feature1\": [1, 2, 3, 4],\n \"feature2\": [1, 2, 3, 4],\n \"timestamp\": [\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n np.datetime64(\"2021-01-01T00:00:00\"),\n ],\n }\n\n self.assertDictEqual(result, expected)\n```\n" + ], + "line": 41, + "token": 276, + "line_diff": 3, + "token_diff": 2 + }, + { + "id": 53, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def from_file(cls, path: str) -> \"LayerConfig\":\n logger.info(f\"Loading route config from {path}\")\n _, ext = os.path.splitext(path)\n with open(path, \"r\") as f:\n if ext == \".json\":\n layer = json.load(f)\n elif ext in [\".yaml\", \".yml\"]:\n layer = yaml.safe_load(f)\n else:\n raise ValueError(\n \"Unsupported file type. Only .json and .yaml are supported\"\n )\n\n if not is_valid(json.dumps(layer)):\n raise Exception(\"Invalid config JSON or YAML\")\n\n encoder_type = layer[\"encoder_type\"]\n encoder_name = layer[\"encoder_name\"]\n routes = []\n for route_data in layer[\"routes\"]:\n # Handle the 'llm' field specially if it exists\n if \"llm\" in route_data and route_data[\"llm\"] is not None:\n llm_data = route_data.pop(\n \"llm\"\n ) # Remove 'llm' from route_data and handle it separately\n # Use the module path directly from llm_data without modification\n llm_module_path = llm_data[\"module\"]\n # Dynamically import the module and then the class from that module\n llm_module = importlib.import_module(llm_module_path)\n llm_class = getattr(llm_module, llm_data[\"class\"])\n # Instantiate the LLM class with the provided model name\n llm = llm_class(name=llm_data[\"model\"])\n # Reassign the instantiated llm object back to route_data\n route_data[\"llm\"] = llm\n\n # Dynamically create the Route object using the remaining route_data\n route = Route(**route_data)\n routes.append(route)\n\n return cls(\n encoder_type=encoder_type, encoder_name=encoder_name, routes=routes\n )\n```\n###test function signature:\n```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_from_file_yaml", + "reference": "\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n config_path = tmp_path / \"config.yaml\"\n config_path.write_text(\n layer_yaml()\n ) # Assuming layer_yaml() returns a valid YAML string\n\n # Load the LayerConfig from the temporary file\n layer_config = LayerConfig.from_file(str(config_path))\n\n # Assertions to verify the loaded configuration\n assert layer_config.encoder_type == \"cohere\"\n assert layer_config.encoder_name == \"embed-english-v3.0\"\n assert len(layer_config.routes) == 2\n assert layer_config.routes[0].name == \"politics\"", + "completions": [ + "```python\n\n def test_from_file_yaml(openai_encoder, tmp_path, index_cls):\n # Create a temporary YAML file with layer configuration\n yaml_file = tmp_path / \"test_layer.yaml\"\n yaml_file.write_text(\n \"\"\"\n encoder_type: openai\n encoder_name: gpt2\n routes:\n - name: test_route\n model_name: test_model\n llm:\n module: tests.test_layer\n class: TestLLM\n model: test_model\n encoder_kwargs:\n max_position_embeddings: 512\n num_attention_heads: 8\n intermediate_size: 2048\n hidden_act: gelu\n hidden_dropout_prob: 0.1\n attention_probs_dropout_prob: 0.1\n type_vocab_size: 16\n initializer_range: 0.02\n layer_norm_eps: 1e-12\n use_cache: True\n is_decoder: False\n classifier_dropout: 0.1\n use_multiplicative_input_mask: False\n num_labels: 2\n output_attentions: False\n output_hidden_states: False\n return_dict: True\n task_specific_params:\n summarization:\n max_length: 20\n min_length: 5\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n top_p: 0.9\n repetition_penalty: 1.5\n length_penalty: 1.0\n do_sample: True\n num_beams: 5\n early_stopping: True\n no_repeat_ngram_size: 3\n num_return_sequences: 3\n temperature: 0.7\n top_k: 10\n" + ], + "line": 49, + "token": 405, + "line_diff": 3, + "token_diff": 4 + }, + { + "id": 54, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def bind_all(self, circuit: QuantumCircuit) -> np.ndarray:\n arr = np.empty(self.shape, dtype=object)\n for idx in np.ndindex(self.shape):\n arr[idx] = self.bind(circuit, idx)\n return arr\n```\n###test function signature:\n```python\n def test_data(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_data", + "reference": " def test_data(self):\n with self.subTest(\"binding a single value\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 0)\n self.assertEqual(ba.shape, ())\n self.assertEqual(ba.size, 1)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuit = ba.bind(self.circuit, ())\n self.assertEqual(bound_circuit, self.circuit.assign_parameters(vals))\n\n with self.subTest(\"binding an array\"):\n vals = np.linspace(0, 1, 300).reshape((2, 3, 50))\n data = {self.params: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 50)\n self.assertEqual(ba.ndim, 2)\n self.assertEqual(ba.shape, (2, 3))\n self.assertEqual(ba.size, 6)\n self.assertEqual(ba.data, {tuple(param.name for param in self.params): vals})\n\n bound_circuits = ba.bind_all(self.circuit)\n self.assertEqual(bound_circuits.shape, (2, 3))\n self.assertEqual(bound_circuits[0, 0], self.circuit.assign_parameters(vals[0, 0]))\n self.assertEqual(bound_circuits[0, 1], self.circuit.assign_parameters(vals[0, 1]))\n self.assertEqual(bound_circuits[0, 2], self.circuit.assign_parameters(vals[0, 2]))\n self.assertEqual(bound_circuits[1, 0], self.circuit.assign_parameters(vals[1, 0]))\n self.assertEqual(bound_circuits[1, 1], self.circuit.assign_parameters(vals[1, 1]))\n self.assertEqual(bound_circuits[1, 2], self.circuit.assign_parameters(vals[1, 2]))\n\n with self.subTest(\"binding a single param\"):\n vals = np.linspace(0, 1, 50)\n data = {self.params[0]: vals}\n ba = BindingsArray(data=data)\n self.assertEqual(ba.num_parameters, 1)\n self.assertEqual(ba.ndim, 1)\n self.assertEqual(ba.shape, (50,))\n self.assertEqual(ba.size, 50)\n self.assertEqual(list(ba.data.keys()), [(self.params[0].name,)])\n np.testing.assert_allclose(list(ba.data.values()), [vals[..., np.newaxis]])", + "completions": [ + "```python\n def test_data(self):\n from qiskit import QuantumCircuit\n from qiskit.quantum_info import Pauli\n from qiskit.quantum_info import PauliTable\n from qiskit.quantum_info import PauliSumOp\n from qiskit.quantum_info import PauliOp\n from qiskit.quantum_info import PauliExpectation\n from qiskit.quantum_info import PauliMoment\n from qiskit.quantum_info import PauliTensor\n from qiskit.quantum_info import PauliList\n from qiskit.quantum_info import PauliTransform\n from qiskit.quantum_info import PauliTransformOp\n from qiskit.quantum_info import PauliTransformList\n from qiskit.quantum_info import PauliTransformSum\n from qiskit.quantum_info import PauliTransformProduct\n from qiskit.quantum_info import PauliTransformExpectation\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum_info import PauliTransformVariance\n from qiskit.quantum" + ], + "line": 15, + "token": 122, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 55, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef set_seed(\n seed: Optional[int] = None, debug_mode: Optional[Union[str, int]] = None\n) -> int:\n world_size, rank = get_world_size_and_rank()\n max_val = np.iinfo(np.uint32).max - world_size + 1\n min_val = np.iinfo(np.uint32).min\n if seed is None:\n rand_seed = torch.randint(min_val, max_val, (1,))\n seed = _broadcast_tensor(rand_seed, 0).item() # sync seed across ranks\n if seed < min_val or seed > max_val:\n raise ValueError(\n f\"Invalid seed value provided: {seed}. Value must be in the range [{min_val}, {max_val}]\"\n )\n local_seed = seed + rank\n _log.debug(\n f\"Setting manual seed to local seed {local_seed}. Local seed is seed + rank = {seed} + {rank}\"\n )\n\n torch.manual_seed(local_seed)\n np.random.seed(local_seed)\n random.seed(local_seed)\n\n if debug_mode is not None:\n _log.debug(f\"Setting deterministic debug mode to {debug_mode}\")\n torch.set_deterministic_debug_mode(debug_mode)\n deterministic_debug_mode = torch.get_deterministic_debug_mode()\n if deterministic_debug_mode == 0:\n _log.debug(\"Disabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = False\n torch.backends.cudnn.benchmark = True\n else:\n _log.debug(\"Enabling cuDNN deterministic mode\")\n torch.backends.cudnn.deterministic = True\n torch.backends.cudnn.benchmark = False\n # reference: https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility\n os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n\n return seed\n```\n###test function signature:\n```python\n\n def test_deterministic_true(self) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_deterministic_true", + "reference": "\n def test_deterministic_true(self) -> None:\n for det_debug_mode, det_debug_mode_str in [(1, \"warn\"), (2, \"error\")]:\n warn_only = det_debug_mode == 1\n for debug_mode in (det_debug_mode, det_debug_mode_str):\n set_seed(42, debug_mode=debug_mode)\n assert torch.backends.cudnn.deterministic\n assert not torch.backends.cudnn.benchmark\n assert det_debug_mode == torch.get_deterministic_debug_mode()\n assert torch.are_deterministic_algorithms_enabled()\n assert (\n warn_only == torch.is_deterministic_algorithms_warn_only_enabled()\n )\n assert os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] == \":4096:8\"", + "completions": [ + "```python\n def test_deterministic_true(self) -> None:\n seed = set_seed(debug_mode=True)\n self.assertEqual(seed, 123456789)\n```\n" + ], + "line": 45, + "token": 345, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 56, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n```\n###test function signature:\n```python\n\ndef test_concat_periodarray_2d():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_concat_periodarray_2d", + "reference": "\ndef test_concat_periodarray_2d():\n pi = pd.period_range(\"2016-01-01\", periods=36, freq=\"D\")\n arr = pi._data.reshape(6, 6)\n\n result = _concat.concat_compat([arr[:2], arr[2:]], axis=0)\n tm.assert_period_array_equal(result, arr)\n\n result = _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=1)\n tm.assert_period_array_equal(result, arr)\n\n msg = (\n \"all the input array dimensions.* for the concatenation axis must match exactly\"\n )\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:, :2], arr[:, 2:]], axis=0)\n\n with pytest.raises(ValueError, match=msg):\n _concat.concat_compat([arr[:2], arr[2:]], axis=1)", + "completions": [ + "```python\n\ndef test_concat_periodarray_2d():\n import numpy as np\n import pandas as pd\n from pandas.core.dtypes.common import is_period_dtype\n\n def concat_compat(to_concat, axis: AxisInt = 0, ea_compat_axis: bool = False):\n\n # filter empty arrays\n # 1-d dtypes always are included here\n def is_nonempty(x) -> bool:\n if x.ndim <= axis:\n return True\n return x.shape[axis] > 0\n\n # If all arrays are empty, there's nothing to convert, just short-cut to\n # the concatenation, #3121.\n #\n # Creating an empty array directly is tempting, but the winnings would be\n # marginal given that it would still require shape & dtype calculation and\n # np.concatenate which has them both implemented is compiled.\n non_empties = [x for x in to_concat if is_nonempty(x)]\n if non_empties and axis == 0 and not ea_compat_axis:\n # ea_compat_axis see GH#39574\n to_concat = non_empties\n\n dtypes = {obj.dtype for obj in to_concat}\n kinds = {obj.dtype.kind for obj in to_concat}\n contains_datetime = any(\n isinstance(dtype, (np.dtype, DatetimeTZDtype)) and dtype.kind in [\"m\", \"M\"]\n for dtype in dtypes\n ) or any(isinstance(obj, ABCExtensionArray) and obj.ndim > 1 for obj in to_concat)\n\n all_empty = not len(non_empties)\n single_dtype = len({x.dtype for x in to_concat}) == 1\n any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat)\n\n if contains_datetime:\n return _concat_datetime(to_concat, axis=axis)\n\n if any_ea:\n # we ignore axis here, as internally concatting with EAs is always\n # for axis=0\n if not single_dtype:\n target_dtype = find_common_type([x.dtype for x in to_concat])\n target_dtype = common_dtype_categorical_compat(to_concat, target_dtype)\n to_concat = [\n astype_array(arr, target_dtype, copy=False) for arr in to_concat\n ]\n\n if isinstance(to_concat[0], ABCExtensionArray):\n # TODO: what about EA-backed Index?\n cls = type(to_concat[0])\n return cls._concat_same_type(to_concat)\n else:\n return np.concatenate(to_concat)\n\n elif all_empty:\n # we have all empties, but may need to coerce the result dtype to\n # object if we have non-numeric type operands (numpy would otherwise\n # cast this to float)\n if len(kinds) != 1:\n if not len(kinds - {\"i\", \"u\", \"f\"}) or not len(kinds - {\"b\", \"i\", \"u\"}):\n # let numpy coerce\n pass\n else:\n # coerce to object\n to_concat = [x.astype(\"object\") for x in to_concat]\n kinds = {\"o\"}\n\n result = np.concatenate(to_concat, axis=axis)\n if \"b\" in kinds and result.dtype.kind in [\"i\", \"u\", \"f\"]:\n # GH#39817 cast to object instead of casting bools to numeric\n result = result.astype(object, copy=False)\n return result\n\n def _concat_datetime(to_concat, axis: AxisInt = 0):\n from pandas.core.arrays import PeriodArray\n\n if axis == 0:\n " + ], + "line": 70, + "token": 610, + "line_diff": 5, + "token_diff": 6 + }, + { + "id": 57, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef ingest_ocp_payload(request):\n request_id = uuid4().hex\n response_data = {\"request-id\": request_id, \"payload-name\": []}\n if request.method == \"POST\":\n if not request.FILES:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for _, file in request.FILES.items():\n payload_name = file.name\n response_data[\"payload-name\"].append(payload_name)\n s3_signature = get_s3_signature(settings.S3_ENDPOINT, payload_name, method=\"put_object\")\n res = upload_file_to_s3(s3_signature, data=file.file)\n if res.status_code == HTTPStatus.OK:\n response_data[\"upload\"] = \"success\"\n else:\n response_data[\"upload\"] = \"failed\"\n response_data[\"failed-reason\"] = res.reason\n return Response(response_data, status=res.status_code)\n send_payload(request_id, payload_name)\n else:\n params = request.query_params\n payload_names = params.get(\"payload_name\")\n if not payload_names:\n response_data[\"error\"] = \"no payload sent\"\n return Response(response_data, status=HTTPStatus.BAD_REQUEST)\n for payload_name in payload_names.split(\",\"):\n send_payload(request_id, payload_name)\n response_data[\"payload-name\"].append(payload_name)\n\n response_data[\"ingest-started\"] = True\n\n return Response(response_data, status=HTTPStatus.ACCEPTED)\n```\n###test function signature:\n```python\n\n def test_ingest_ocp_payload(self):\n # Arrange\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ingest_ocp_payload", + "reference": "\n def test_ingest_ocp_payload(self):\n # Arrange\n file1 = SimpleUploadedFile(\"file1.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n file2 = SimpleUploadedFile(\"file2.gz\", b\"file_content\", content_type=\"multipart/form-data\")\n test_table = [\n # Happy path tests\n (\n \"POST\",\n {\"file1\": file1, \"file2\": file2},\n \"\",\n HTTPStatus.ACCEPTED,\n {\"upload\": \"success\", \"ingest-started\": True},\n \"happy_path_post\",\n ),\n (\"GET\", {}, \"?payload_name=file1,file2\", HTTPStatus.ACCEPTED, {\"ingest-started\": True}, \"happy_path_get\"),\n # Edge cases\n (\"POST\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_files_post\"),\n (\"GET\", {}, \"\", HTTPStatus.BAD_REQUEST, {\"error\": \"no payload sent\"}, \"edge_case_no_payload_name_get\"),\n # Error cases\n (\"POST\", {\"file1\": file1}, \"\", HTTPStatus.INTERNAL_SERVER_ERROR, {\"upload\": \"failed\"}, \"error_case_post\"),\n (\"GET\", {}, \"?payload_name=non_existent_file\", HTTPStatus.ACCEPTED, {}, \"error_case_get\"),\n ]\n for test in test_table:\n method, files, query_params, expected_status, expected_response, test_id = test\n with self.subTest(test=test_id):\n\n mock_response = MagicMock()\n mock_response.status_code = (\n HTTPStatus.OK if expected_response.get(\"upload\") != \"failed\" else HTTPStatus.INTERNAL_SERVER_ERROR\n )\n mock_response.reason = \"Some reason\"\n\n with patch(\"masu.api.ingest_ocp_payload.get_s3_signature\") as mock_get_s3_signature, patch(\n \"masu.api.ingest_ocp_payload.upload_file_to_s3\", return_value=mock_response\n ) as mock_upload_file_to_s3, patch(\"masu.api.ingest_ocp_payload.send_payload\") as mock_send_payload:\n\n request = (\n APIRequestFactory().get(f\"/{query_params}\")\n if method == \"GET\"\n else APIRequestFactory().post(\"/\", data=files)\n )\n\n # Act\n response = ingest_ocp_payload(request)\n\n # Assert\n self.assertEqual(response.status_code, expected_status)\n if method == \"POST\" and files:\n mock_get_s3_signature.assert_called()\n mock_upload_file_to_s3.assert_called()\n if method == \"GET\" and \"payload_name\" in query_params:\n mock_send_payload.assert_called()\n for key, value in expected_response.items():\n self.assertEqual(response.data[key], value)", + "completions": [ + "```python\n def test_ingest_ocp_payload(self):\n # Arrange\n request = self.factory.post(\n \"/api/v1/ingest/ocp/payload\",\n data={\"payload_name\": \"payload_name\"},\n content_type=\"multipart/form-data\",\n )\n request.FILES = {\"payload_name\": \"payload_name\"}\n # Act\n response = ingest_ocp_payload(request)\n # Assert\n self.assertEqual(response.status_code, HTTPStatus.ACCEPTED)\n self.assertEqual(response.data[\"request-id\"], \"request_id\")\n self.assertEqual(response.data[\"payload-name\"], [\"payload_name\"])\n self.assertEqual(response.data[\"ingest-started\"], True)\n```\n" + ], + "line": 41, + "token": 308, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 58, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\nasync def compute_nb_items_with_metadata_field(\n project_id: str,\n metadata_field: str,\n collection_name: str,\n quantile_value: float = 0.1,\n) -> Tuple[int, float, int]:\n\n mongo_db = await get_mongo_db()\n\n try:\n pipeline = [\n # Match on the project id and the existence of the metadata field\n {\n \"$match\": {\n \"project_id\": project_id,\n f\"metadata.{metadata_field}\": {\"$exists\": True, \"$ne\": None},\n }\n },\n {\n \"$group\": {\n \"_id\": \"$user_id\", # Group by 'user_id'\n \"count\": {\"$sum\": 1}, # Count the documents in each group\n }\n },\n {\n \"$sort\": {\"count\": -1} # Sort by 'count' in descending order\n },\n ]\n\n # Execute the aggregation pipeline\n results = [] # List to hold the results\n async for group in mongo_db[collection_name].aggregate(pipeline):\n results.append(group[\"count\"])\n\n logger.debug(\n f\"Results for {metadata_field} in collection {collection_name} for project {project_id}: {results}\"\n )\n\n # If no results, return 0\n if len(results) == 0:\n return 0, 0.0, 0\n\n # Get the average and the quantiles\n average = sum(results) / len(results)\n bottom_quantile = results[int(len(results) * quantile_value)]\n top_quantile = results[int(len(results) * (1 - quantile_value))]\n\n return bottom_quantile, average, top_quantile\n\n except Exception as e:\n logger.warning(\n f\"Failed to fetch the number of {metadata_field} in collection {collection_name} for project {project_id}: {e}\",\n )\n return 0, 0.0, 0\n```\n###test function signature:\n```python\n\nasync def test_main_pipeline(db, populated_project):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_main_pipeline", + "reference": "\nasync def test_main_pipeline(db, populated_project):\n async for mongo_db in db:\n test_project_id = populated_project.id\n\n # Run the function\n topics = await fetch_topics(test_project_id)\n\n logger.debug(f\"Topics: {topics}\")\n\n count = await nb_items_with_a_metadata_field(\n test_project_id, \"tasks\", \"user_id\"\n )\n\n assert count > 0\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_successrate_metadata_quantiles(test_project_id, \"user_id\")\n\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert bottom_quantile >= 0\n # assert bottom_quantile <= top_quantile\n\n test_project_id = populated_project.id\n test_metadata_field = \"user_id\"\n test_collection_name = \"tasks\"\n test_quantile_value = 0.1\n\n # Run the function\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_nb_items_with_metadata_field(\n test_project_id,\n test_metadata_field,\n test_collection_name,\n test_quantile_value,\n )\n\n logger.debug(\"runned compute_nb_items_with_metadata_field\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, int)\n assert isinstance(average, float)\n assert isinstance(top_quantile, int)\n # assert bottom_quantile <= average <= top_quantile\n\n (\n bottom_quantile,\n average,\n top_quantile,\n ) = await compute_session_length_per_metadata(\n test_project_id,\n test_metadata_field,\n test_quantile_value,\n )\n logger.debug(\"runned compute_session_length_per_metadata\")\n logger.debug(f\"Bottom quantile: {bottom_quantile}\")\n logger.debug(f\"Average: {average}\")\n logger.debug(f\"Top quantile: {top_quantile}\")\n\n assert isinstance(bottom_quantile, float)\n assert isinstance(average, float)\n assert isinstance(top_quantile, float)\n # assert bottom_quantile <= average <= top_quantile", + "completions": [ + "```python\nasync def test_main_pipeline(db, populated_project):\n # Arrange\n project_id = populated_project.id\n collection_name = \"collection_name\"\n metadata_field = \"metadata_field\"\n quantile_value = 0.1\n\n # Act\n result = await compute_nb_items_with_metadata_field(\n project_id, metadata_field, collection_name, quantile_value\n )\n\n # Assert\n assert result == (0, 0.0, 0)\n```\n" + ], + "line": 56, + "token": 430, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 59, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n```\n###test function signature:\n```python\n\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_engine_auto_error_message", + "reference": "\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n from pandas.compat._optional import VERSIONS\n\n # Do we have engines installed, but a bad version of them?\n pa_min_ver = VERSIONS.get(\"pyarrow\")\n fp_min_ver = VERSIONS.get(\"fastparquet\")\n have_pa_bad_version = (\n False\n if not _HAVE_PYARROW\n else Version(pyarrow.__version__) < Version(pa_min_ver)\n )\n have_fp_bad_version = (\n False\n if not _HAVE_FASTPARQUET\n else Version(fastparquet.__version__) < Version(fp_min_ver)\n )\n # Do we have usable engines installed?\n have_usable_pa = _HAVE_PYARROW and not have_pa_bad_version\n have_usable_fp = _HAVE_FASTPARQUET and not have_fp_bad_version\n\n if not have_usable_pa and not have_usable_fp:\n # No usable engines found.\n if have_pa_bad_version:\n match = f\"Pandas requires version .{pa_min_ver}. or newer of .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .pyarrow.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n\n if have_fp_bad_version:\n match = f\"Pandas requires version .{fp_min_ver}. or newer of .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")\n else:\n match = \"Missing optional dependency .fastparquet.\"\n with pytest.raises(ImportError, match=match):\n get_engine(\"auto\")", + "completions": [ + "```python\ndef test_get_engine_auto_error_message():\n # Expect different error messages from get_engine(engine=\"auto\")\n # if engines aren't installed vs. are installed but bad version\n import pyarrow\n import fastparquet\n from pyarrow.parquet import ParquetEngine\n from fastparquet.parquet import ParquetEngine as FastParquetEngine\n\n def get_engine(engine: str) -> BaseImpl:\n if engine == \"auto\":\n engine = get_option(\"io.parquet.engine\")\n\n if engine == \"auto\":\n # try engines in this order\n engine_classes = [PyArrowImpl, FastParquetImpl]\n\n error_msgs = \"\"\n for engine_class in engine_classes:\n try:\n return engine_class()\n except ImportError as err:\n error_msgs += \"\\n - \" + str(err)\n\n raise ImportError(\n \"Unable to find a usable engine; \"\n \"tried using: 'pyarrow', 'fastparquet'.\\n\"\n \"A suitable version of \"\n \"pyarrow or fastparquet is required for parquet \"\n \"support.\\n\"\n \"Trying to import the above resulted in these errors:\"\n f\"{error_msgs}\"\n )\n\n if engine == \"pyarrow\":\n return PyArrowImpl()\n elif engine == \"fastparquet\":\n return FastParquetImpl()\n\n raise ValueError(\"engine must be one of 'pyarrow', 'fastparquet'\")\n\n # pyarrow is installed\n assert get_engine(\"auto\").__class__ == PyArrowImpl\n assert get_engine(\"pyarrow\").__class__ == PyArrowImpl\n assert get_engine(\"fastparquet\").__class__ == FastParquetImpl\n\n # pyarrow is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'pyarrow'\" in str(excinfo.value)\n\n # fastparquet is not installed\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"No module named 'fastparquet'\" in str(excinfo.value)\n\n # pyarrow is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of pyarrow is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n assert \"pyarrow.parquet.ParquetEngine\" in str(excinfo.value)\n\n # fastparquet is installed but bad version\n with pytest.raises(ImportError) as excinfo:\n get_engine(\"auto\")\n assert \"Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'\" in str(excinfo.value)\n assert \"A suitable version of fastparquet is required for parquet support.\" in str(excinfo.value)\n assert \"Trying to import the above resulted in these errors:\" in str(excinfo.value)\n" + ], + "line": 38, + "token": 277, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 60, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def attribute(\n self,\n inp: InterpretableInput,\n target: Union[str, torch.Tensor, None] = None,\n num_trials: int = 1,\n gen_args: Optional[Dict] = None,\n # internal callback hook can be used for logging\n _inspect_forward: Optional[Callable] = None,\n **kwargs,\n ) -> LLMAttributionResult:\n\n assert isinstance(\n inp, self.SUPPORTED_INPUTS\n ), f\"LLMAttribution does not support input type {type(inp)}\"\n\n if target is None:\n # generate when None\n assert hasattr(self.model, \"generate\") and callable(self.model.generate), (\n \"The model does not have recognizable generate function.\"\n \"Target must be given for attribution\"\n )\n\n if not gen_args:\n gen_args = DEFAULT_GEN_ARGS\n\n model_inp = self._format_model_input(inp.to_model_input())\n output_tokens = self.model.generate(model_inp, **gen_args)\n target_tokens = output_tokens[0][model_inp.size(1) :]\n else:\n assert gen_args is None, \"gen_args must be None when target is given\"\n\n if type(target) is str:\n # exclude sos\n target_tokens = self.tokenizer.encode(target)[1:]\n target_tokens = torch.tensor(target_tokens)\n elif type(target) is torch.Tensor:\n target_tokens = target\n\n attr = torch.zeros(\n [\n 1 + len(target_tokens) if self.include_per_token_attr else 1,\n inp.n_itp_features,\n ],\n dtype=torch.float,\n device=self.device,\n )\n\n for _ in range(num_trials):\n attr_input = inp.to_tensor().to(self.device)\n\n cur_attr = self.attr_method.attribute(\n attr_input,\n additional_forward_args=(inp, target_tokens, _inspect_forward),\n **kwargs,\n )\n\n # temp necessary due to FA & Shapley's different return shape of multi-task\n # FA will flatten output shape internally (n_output_token, n_itp_features)\n # Shapley will keep output shape (batch, n_output_token, n_input_features)\n cur_attr = cur_attr.reshape(attr.shape)\n\n attr += cur_attr\n\n attr = attr / num_trials\n\n attr = inp.format_attr(attr)\n\n return LLMAttributionResult(\n attr[0],\n attr[1:]\n if self.include_per_token_attr\n else None, # shape(n_output_token, n_input_features)\n inp.values,\n self.tokenizer.convert_ids_to_tokens(target_tokens),\n )\n```\n###test function signature:\n```python\n\n def test_llm_attr(self, AttrClass) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_llm_attr", + "reference": "\n def test_llm_attr(self, AttrClass) -> None:\n llm = DummyLLM()\n llm.to(self.device)\n tokenizer = DummyTokenizer()\n llm_attr = LLMAttribution(AttrClass(llm), tokenizer)\n\n inp = TextTemplateInput(\"{} b {} {} e {}\", [\"a\", \"c\", \"d\", \"f\"])\n res = llm_attr.attribute(inp, \"m n o p q\")\n\n self.assertEqual(res.seq_attr.shape, (4,))\n self.assertEqual(cast(Tensor, res.token_attr).shape, (5, 4))\n self.assertEqual(res.input_tokens, [\"a\", \"c\", \"d\", \"f\"])\n self.assertEqual(res.output_tokens, [\"m\", \"n\", \"o\", \"p\", \"q\"])\n self.assertEqual(res.seq_attr.device.type, self.device)\n self.assertEqual(cast(Tensor, res.token_attr).device.type, self.device)", + "completions": [ + "```python\n def test_llm_attr(self, AttrClass) -> None:\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for LLM attribute\n # test case for L" + ], + "line": 72, + "token": 464, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 61, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def plan(self, task: str):\n agent_pool_description = self.agent_pool.get_agent_pool_description()\n feedback = self.get_last_feedback()\n last_plan = self.get_last_plan()\n\n # run the planning\n plan = self.planning.execute(task, agent_pool_description, last_plan, feedback)\n\n logger.info(f\"Plan: {plan}\")\n\n self.shared_memory.add(\n event_type=EventType.planning, agent=self.name, content=str(plan)\n )\n\n return plan\n```\n###test function signature:\n```python\n\ndef test_planner_succeeds(get_llm): # noqa: F811\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_planner_succeeds", + "reference": "\ndef test_planner_succeeds(get_llm): # noqa: F811\n llm = get_llm(__file__, test_planner_succeeds.__name__)\n\n physicist_description = (\n \"The physicist agent answers questions or research about physics-related topics\"\n )\n physicist = Physicist(\n name=\"Physicist\",\n description=physicist_description,\n llm=llm,\n )\n\n agent_pool = AgentPool()\n agent_pool.add_agents([physicist])\n\n shared_memory = SharedMemory(\n objective=\"Share the information across different agents.\",\n agent_pool=agent_pool,\n )\n\n planner = Planner(\n name=\"planner\",\n agent_pool=agent_pool,\n shared_memory=shared_memory,\n llm=llm,\n )\n\n task = \"\"\"We need to render a highly complex 3D image on the solar system. We can use any publicly avaliable\n resources to achieve this task.\"\"\" # noqa: E501\n\n p = planner.plan(task=task)\n assert len(p.steps) > 0", + "completions": [ + "```python\ndef test_planner_succeeds(get_llm): # noqa: F811\n planner = get_llm\n planner.agent_pool.set_agent_pool_description(\n {\n \"agent_pool\": {\n \"agents\": [\n {\n \"name\": \"agent1\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n {\n \"name\": \"agent2\",\n \"capabilities\": [\"cap1\", \"cap2\"],\n \"max_task_queue_size\": 10,\n },\n ]\n }\n }\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task1\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task2\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task3\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task4\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task5\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent2\",\n content='{\"task\": \"task6\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}]}}',\n )\n planner.shared_memory.add(\n event_type=EventType.planning,\n agent=\"agent1\",\n content='{\"task\": \"task7\", \"agent_pool\": {\"agents\": [{\"name\": \"agent1\", \"capabilities\": [\"cap1\", \"cap2\"], \"max_task_queue_size\": 10}, {\"name\": \"agent2\", \"capabilities\": [\"cap1\"," + ], + "line": 21, + "token": 149, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 62, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def add_edge(self, from_step: str, to_step: str) -> None:\n if from_step not in self.G:\n raise ValueError(f\"Step with name '{from_step}' does not exist\")\n\n if to_step not in self.G:\n raise ValueError(f\"Step with name '{to_step}' does not exist\")\n\n if to_step in self.G[from_step]:\n raise ValueError(\n f\"There is already a edge from '{to_step}' to '{from_step}'\"\n )\n\n if to_step in nx.ancestors(self.G, from_step):\n raise ValueError(\n f\"Cannot add edge from '{from_step}' to '{to_step}' as it would create a cycle.\"\n )\n\n self.G.add_edge(from_step, to_step)\n```\n###test function signature:\n```python\n\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validate_missing_step_input", + "reference": "\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n class DummyStep3(Step):\n @property\n def inputs(self) -> List[str]:\n return [\"instruction\"]\n\n @property\n def outputs(self) -> List[str]:\n return [\"response\"]\n\n def process(self) -> \"StepOutput\": # type: ignore\n yield [{\"response\": \"response1\"}]\n\n dag = DAG()\n dag.add_step(dummy_generator_step)\n dag.add_step(DummyStep3(name=\"dummy_step_3\", pipeline=pipeline))\n dag.add_edge(\"dummy_generator_step\", \"dummy_step_3\")\n\n with pytest.raises(\n ValueError,\n match=\"Step 'dummy_step_3' should have a parameter with type hint `StepInput`\",\n ):\n dag.validate()", + "completions": [ + "```python\n def test_validate_missing_step_input(\n self, dummy_generator_step: \"GeneratorStep\", pipeline: \"Pipeline\"\n ) -> None:\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_2\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_1\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_1\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_2\", \"step_3\")\n\n with pytest.raises(ValueError):\n dummy_generator_step.add_edge(\"step_3\", \"step_2\")\n```\n" + ], + "line": 26, + "token": 238, + "line_diff": 1, + "token_diff": 2 + }, + { + "id": 63, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_hourly_temp_mean", + "reference": "\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n meter_data = il_electricity_cdd_hdd_hourly[\"meter_data\"][\"2016-03-01\":\"2016-07-01\"]\n temperature_data = il_electricity_cdd_hdd_hourly[\"temperature_data\"][\n \"2016-03-01\":\"2016-07-01\"\n ]\n df = compute_temperature_features(meter_data.index, temperature_data)\n assert list(sorted(df.columns)) == [\n \"n_hours_dropped\",\n \"n_hours_kept\",\n \"temperature_mean\",\n ]\n assert df.shape == (2952, 3)\n\n assert round(df.temperature_mean.mean()) == 62.0", + "completions": [ + "```python\n\ndef test_compute_temperature_features_hourly_temp_mean(il_electricity_cdd_hdd_hourly):\n # pick a slice with both hdd and cdd\n df = il_electricity_cdd_hdd_hourly.copy()\n df = df.iloc[100:150]\n df = df.assign(\n **{\n \"cdd_10\": np.maximum(df.temperature_mean - 10, 0),\n \"cdd_20\": np.maximum(df.temperature_mean - 20, 0),\n \"hdd_10\": np.maximum(10 - df.temperature_mean, 0),\n \"hdd_20\": np.maximum(20 - df.temperature_mean, 0),\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n df = df.drop([\"temperature_mean\"], axis=1)\n df = df.reindex(il_electricity_cdd_hdd_hourly.index)\n df = overwrite_partial_rows_with_nan(df)\n df = df.iloc[:-1].reindex(df.index)\n assert_frame_equal(df, il_electricity_cdd_hdd_hourly)\n```\n" + ], + "line": 167, + "token": 992, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 64, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef pandasSQL_builder(\n con,\n schema: str | None = None,\n need_transaction: bool = False,\n) -> PandasSQL:\n import sqlite3\n\n if isinstance(con, sqlite3.Connection) or con is None:\n return SQLiteDatabase(con)\n\n sqlalchemy = import_optional_dependency(\"sqlalchemy\", errors=\"ignore\")\n\n if isinstance(con, str) and sqlalchemy is None:\n raise ImportError(\"Using URI string without sqlalchemy installed.\")\n\n if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Connectable)):\n return SQLDatabase(con, schema, need_transaction)\n\n warnings.warn(\n \"pandas only supports SQLAlchemy connectable (engine/connection) or \"\n \"database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 \"\n \"objects are not tested. Please consider using SQLAlchemy.\",\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return SQLiteDatabase(con)\n```\n###test function signature:\n```python\n\n def test_execute_closed_connection(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_execute_closed_connection", + "reference": "\n def test_execute_closed_connection(self):\n create_sql = \"\"\"\n CREATE TABLE test\n (\n a TEXT,\n b TEXT,\n c REAL,\n PRIMARY KEY (a, b)\n );\n \"\"\"\n with contextlib.closing(sqlite3.connect(\":memory:\")) as conn:\n cur = conn.cursor()\n cur.execute(create_sql)\n\n with sql.pandasSQL_builder(conn) as pandas_sql:\n pandas_sql.execute('INSERT INTO test VALUES(\"foo\", \"bar\", 1.234)')\n\n msg = \"Cannot operate on a closed database.\"\n with pytest.raises(sqlite3.ProgrammingError, match=msg):\n tquery(\"select * from test\", con=conn)", + "completions": [ + "```python\n def test_execute_closed_connection(self):\n from pandas.io.sql import PandasSQL\n from pandas.io.sql.sqldatabase import SQLiteDatabase\n from pandas.io.sql.pandas_sql import PandasSQLQueryCompiler\n from pandas.io.sql.pandas_sql import PandasSQLEngine\n from pandas.io.sql.pandas_sql import PandasSQLTable\n from pandas.io.sql.pandas_sql import PandasSQLConverter\n from pandas.io.sql.pandas_sql import PandasSQLAlchemy\n from pandas.io.sql.pandas_sql import PandasSQLQuery\n from pandas.io.sql.pandas_sql import PandasSQLExecutionContext\n from pandas.io.sql.pandas_sql import PandasSQLCompiler\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas.io.sql.pandas_sql import PandasSQLExecutionResult\n from pandas." + ], + "line": 31, + "token": 237, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 65, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef compute_temperature_features(\n meter_data_index,\n temperature_data,\n heating_balance_points=None,\n cooling_balance_points=None,\n data_quality=False,\n temperature_mean=True,\n degree_day_method=\"daily\",\n percent_hourly_coverage_per_day=0.5,\n percent_hourly_coverage_per_billing_period=0.9,\n use_mean_daily_values=True,\n tolerance=None,\n keep_partial_nan_rows=False,\n):\n if temperature_data.index.freq != \"H\":\n raise ValueError(\n \"temperature_data.index must have hourly frequency (freq='H').\"\n \" Found: {}\".format(temperature_data.index.freq)\n )\n\n if not temperature_data.index.tz:\n raise ValueError(\n \"temperature_data.index must be timezone-aware. You can set it with\"\n \" temperature_data.tz_localize(...).\"\n )\n\n if meter_data_index.freq is None and meter_data_index.inferred_freq == \"H\":\n raise ValueError(\n \"If you have hourly data explicitly set the frequency\"\n \" of the dataframe by setting\"\n \"``meter_data_index.freq =\"\n \" pd.tseries.frequencies.to_offset('H').\"\n )\n\n if not meter_data_index.tz:\n raise ValueError(\n \"meter_data_index must be timezone-aware. You can set it with\"\n \" meter_data.tz_localize(...).\"\n )\n\n if meter_data_index.duplicated().any():\n raise ValueError(\"Duplicates found in input meter trace index.\")\n\n temp_agg_funcs = []\n temp_agg_column_renames = {}\n\n if heating_balance_points is None:\n heating_balance_points = []\n if cooling_balance_points is None:\n cooling_balance_points = []\n\n if meter_data_index.freq is not None:\n try:\n freq_timedelta = pd.Timedelta(meter_data_index.freq)\n except ValueError: # freq cannot be converted to timedelta\n freq_timedelta = None\n else:\n freq_timedelta = None\n\n if tolerance is None:\n tolerance = freq_timedelta\n\n if not (heating_balance_points == [] and cooling_balance_points == []):\n if degree_day_method == \"hourly\":\n pass\n elif degree_day_method == \"daily\":\n if meter_data_index.freq == \"H\":\n raise ValueError(\n \"degree_day_method='daily' must be used with\"\n \" daily meter data. Found: 'hourly'\".format(degree_day_method)\n )\n else:\n raise ValueError(\"method not supported: {}\".format(degree_day_method))\n\n if freq_timedelta == pd.Timedelta(\"1H\"):\n # special fast route for hourly data.\n df = temperature_data.to_frame(\"temperature_mean\").reindex(meter_data_index)\n\n if use_mean_daily_values:\n n_days = 1\n else:\n n_days = 1.0 / 24.0\n\n df = df.assign(\n **{\n \"cdd_{}\".format(bp): np.maximum(df.temperature_mean - bp, 0) * n_days\n for bp in cooling_balance_points\n }\n )\n df = df.assign(\n **{\n \"hdd_{}\".format(bp): np.maximum(bp - df.temperature_mean, 0) * n_days\n for bp in heating_balance_points\n }\n )\n df = df.assign(\n n_hours_dropped=df.temperature_mean.isnull().astype(int),\n n_hours_kept=df.temperature_mean.notnull().astype(int),\n )\n # TODO(philngo): bad interface or maybe this is just wrong for some reason?\n if data_quality:\n df = df.assign(\n temperature_null=df.n_hours_dropped,\n temperature_not_null=df.n_hours_kept,\n )\n if not temperature_mean:\n del df[\"temperature_mean\"]\n else:\n # daily/billing route\n # heating/cooling degree day aggregations. Needed for n_days fields as well.\n temp_agg_funcs.extend(\n _degree_day_columns(\n heating_balance_points=heating_balance_points,\n cooling_balance_points=cooling_balance_points,\n degree_day_method=degree_day_method,\n percent_hourly_coverage_per_day=percent_hourly_coverage_per_day,\n percent_hourly_coverage_per_billing_period=percent_hourly_coverage_per_billing_period,\n use_mean_daily_values=use_mean_daily_values,\n )\n )\n temp_agg_column_renames.update(\n {(\"temp\", \"degree_day_columns\"): \"degree_day_columns\"}\n )\n\n if data_quality:\n temp_agg_funcs.extend(\n [(\"not_null\", \"count\"), (\"null\", lambda x: x.isnull().sum())]\n )\n temp_agg_column_renames.update(\n {\n (\"temp\", \"not_null\"): \"temperature_not_null\",\n (\"temp\", \"null\"): \"temperature_null\",\n }\n )\n\n if temperature_mean:\n temp_agg_funcs.extend([(\"mean\", \"mean\")])\n temp_agg_column_renames.update({(\"temp\", \"mean\"): \"temperature_mean\"})\n\n # aggregate temperatures\n temp_df = temperature_data.to_frame(\"temp\")\n temp_groups = _matching_groups(meter_data_index, temp_df, tolerance)\n temp_aggregations = temp_groups.agg({\"temp\": temp_agg_funcs})\n\n # expand temp aggregations by faking and deleting the `meter_value` column.\n # I haven't yet figured out a way to avoid this and get the desired\n # structure and behavior. (philngo)\n meter_value = pd.DataFrame({\"meter_value\": 0}, index=meter_data_index)\n df = pd.concat([meter_value, temp_aggregations], axis=1).rename(\n columns=temp_agg_column_renames\n )\n del df[\"meter_value\"]\n\n if \"degree_day_columns\" in df:\n if df[\"degree_day_columns\"].dropna().empty:\n column_defaults = {\n column: np.full(df[\"degree_day_columns\"].shape, np.nan)\n for column in [\"n_days_dropped\", \"n_days_kept\"]\n }\n df = df.drop([\"degree_day_columns\"], axis=1).assign(**column_defaults)\n else:\n df = pd.concat(\n [\n df.drop([\"degree_day_columns\"], axis=1),\n df[\"degree_day_columns\"].dropna().apply(pd.Series),\n ],\n axis=1,\n )\n\n if not keep_partial_nan_rows:\n df = overwrite_partial_rows_with_nan(df)\n\n # nan last row\n df = df.iloc[:-1].reindex(df.index)\n return df\n```\n###test function signature:\n```python\n\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_compute_temperature_features_billing_monthly_hourly_degree_days", + "reference": "\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data = il_electricity_cdd_hdd_billing_monthly[\"meter_data\"]\n temperature_data = il_electricity_cdd_hdd_billing_monthly[\"temperature_data\"]\n df = compute_temperature_features(\n meter_data.index,\n temperature_data,\n heating_balance_points=[60, 61],\n cooling_balance_points=[65, 66],\n temperature_mean=False,\n degree_day_method=\"hourly\",\n )\n assert df.shape == (27, 6)\n assert list(sorted(df.columns)) == [\n \"cdd_65\",\n \"cdd_66\",\n \"hdd_60\",\n \"hdd_61\",\n \"n_hours_dropped\",\n \"n_hours_kept\",\n ]\n snapshot.assert_match(\n [\n round(df.hdd_60.mean(), 2),\n round(df.hdd_61.mean(), 2),\n round(df.cdd_65.mean(), 2),\n round(df.cdd_66.mean(), 2),\n round(df.n_hours_kept.mean(), 2),\n round(df.n_hours_dropped.mean(), 2),\n ],\n \"values\",\n )", + "completions": [ + "```python\ndef test_compute_temperature_features_billing_monthly_hourly_degree_days(\n il_electricity_cdd_hdd_billing_monthly, snapshot\n):\n meter_data_index = pd.date_range(\n start=\"2015-01-01\", end=\"2015-01-31\", freq=\"H\"\n )\n temperature_data = pd.DataFrame(\n {\n \"temperature_mean\": [\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 10,\n 1" + ], + "line": 168, + "token": 985, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 66, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef initialize_model_parallel(\n tensor_model_parallel_size: int = 1,\n pipeline_model_parallel_size: int = 1,\n virtual_pipeline_model_parallel_size: Optional[int] = None,\n pipeline_model_parallel_split_rank: Optional[int] = None,\n use_fp8: bool = False,\n) -> None:\n # Get world size and rank. Ensure some consistencies.\n assert torch.distributed.is_initialized()\n world_size: int = torch.distributed.get_world_size()\n\n if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:\n raise RuntimeError(\n f\"world_size ({world_size}) is not divisible by tensor_model_parallel_size \"\n f\"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})\"\n )\n\n data_parallel_size: int = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)\n\n num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size\n num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size\n num_data_parallel_groups: int = world_size // data_parallel_size\n\n if virtual_pipeline_model_parallel_size is not None:\n if not pipeline_model_parallel_size > 2:\n raise RuntimeError(\"pipeline-model-parallel size should be greater than 2 with \" \"interleaved schedule\")\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK\n global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0\n _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size\n\n if pipeline_model_parallel_split_rank is not None:\n global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK\n _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank\n\n rank = torch.distributed.get_rank()\n\n # Build the data-parallel groups.\n global _DATA_PARALLEL_GROUP\n global _DATA_PARALLEL_GROUP_GLOO\n global _DATA_PARALLEL_GLOBAL_RANKS\n assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'\n all_data_parallel_group_ranks = []\n for i in range(pipeline_model_parallel_size):\n start_rank = i * num_pipeline_model_parallel_groups\n end_rank = (i + 1) * num_pipeline_model_parallel_groups\n for j in range(tensor_model_parallel_size):\n ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)\n all_data_parallel_group_ranks.append(list(ranks))\n group = torch.distributed.new_group(ranks)\n group_gloo = torch.distributed.new_group(ranks, backend=\"gloo\")\n if rank in ranks:\n _DATA_PARALLEL_GROUP = group\n _DATA_PARALLEL_GROUP_GLOO = group_gloo\n _DATA_PARALLEL_GLOBAL_RANKS = ranks\n\n # Build the model-parallel groups.\n global _MODEL_PARALLEL_GROUP\n assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'\n for i in range(data_parallel_size):\n ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _MODEL_PARALLEL_GROUP = group\n\n # Build the tensor model-parallel groups.\n global _TENSOR_MODEL_PARALLEL_GROUP\n assert _TENSOR_MODEL_PARALLEL_GROUP is None, 'tensor model parallel group is already initialized'\n for i in range(num_tensor_model_parallel_groups):\n ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _TENSOR_MODEL_PARALLEL_GROUP = group\n\n # Build the pipeline model-parallel groups and embedding groups\n # (first and last rank in each pipeline model-parallel group).\n global _PIPELINE_MODEL_PARALLEL_GROUP\n global _PIPELINE_GLOBAL_RANKS\n assert _PIPELINE_MODEL_PARALLEL_GROUP is None, 'pipeline model parallel group is already initialized'\n global _EMBEDDING_GROUP\n global _EMBEDDING_GLOBAL_RANKS\n assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'\n global _POSITION_EMBEDDING_GROUP\n global _POSITION_EMBEDDING_GLOBAL_RANKS\n assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'\n for i in range(num_pipeline_model_parallel_groups):\n ranks = range(i, world_size, num_pipeline_model_parallel_groups)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _PIPELINE_MODEL_PARALLEL_GROUP = group\n _PIPELINE_GLOBAL_RANKS = ranks\n # Setup embedding group (to exchange gradients between\n # first and last stages).\n if len(ranks) > 1:\n embedding_ranks = [ranks[0], ranks[-1]]\n position_embedding_ranks = [ranks[0]]\n if pipeline_model_parallel_split_rank is not None:\n if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:\n embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank], ranks[-1]]\n if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:\n position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]\n else:\n embedding_ranks = ranks\n position_embedding_ranks = ranks\n\n group = torch.distributed.new_group(embedding_ranks)\n if rank in embedding_ranks:\n _EMBEDDING_GROUP = group\n if rank in ranks:\n _EMBEDDING_GLOBAL_RANKS = embedding_ranks\n\n group = torch.distributed.new_group(position_embedding_ranks)\n if rank in position_embedding_ranks:\n _POSITION_EMBEDDING_GROUP = group\n if rank in ranks:\n _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks\n\n # Build the FP8 groups.\n global _AMAX_REDUCTION_GROUP\n assert _AMAX_REDUCTION_GROUP is None, \\\n 'FP8 amax reduction group is already initialized'\n if use_fp8:\n amax_group_size: int = tensor_model_parallel_size * data_parallel_size\n num_amax_groups: int = world_size // amax_group_size\n for i in range(num_amax_groups):\n start_rank = i * amax_group_size\n end_rank = (i + 1) * amax_group_size\n ranks = range(start_rank, end_rank)\n group = torch.distributed.new_group(ranks)\n if rank in ranks:\n _AMAX_REDUCTION_GROUP = group\n\n # Initialize global memory buffer\n # This isn't really \"parallel state\" but there isn't another good place to\n # put this. If we end up with a more generic initialization of megatron-core\n # we could stick it there\n _set_global_memory_buffer()\n```\n###test function signature:\n```python\n\ndef test_initialize__and_destroy_model_parallel():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_initialize__and_destroy_model_parallel", + "reference": "\ndef test_initialize__and_destroy_model_parallel():\n with pytest.raises(AssertionError):\n assert(ps.initialize_model_parallel())\n Utils.initialize_distributed()\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(tensor_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=2*world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(pipeline_model_parallel_size=world_size, tensor_model_parallel_size=world_size))\n with pytest.raises(RuntimeError):\n assert(ps.initialize_model_parallel(virtual_pipeline_model_parallel_size=2))\n Utils.initialize_model_parallel(tensor_model_parallel_size=2, pipeline_model_parallel_size=4)\n\n assert(ps.model_parallel_is_initialized())\n assert(ps.get_model_parallel_group() is not None)\n assert(ps.get_tensor_model_parallel_group() is not None)\n assert(ps.get_pipeline_model_parallel_group() is not None)\n assert(ps.get_data_parallel_group() is not None) \n Utils.destroy_model_parallel()\n assert(ps._MODEL_PARALLEL_GROUP is None)", + "completions": [ + "```python\ndef test_initialize_and_destroy_model_parallel():\n initialize_model_parallel(\n tensor_model_parallel_size=1,\n pipeline_model_parallel_size=1,\n virtual_pipeline_model_parallel_size=None,\n pipeline_model_parallel_split_rank=None,\n use_fp8=False,\n )\n destroy_model_parallel()\n```\n" + ], + "line": 133, + "token": 886, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 67, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_spectrogram_shapes(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_spectrogram_shapes", + "reference": "\n def test_spectrogram_shapes(self):\n waveform = self._load_datasamples(1)[0]\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n power=1.0,\n center=False,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (201, 729))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\"),\n frame_length=400,\n hop_length=128,\n fft_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=True,\n )\n self.assertEqual(spec.shape, (257, 732))\n\n spec = spectrogram(\n waveform,\n window_function(400, \"hann\", frame_length=512),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=64,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 1464))\n\n spec = spectrogram(\n waveform,\n window_function(512, \"hann\"),\n frame_length=512,\n hop_length=512,\n power=1.0,\n center=True,\n pad_mode=\"reflect\",\n onesided=False,\n )\n self.assertEqual(spec.shape, (512, 183))", + "completions": [ + "```python\n def test_spectrogram_shapes(self):\n waveform = np.random.rand(1000)\n window = np.hanning(100)\n frame_length = 100\n hop_length = 10\n fft_length = 100\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = None\n mel_floor = 1e-10\n log_mel = None\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (10, 100))\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 68, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_base64_image_given_and_no_resize_needed", + "reference": "\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n _, image = example_local_image\n _, buffer = cv2.imencode(\".jpg\", image)\n base64_image = base64.b64encode(buffer).decode(\"ascii\")\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=base64_image,\n max_width=None,\n max_height=None,\n )\n\n # then\n assert (\n scaling_factor is None\n ), \"No resize parameters given, so scaling factor should not be established\"\n assert (\n serialised_image == base64_image\n ), \"Serialised image should be identical with input\"", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_base64_image_given_and_no_resize_needed(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n image_bytes = base64.b64decode(example_local_image[0])\n image = bytes_to_opencv_image(payload=image_bytes)\n # when\n result = load_image_from_string(\n reference=example_local_image[0],\n max_height=None,\n max_width=None,\n )\n # then\n assert result[0] == example_local_image[0]\n assert result[1] is None\n assert np.array_equal(image, result[1])\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 69, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def submit_payout_address(self, robot_index=1):\n path = reverse(\"order\")\n params = f\"?order_id={self.order_id}\"\n headers = self.get_robot_auth(robot_index)\n\n payout_address = create_address(\"robot\")\n signed_payout_address = sign_message(\n payout_address,\n passphrase_path=f\"tests/robots/{robot_index}/token\",\n private_key_path=f\"tests/robots/{robot_index}/enc_priv_key\",\n )\n body = {\n \"action\": \"update_address\",\n \"address\": signed_payout_address,\n \"mining_fee_rate\": 50,\n }\n self.response = self.client.post(path + params, body, **headers)\n```\n###test function signature:\n```python\n def test_successful_onchain(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_successful_onchain", + "reference": " def test_successful_onchain(self):\n trade = Trade(self.client)\n trade.publish_order()\n trade.take_order()\n trade.lock_taker_bond()\n trade.lock_escrow(trade.taker_index)\n trade.submit_payout_address(trade.maker_index)\n trade.confirm_fiat(trade.maker_index)\n trade.confirm_fiat(trade.taker_index)\n\n trade.process_payouts(mine_a_block=True)\n trade.get_order(trade.maker_index)\n\n data = trade.response.json()\n\n self.assertEqual(trade.response.status_code, 200)\n self.assertResponse(trade.response)\n\n self.assertEqual(data[\"status_message\"], Order.Status(Order.Status.SUC).label)\n self.assertTrue(data[\"is_fiat_sent\"])\n self.assertFalse(data[\"is_disputed\"])\n self.assertIsInstance(data[\"maker_summary\"][\"address\"], str)\n self.assertIsHash(data[\"maker_summary\"][\"txid\"])\n\n self.assert_order_logs(data[\"id\"])", + "completions": [ + "```python\n def test_successful_onchain(self):\n self.submit_payout_address(robot_index=1)\n self.assertEqual(self.response.status_code, 200)\n self.assertEqual(self.response.json()[\"status\"], \"success\")\n```\n" + ], + "line": 26, + "token": 183, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 70, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_blank_line():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_blank_line", + "reference": "\ndef test_header_to_metadata_and_cell_blank_line():\n text = \"\"\"---\ntitle: Sample header\n---\n\nHeader is followed by a blank line\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {}\n assert lines[pos].startswith(\"Header is\")", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_blank_line():\n lines = [\"#' ---\", \"\"]\n header_prefix = \"#'\"\n header_suffix = \"\"\n ext = None\n root_level_metadata_as_raw_cell = True\n metadata, jupyter, cell, i = header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext, root_level_metadata_as_raw_cell\n )\n assert metadata == {}\n assert jupyter == []\n assert cell == new_raw_cell(source=\"\\n\".join([\"---\"]), metadata={\"lines_to_next_cell\": 1})\n assert i == 1\n```\n" + ], + "line": 104, + "token": 642, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 71, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def _create_output(\n self,\n accepted: torch.Tensor, # [batch_size, k]\n recovered_token_ids: torch.Tensor, # [batch_size, k]\n draft_token_ids: torch.Tensor, # [batch_size, k]\n bonus_token_ids: torch.Tensor, # [batch_size]\n ) -> torch.Tensor:\n bonus_token_ids = bonus_token_ids.squeeze()\n batch_size, k = recovered_token_ids.shape\n\n # Determine the index of the first False value for each row.\n limits = (accepted == 0).max(1).indices\n limits[~(accepted == 0).any(1)] = k\n\n # Create masks using the indices.\n indices = torch.arange(k, device=accepted.device).unsqueeze(0)\n accepted_mask = indices < limits.unsqueeze(1)\n after_false_mask = indices == limits.unsqueeze(1)\n\n # Create an extended output tensor\n output_with_bonus_tokens = -torch.ones(\n (batch_size, k + self._num_bonus_tokens),\n dtype=self.token_id_dtype,\n device=accepted.device)\n output = output_with_bonus_tokens[:, :k]\n\n # Fill in the first k columns of the output tensor using masks and data\n # tensors.\n output[:, :k] = torch.where(accepted_mask, draft_token_ids,\n -torch.ones_like(draft_token_ids))\n\n # Fill the last column.\n # We check output directly as accepted may have True values inconsistent\n # with causal acceptance.\n output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,\n bonus_token_ids, -1)\n\n # We disable bonus tokens because it causes corrupt KV cache for\n # proposal methods that require KV cache. We can fix it by \"prefilling\"\n # the bonus token in the proposer. The following issue tracks the fix.\n # https://github.com/vllm-project/vllm/issues/4212\n output_with_bonus_tokens[:, -1] = -1\n\n # Fill the recovered token ids.\n output.mul_(~after_false_mask).add_(\n recovered_token_ids.mul(after_false_mask))\n\n self.num_accepted_tokens += accepted.sum()\n self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()\n self.num_draft_tokens += batch_size * k\n\n return output_with_bonus_tokens\n```\n###test function signature:\n```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_correct_output_format", + "reference": "def test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n set_random_seed(seed)\n torch.set_default_device(device)\n\n batch_size = 10\n k = 5\n vocab_size = 3000\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"no_tokens_accepted\":\n accepted = mock_causal_accepted_tensor(\n k, -torch.ones((batch_size, ), dtype=torch.long))\n elif which_tokens_accepted == \"some_tokens_accepted\":\n last_accepted_indices = torch.randint(low=-1,\n high=k,\n size=(batch_size, ))\n accepted = mock_causal_accepted_tensor(k, last_accepted_indices)\n else:\n raise AssertionError()\n\n recovered_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n draft_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, k),\n dtype=torch.int64)\n bonus_token_ids = torch.randint(low=0,\n high=vocab_size,\n size=(batch_size, 1),\n dtype=torch.int64)\n\n rejection_sampler = RejectionSampler()\n rejection_sampler.init_gpu_tensors(rank=0)\n output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access\n accepted,\n recovered_token_ids,\n draft_token_ids,\n bonus_token_ids,\n )\n\n # Bonus tokens are currently disabled. Verify they're set to -1.\n # See https://github.com/vllm-project/vllm/issues/4212\n expected_bonus_token_ids = bonus_token_ids.clone() * 0 - 1\n\n if which_tokens_accepted == \"all_tokens_accepted\":\n # Expect all tokens to be equal to draft tokens.\n assert torch.equal(output_token_ids[:, :-1], draft_token_ids)\n\n # Expect all bonus tokens to be included.\n assert torch.equal(output_token_ids[:, -1:], expected_bonus_token_ids)\n elif which_tokens_accepted == \"no_tokens_accepted\":\n # Expect first token to be equal to recovered tokens.\n assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])\n\n # Expect everything else to be -1.\n assert torch.equal(output_token_ids[:, 1:],\n torch.ones_like(output_token_ids[:, 1:]) * -1)\n elif which_tokens_accepted == \"some_tokens_accepted\":\n recovered_plus_bonus = torch.cat(\n (recovered_token_ids, expected_bonus_token_ids), dim=-1)\n # Assert first rejected token is a recovered token or bonus token.\n assert torch.equal(\n recovered_plus_bonus[torch.arange(0, batch_size),\n last_accepted_indices + 1],\n output_token_ids[torch.arange(0, batch_size),\n last_accepted_indices + 1])\n\n # Assert every subsequent token is -1.\n subsequent_mask = torch.arange(0, k + 1).expand(\n batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)\n assert torch.all(output_token_ids[subsequent_mask] == -1)", + "completions": [ + "```python\ndef test_correct_output_format(which_tokens_accepted: str, seed: int,\n device: str):\n torch.manual_seed(seed)\n batch_size = 10\n k = 10\n num_bonus_tokens = 1\n accepted = torch.randint(0, 2, (batch_size, k), dtype=torch.int32)\n recovered_token_ids = torch.randint(0, 10, (batch_size, k),\n dtype=torch.int32)\n draft_token_ids = torch.randint(0, 10, (batch_size, k), dtype=torch.int32)\n bonus_token_ids = torch.randint(0, 10, (batch_size,), dtype=torch.int32)\n output = _create_output(accepted, recovered_token_ids, draft_token_ids,\n bonus_token_ids)\n assert output.shape == (batch_size, k + num_bonus_tokens)\n assert output.dtype == torch.int32\n assert output.device == device\n assert (output[:, :k] == torch.where(accepted, draft_token_ids,\n -torch.ones_like(draft_token_ids))).all()\n assert (output[:, -1] == torch.where(output[:, -1] != -1, bonus_token_ids,\n -1)).all()\n assert (output[:, -1] == -1).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted == 0).any(1))).all()\n assert (output.mul(~(accepted == 0).any(1)) +\n recovered_token_ids.mul((accepted" + ], + "line": 54, + "token": 433, + "line_diff": 4, + "token_diff": 5 + }, + { + "id": 72, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_array(fp, allow_pickle=False, pickle_kwargs=None, *,\n max_header_size=_MAX_HEADER_SIZE):\n if allow_pickle:\n # Effectively ignore max_header_size, since `allow_pickle` indicates\n # that the input is fully trusted.\n max_header_size = 2**64\n\n version = read_magic(fp)\n _check_version(version)\n shape, fortran_order, dtype = _read_array_header(\n fp, version, max_header_size=max_header_size)\n if len(shape) == 0:\n count = 1\n else:\n count = numpy.multiply.reduce(shape, dtype=numpy.int64)\n\n # Now read the actual data.\n if dtype.hasobject:\n # The array contained Python objects. We need to unpickle the data.\n if not allow_pickle:\n raise ValueError(\"Object arrays cannot be loaded when \"\n \"allow_pickle=False\")\n if pickle_kwargs is None:\n pickle_kwargs = {}\n try:\n array = pickle.load(fp, **pickle_kwargs)\n except UnicodeError as err:\n # Friendlier error message\n raise UnicodeError(\"Unpickling a python object failed: %r\\n\"\n \"You may need to pass the encoding= option \"\n \"to numpy.load\" % (err,)) from err\n else:\n if isfileobj(fp):\n # We can use the fast fromfile() function.\n array = numpy.fromfile(fp, dtype=dtype, count=count)\n else:\n # This is not a real file. We have to read it the\n # memory-intensive way.\n # crc32 module fails on reads greater than 2 ** 32 bytes,\n # breaking large reads from gzip streams. Chunk reads to\n # BUFFER_SIZE bytes to avoid issue and reduce memory overhead\n # of the read. In non-chunked case count < max_read_count, so\n # only one read is performed.\n\n # Use np.ndarray instead of np.empty since the latter does\n # not correctly instantiate zero-width string dtypes; see\n # https://github.com/numpy/numpy/pull/6430\n array = numpy.ndarray(count, dtype=dtype)\n\n if dtype.itemsize > 0:\n # If dtype.itemsize == 0 then there's nothing more to read\n max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)\n\n for i in range(0, count, max_read_count):\n read_count = min(max_read_count, count - i)\n read_size = int(read_count * dtype.itemsize)\n data = _read_bytes(fp, read_size, \"array data\")\n array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n count=read_count)\n\n if fortran_order:\n array.shape = shape[::-1]\n array = array.transpose()\n else:\n array.shape = shape\n\n return array\n```\n###test function signature:\n```python\n\ndef test_version_2_0():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_version_2_0", + "reference": "\ndef test_version_2_0():\n f = BytesIO()\n # requires more than 2 byte for header\n dt = [((\"%d\" % i) * 100, float) for i in range(500)]\n d = np.ones(1000, dtype=dt)\n\n format.write_array(f, d, version=(2, 0))\n with warnings.catch_warnings(record=True) as w:\n warnings.filterwarnings('always', '', UserWarning)\n format.write_array(f, d)\n assert_(w[0].category is UserWarning)\n\n # check alignment of data portion\n f.seek(0)\n header = f.readline()\n assert_(len(header) % format.ARRAY_ALIGN == 0)\n\n f.seek(0)\n n = format.read_array(f, max_header_size=200000)\n assert_array_equal(d, n)\n\n # 1.0 requested but data cannot be saved this way\n assert_raises(ValueError, format.write_array, f, d, (1, 0))", + "completions": [ + "```python\ndef test_version_2_0():\n fp = open(\"test_file.npy\", \"rb\")\n read_array(fp)\n```\n" + ], + "line": 70, + "token": 516, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 73, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef log_mel_spectrogram(audio: Union[str, np.ndarray, torch.Tensor], n_mels: int = N_MELS):\n if not torch.is_tensor(audio):\n if isinstance(audio, str):\n audio = load_audio(audio)\n audio = torch.from_numpy(audio)\n\n window = torch.hann_window(N_FFT).to(audio.device)\n stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)\n magnitudes = stft[..., :-1].abs() ** 2\n\n filters = mel_filters(audio.device, n_mels)\n mel_spec = filters @ magnitudes\n\n log_spec = torch.clamp(mel_spec, min=1e-10).log10()\n log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)\n log_spec = (log_spec + 4.0) / 4.0\n return log_spec\n```\n###test function signature:\n```python\n\ndef test_audio():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_audio", + "reference": "\ndef test_audio():\n audio_path = os.path.join(os.path.dirname(__file__), \"jfk.flac\")\n audio = load_audio(audio_path)\n assert audio.ndim == 1\n assert SAMPLE_RATE * 10 < audio.shape[0] < SAMPLE_RATE * 12\n assert 0 < audio.std() < 1\n\n mel_from_audio = log_mel_spectrogram(audio)\n mel_from_file = log_mel_spectrogram(audio_path)\n\n assert np.allclose(mel_from_audio, mel_from_file)\n assert mel_from_audio.max() - mel_from_audio.min() <= 2.0", + "completions": [ + "```python\ndef test_audio():\n audio = torch.randn(1, 16000)\n log_mel_spectrogram(audio)\n```\n" + ], + "line": 24, + "token": 213, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 74, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef transform_grid_cell(\n grid: Grid,\n cell: npt.NDArray[np.float_] = np.eye(3),\n output_shape: Optional[Tuple[int, int, int]] = None,\n mode: str = \"constant\",\n order: int = 1,\n **kwargs,\n) -> Grid:\n # Take the current shape of the grid if no output shape was provided\n if output_shape is None:\n output_shape = grid.shape\n\n # Make sure the cell has type float\n cell = np.asarray(cell, dtype=float)\n\n # Get the current cell in coordinates of the destination axes\n inv_cell = cell_invert(cell).T\n projected_cell = grid.cell.dot(inv_cell)\n\n # From that, infere how long will the bounding box of the cell be\n lengths = abs(projected_cell).sum(axis=0)\n\n # Create the transformation matrix. Since we want to control the shape\n # of the output, we can not use grid.dcell directly, we need to modify it.\n scales = output_shape / lengths\n forward_t = (grid.dcell.dot(inv_cell) * scales).T\n\n # Scipy's affine transform asks for the inverse transformation matrix, to\n # map from output pixels to input pixels. By taking the inverse of our\n # transformation matrix, we get exactly that.\n tr = cell_invert(forward_t).T\n\n # Calculate the offset of the image so that all points of the grid \"fall\" inside\n # the output array.\n # For this we just calculate the centers of the input and output images\n center_input = 0.5 * (_a.asarrayd(grid.shape) - 1)\n center_output = 0.5 * (_a.asarrayd(output_shape) - 1)\n\n # And then make sure that the input center that is interpolated from the output\n # falls in the actual input's center\n offset = center_input - tr.dot(center_output)\n\n # We pass all the parameters to scipy's affine_transform\n transformed_image = affine_transform(\n grid.grid,\n tr,\n order=1,\n offset=offset,\n output_shape=output_shape,\n mode=mode,\n **kwargs,\n )\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = grid.__class__((1, 1, 1), lattice=cell * lengths.reshape(3, 1))\n new_grid.grid = transformed_image\n new_grid.geometry = grid.geometry\n new_grid.lattice.origin = grid.origin + new_grid.dcell.dot(forward_t.dot(offset))\n\n # Find the offset between the origin before and after the transformation\n return new_grid\n```\n###test function signature:\n```python\n\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_transform_grid_cell", + "reference": "\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n new_grid = transform_grid_cell(grid, cell=np.eye(3), output_shape=(10, 10, 10))\n\n assert new_grid.shape == (10, 10, 10)\n assert new_grid.lattice.is_cartesian()\n\n if not skewed:\n assert np.allclose(new_grid.lattice.cell, grid.lattice.cell)\n\n assert np.allclose(new_grid.grid, grid.grid) == (not skewed)\n assert not np.allclose(new_grid.grid, 0)\n\n assert np.allclose(new_grid.origin, grid.origin) == (not skewed)\n\n # Convert to a skewed cell\n directions = np.array([[1, 2, 3], [-1, 2, -4], [2, -1, 1]])\n new_grid = transform_grid_cell(grid, cell=directions, output_shape=(5, 5, 5))\n\n assert new_grid.shape == (5, 5, 5)\n for i in range(3):\n n = new_grid.lattice.cell[i] / directions[i]\n assert np.allclose(n, n[0])", + "completions": [ + "```python\ndef test_transform_grid_cell(grid, skewed):\n # Convert to a cartesian cell\n cell = grid.cell.dot(cell_invert(grid.cell))\n\n # Create a new grid with the new shape and the new cell (notice how the cell\n # is rescaled from the input cell to fit the actual coordinates of the system)\n new_grid = transform_grid_cell(grid, cell, grid.shape, mode=\"constant\", order=1)\n\n # Check that the new grid has the same shape and cell as the original grid\n assert new_grid.shape == grid.shape\n assert np.allclose(new_grid.cell, grid.cell)\n\n # Check that the new grid has the same geometry as the original grid\n assert new_grid.geometry == grid.geometry\n\n # Check that the new grid has the same origin as the original grid\n assert np.allclose(new_grid.origin, grid.origin)\n\n # Check that the new grid has the same dcell as the original grid\n assert np.allclose(new_grid.dcell, grid.dcell)\n\n # Check that the new grid has the same grid as the original grid\n assert np.allclose(new_grid.grid, grid.grid)\n\n # Check that the new grid has the same skewed as the original grid\n assert np.allclose(new_grid.skewed, grid.skewed)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid.grid_spacing, grid.grid_spacing)\n\n # Check that the new grid has the same grid_shape as the original grid\n assert np.allclose(new_grid.grid_shape, grid.grid_shape)\n\n # Check that the new grid has the same grid_origin as the original grid\n assert np.allclose(new_grid.grid_origin, grid.grid_origin)\n\n # Check that the new grid has the same grid_spacing as the original grid\n assert np.allclose(new_grid" + ], + "line": 63, + "token": 533, + "line_diff": 5, + "token_diff": 5 + }, + { + "id": 75, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef similarity_matrix(xq: np.ndarray, index: np.ndarray) -> np.ndarray:\n\n index_norm = norm(index, axis=1)\n xq_norm = norm(xq.T)\n sim = np.dot(index, xq.T) / (index_norm * xq_norm)\n return sim\n```\n###test function signature:\n```python\ndef test_top_scores__is_sorted(test_index):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_top_scores__is_sorted", + "reference": "def test_top_scores__is_sorted(test_index):\n\n xq = test_index[0] # should have max similarity\n\n sim = similarity_matrix(xq, test_index)\n _, idx = top_scores(sim, 3)\n\n # Scores and indexes should be sorted ascending\n assert np.array_equal(idx, np.array([2, 1, 0]))", + "completions": [ + "```python\ndef test_top_scores__is_sorted(test_index):\n sim = similarity_matrix(xq, test_index)\n assert np.all(np.diff(sim, axis=1) >= 0)\n```\n" + ], + "line": 15, + "token": 125, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 76, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef _generate_nanoplot(\n y_vals: Union[List[Union[int, float]], List[int], List[float]],\n y_ref_line: Optional[str] = None,\n y_ref_area: Optional[str] = None,\n x_vals: \"List[Union[int, float]] | None\" = None,\n expand_x: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n expand_y: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n missing_vals: str = \"marker\",\n all_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n all_single_y_vals: Optional[Union[List[Union[int, float]], List[int], List[float]]] = None,\n plot_type: str = \"line\",\n data_line_type: str = \"curved\",\n currency: Optional[str] = None,\n y_val_fmt_fn: Optional[Callable[..., str]] = None,\n y_axis_fmt_fn: Optional[Callable[..., str]] = None,\n y_ref_line_fmt_fn: Optional[Callable[..., str]] = None,\n data_point_radius: Union[int, List[int]] = 10,\n data_point_stroke_color: Union[str, List[str]] = \"#FFFFFF\",\n data_point_stroke_width: Union[int, List[int]] = 4,\n data_point_fill_color: Union[str, List[str]] = \"#FF0000\",\n data_line_stroke_color: str = \"#4682B4\",\n data_line_stroke_width: int = 8,\n data_area_fill_color: str = \"#FF0000\",\n data_bar_stroke_color: Union[str, List[str]] = \"#3290CC\",\n data_bar_stroke_width: Union[int, List[int]] = 4,\n data_bar_fill_color: Union[str, List[str]] = \"#3FB5FF\",\n data_bar_negative_stroke_color: str = \"#CC3243\",\n data_bar_negative_stroke_width: int = 4,\n data_bar_negative_fill_color: str = \"#D75A68\",\n reference_line_color: str = \"#75A8B0\",\n reference_area_fill_color: str = \"#A6E6F2\",\n vertical_guide_stroke_color: str = \"#911EB4\",\n vertical_guide_stroke_width: int = 12,\n show_data_points: bool = True,\n show_data_line: bool = True,\n show_data_area: bool = True,\n show_reference_line: bool = True,\n show_reference_area: bool = True,\n show_vertical_guides: bool = True,\n show_y_axis_guide: bool = True,\n interactive_data_values: bool = True,\n svg_height: str = \"2em\",\n) -> str:\n\n # Ensure that arguments are matched\n _match_arg(\n x=missing_vals,\n lst=[\"marker\", \"gap\", \"zero\", \"remove\"],\n )\n _match_arg(\n x=data_line_type,\n lst=[\"curved\", \"straight\"],\n )\n\n #\n # Determine where a zero line is considered and provide the stroke color and width\n #\n\n zero_line_considered = True if plot_type in [\"bar\", \"boxplot\"] else False\n\n zero_line_stroke_color = \"#BFBFBF\"\n zero_line_stroke_width = 4\n\n # Initialize several local `*_tags` variables with `None`\n ref_area_tags = None\n area_path_tags = None\n data_path_tags = None\n zero_line_tags = None\n bar_tags = None\n boxplot_tags = None\n ref_line_tags = None\n circle_tags = None\n g_y_axis_tags = None\n g_guide_tags = None\n\n # Initialize the `single_horizontal_plot` variable with `False`\n single_horizontal_plot = False\n\n # If the number of `y` values in a list is zero or if all consist of NA values,\n # return an empty string\n if type(y_vals) is list and len(y_vals) == 0:\n return \"\"\n\n # If all `y` values are NA, return an empty string\n # TODO: all([]) evaluates to True. In that case does this produce the intended behavior?\n if isinstance(y_vals, list) and all(_map_is_na(y_vals)):\n return \"\"\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # Handle case where `x_vals` exists (i.e., is not `NULL`)\n if x_vals is not None:\n\n # If the number of `x` values is zero or an empty string,\n # return an empty string\n if len(x_vals) == 0:\n return \"\"\n if all(_map_is_na(x_vals)):\n return \"\"\n\n # Get the number of data points for `x`\n num_x_vals = len(x_vals)\n\n # Ensure that, if there are `x` values, the number of `x`\n # and `y` values matches\n if num_x_vals != num_y_vals:\n raise ValueError(\n f\"\"\"The number of `x` and `y` values must match.\n The `x` value length is: {num_x_vals}\n The `y` value length is: {num_y_vals}\n \"\"\"\n )\n\n # Handle missing values in `x_vals` through removal (i.e., missing\n # values in `x_vals` means removal of positional values from both\n # `x_vals` and `y_vals`)\n if any(_map_is_na(x_vals)):\n # TODO: this code did not have test coverage and likely didn't\n # work. It should work now, but we need to test it.\n\n # Determine which values from `x_vals` are non-missing values\n x_vals_non_missing = [~_is_na(val) for val in x_vals]\n\n # Retain only `x_vals_non_missing` from `x_vals` and `y_vals`\n x_vals = [x for x, keep in zip(x_vals, x_vals_non_missing) if keep]\n y_vals = [y for y, keep in zip(y_vals, x_vals_non_missing) if keep]\n\n # If `x` values are present, we cannot use a curved line so\n # we'll force the use of the 'straight' line type\n # TODO: if someone specifies the options curved, and we can't do it\n # then we should raise an error.\n data_line_type = \"straight\"\n\n # If `missing_vals` is set to 'gap' raise an error\n # TODO: Implement the 'gap' option for missing values\n if missing_vals == \"gap\":\n raise NotImplementedError(\"The 'gap' option for missing values is not yet implemented.\")\n\n # For the `missing_vals` options of 'zero' or 'remove', either replace NAs\n # with `0` or remove NAs entirely\n if missing_vals == \"zero\":\n y_vals = y_vals.fillna(0)\n\n # If `missing_vals` is 'remove', remove NAs from `y_vals`\n if missing_vals == \"remove\":\n y_vals = y_vals.dropna()\n\n if x_vals is not None:\n # Remove the corresponding `x` values for the removed `y` values\n x_vals = x_vals[y_vals.index]\n\n # Get the number of data points for `y`\n if type(y_vals) is list:\n num_y_vals = len(y_vals)\n else:\n num_y_vals = 1\n\n # If `y_vals` is a scalar value we requested a 'line' or 'bar' plot, then\n # reset several parameters\n if type(y_vals) in [int, float] and plot_type in [\"line\", \"bar\"]:\n\n single_horizontal_plot = True\n show_data_points = True\n show_data_line = True\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n y_vals = [y_vals]\n\n # If this is a box plot, set several parameters\n if plot_type == \"boxplot\":\n\n show_data_points = False\n show_data_line = False\n show_data_area = False\n show_reference_line = False\n show_reference_area = False\n show_vertical_guides = False\n show_y_axis_guide = False\n\n # Find out whether the collection of non-NA `y` values are all integer-like\n y_vals_integerlike = _is_integerlike(val_list=y_vals)\n\n # Get the max and min of the `y` scale from the `y` data values\n y_scale_max = _get_extreme_value(y_vals, stat=\"max\")\n y_scale_min = _get_extreme_value(y_vals, stat=\"min\")\n\n # Handle cases where collection of `y_vals` is invariant\n if y_scale_min == y_scale_max and expand_y is None:\n\n if y_scale_min == 0:\n expand_y_dist = 5\n else:\n expand_y_dist = (y_scale_min / 10) * 2\n\n # Expand the `y` scale, centering around the `y_scale_min` value\n expand_y = [y_scale_min - expand_y_dist, y_scale_min + expand_y_dist]\n\n # Ensure that a reference line or reference area isn't shown if None or\n # any of its directives is missing\n if _is_na(y_ref_line):\n show_reference_line = False\n\n if y_ref_area is None:\n show_reference_area = False\n elif _is_na(y_ref_area[0]) or _is_na(y_ref_area[1]):\n show_reference_area = False\n\n # Determine the width of the data plot area; for plots where `x_vals`\n # are available, we'll use a fixed width of `500` (px), and for plots\n # where `x_vals` aren't present, we'll adjust the final width based\n # on the fixed interval between data points (this is dependent on the\n # number of data points)\n if x_vals is not None or single_horizontal_plot or plot_type == \"boxplot\":\n data_x_width = 600\n # TODO: what should x_d be in this case?\n else:\n # Obtain a sensible, fixed interval between data points in px\n if num_y_vals <= 20:\n x_d = 50\n elif num_y_vals <= 30:\n x_d = 40\n elif num_y_vals <= 40:\n x_d = 30\n elif num_y_vals <= 50:\n x_d = 25\n else:\n x_d = 20\n\n data_x_width = num_y_vals * x_d\n\n # Define the top-left of the plot area\n left_x = 0\n top_y = 0\n\n # Define the safe zone distance from top/bottom and left/right edges\n safe_y_d = 15\n safe_x_d = 50\n\n # Define the height of the plot area that bounds the data points\n data_y_height = 100\n\n # Determine the bottom-right of the plot area based on the quantity of data\n bottom_y = safe_y_d + data_y_height + safe_y_d\n right_x = safe_x_d + data_x_width + safe_x_d\n\n viewbox = f\"{left_x} {top_y} {right_x} {bottom_y}\"\n\n #\n # If there is a reference line and/or reference area, the values for these\n # need to be generated and integrated in the `normalize_y_vals()` operation\n # so that there are normalized values in relation to the data points\n #\n\n if show_reference_line and show_reference_area:\n\n # Case where there is both a reference line and a reference area\n\n #\n # Resolve the reference line\n #\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n #\n # Resolve the reference area\n #\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n _all_y_data = [y_vals, y_ref_line, y_ref_area_l, y_ref_area_u, expand_y]\n\n # Recompute the `y` scale min and max values\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference line and reference area boundaries\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n elif show_reference_line:\n\n # Case where there is a reference line\n\n if (\n y_ref_line is not None\n and _val_is_str(y_ref_line)\n and y_ref_line in REFERENCE_LINE_KEYWORDS\n ):\n y_ref_line = _generate_ref_line_from_keyword(vals=y_vals, keyword=y_ref_line)\n\n # Recompute the `y` scale min and max values\n args = [y_vals, y_ref_line, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_line=y_ref_line,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportion_ref_line = y_proportions_list[\"ref_line\"][0]\n\n # Scale reference line\n data_y_ref_line = safe_y_d + ((1 - y_proportion_ref_line) * data_y_height)\n\n elif show_reference_area:\n\n # Case where there is a reference area\n\n # Note if y_ref_area were None, we would not be in this clause\n y_ref_area_line_1 = calc_ref_value(y_ref_area[0], y_vals)\n y_ref_area_line_2 = calc_ref_value(y_ref_area[1], y_vals)\n\n y_ref_area_lines_sorted = sorted([y_ref_area_line_1, y_ref_area_line_2])\n y_ref_area_l = y_ref_area_lines_sorted[0]\n y_ref_area_u = y_ref_area_lines_sorted[1]\n\n # Recompute the `y` scale min and max values\n\n # Recompute the `y` scale min and max values\n _all_y_data = [y_vals, y_ref_area_l, y_ref_area_u, expand_y] + (\n [0] if zero_line_considered else []\n )\n y_scale_max = _get_extreme_value(*_all_y_data, stat=\"max\")\n y_scale_min = _get_extreme_value(*_all_y_data, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n ref_area_l=y_ref_area_l,\n ref_area_u=y_ref_area_u,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n y_proportions_ref_area_l = y_proportions_list[\"ref_area_l\"][0]\n y_proportions_ref_area_u = y_proportions_list[\"ref_area_u\"][0]\n\n # Scale reference area boundaries\n data_y_ref_area_l = safe_y_d + ((1 - y_proportions_ref_area_l) * data_y_height)\n data_y_ref_area_u = safe_y_d + ((1 - y_proportions_ref_area_u) * data_y_height)\n\n else:\n\n # Case where there is no reference line or reference area\n\n # Recompute the `y` scale min and max values\n args = [y_vals, expand_y] + ([0] if zero_line_considered else [])\n y_scale_max = _get_extreme_value(*args, stat=\"max\")\n y_scale_min = _get_extreme_value(*args, stat=\"min\")\n\n y_proportions_list = _normalize_to_dict(\n vals=y_vals,\n zero=0 if zero_line_considered else None,\n expand_y=expand_y,\n )\n\n y_proportions = y_proportions_list[\"vals\"]\n\n # Calculate the `data_y0_point` value for zero-line-inclusive plots\n if zero_line_considered:\n y_proportions_zero = y_proportions_list[\"zero\"][0]\n data_y0_point = safe_y_d + ((1 - y_proportions_zero) * data_y_height)\n\n # If x values are present then normalize them between [0, 1]; if\n # there are no x values, generate equally-spaced x values according\n # to the number of y values\n if plot_type == \"line\" and x_vals is not None:\n\n if expand_x is not None and _val_is_str(expand_x):\n\n # TODO: the line below lacked tests, and called non-existent methods.\n # replace with something that doesn't use pandas and returns the correct thing.\n\n # Assume that string values are dates and convert them to timestamps\n # expand_x = pd.to_datetime(expand_x, utc=True).timestamp()\n raise NotImplementedError(\"Currently, passing expand_x as a string is unsupported.\")\n\n # Scale to proportional values\n x_proportions_list = _normalize_to_dict(vals=x_vals, expand_x=expand_x)\n\n x_proportions = x_proportions_list[\"vals\"]\n\n else:\n x_proportions = np.linspace(0, 1, num_y_vals)\n\n #\n # Create normalized (and inverted for SVG) data `x` and `y` values for the\n # plot area; these are named `data_x_points` and `data_y_points`\n #\n\n for i in range(len(y_proportions)):\n y_proportions[i] = safe_y_d + ((1 - y_proportions[i]) * data_y_height)\n\n for i in range(len(x_proportions)):\n x_proportions[i] = (data_x_width * x_proportions[i]) + safe_x_d\n\n data_y_points = y_proportions\n data_x_points = x_proportions\n\n #\n # Ensure that certain options have their lengths checked and\n # expanded to length `num_y_vals`\n #\n\n data_point_radius = _normalize_option_list(option_list=data_point_radius, num_y_vals=num_y_vals)\n data_point_stroke_color = _normalize_option_list(\n option_list=data_point_stroke_color, num_y_vals=num_y_vals\n )\n data_point_stroke_width = _normalize_option_list(\n option_list=data_point_stroke_width, num_y_vals=num_y_vals\n )\n data_point_fill_color = _normalize_option_list(\n option_list=data_point_fill_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_color = _normalize_option_list(\n option_list=data_bar_stroke_color, num_y_vals=num_y_vals\n )\n data_bar_stroke_width = _normalize_option_list(\n option_list=data_bar_stroke_width, num_y_vals=num_y_vals\n )\n data_bar_fill_color = _normalize_option_list(\n option_list=data_bar_fill_color, num_y_vals=num_y_vals\n )\n\n #\n # Generate data segments by defining `start` and `end` vectors (these\n # are guaranteed to be of the same length); these the segments of data\n # they receive line segments and adjoining areas\n #\n\n # Use run-length encoding to determine the segments of data\n\n # rle_data_y_points = pd.Series(data_y_points).diff().ne(0).cumsum()\n\n # end_data_y_points = np.cumsum(rle_data_y_points.lengths)\n\n # start_data_y_points = end_data_y_points - rle_data_y_points.lengths + 1\n\n start_data_y_points = [0]\n end_data_y_points = [len(data_y_points)]\n n_segments = 1\n\n #\n # Generate a curved data line\n #\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"curved\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n curve_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n curve_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n curved_path_string = [f\"M {curve_x[0]},{curve_y[0]}\"]\n\n for j in range(1, len(curve_x)):\n\n point_b1 = f\"{curve_x[j - 1] + x_d / 2},{curve_y[j - 1]}\"\n point_b2 = f\"{curve_x[j] - x_d / 2},{curve_y[j]}\"\n point_i = f\"{curve_x[j]},{curve_y[j]}\"\n\n path_string_j = f\"C {point_b1} {point_b2} {point_i}\"\n\n curved_path_string.append(path_string_j)\n\n curved_path_string_i = \" \".join(curved_path_string)\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\\n\".join(data_path_tags)\n\n if plot_type == \"line\" and show_data_line and data_line_type == \"straight\":\n\n data_path_tags = []\n\n for i in range(n_segments):\n\n line_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n line_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n line_xy = \" \".join([f\"{x},{y}\" for x, y in zip(line_x, line_y)])\n\n data_path_tags_i = f''\n\n data_path_tags.append(data_path_tags_i)\n\n data_path_tags = \"\".join(data_path_tags)\n\n #\n # Generate data points\n #\n\n if plot_type == \"line\" and show_data_points:\n\n circle_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_point_stroke_color_i = data_point_stroke_color[i]\n data_point_stroke_width_i = data_point_stroke_width[i]\n data_point_fill_color_i = data_point_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n circle_strings_i = f''\n\n else:\n continue\n\n else:\n circle_strings_i = f''\n\n circle_strings.append(circle_strings_i)\n\n circle_tags = \"\".join(circle_strings)\n\n #\n # Generate data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n bar_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n data_point_radius_i = data_point_radius[i]\n data_bar_stroke_color_i = data_bar_stroke_color[i]\n data_bar_stroke_width_i = data_bar_stroke_width[i]\n data_bar_fill_color_i = data_bar_fill_color[i]\n\n if data_y_points[i] is None:\n\n if missing_vals == \"marker\":\n\n # Create a symbol that should denote that a missing value is present\n bar_strings_i = f''\n\n else:\n continue\n\n else:\n\n if y_vals[i] < 0:\n\n y_value_i = data_y0_point\n y_height = data_y_points[i] - data_y0_point\n data_bar_stroke_color_i = data_bar_negative_stroke_color\n data_bar_stroke_width_i = data_bar_negative_stroke_width\n data_bar_fill_color_i = data_bar_negative_fill_color\n\n elif y_vals[i] > 0:\n\n y_value_i = data_y_points[i]\n y_height = data_y0_point - data_y_points[i]\n\n elif y_vals[i] == 0:\n\n y_value_i = data_y0_point - 1\n y_height = 2\n data_bar_stroke_color_i = \"#808080\"\n data_bar_stroke_width_i = 4\n data_bar_fill_color_i = \"#808080\"\n\n bar_strings_i = f''\n\n bar_strings.append(bar_strings_i)\n\n bar_tags = \"\".join(bar_strings)\n\n #\n # Generate single horizontal data bars\n #\n\n if plot_type == \"bar\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n data_bar_stroke_color = data_bar_negative_stroke_color\n data_bar_stroke_width = data_bar_negative_stroke_width\n data_bar_fill_color = data_bar_negative_fill_color\n\n rect_x = y_width\n rect_width = y0_width - y_width\n\n elif y_vals[0] > 0:\n\n data_bar_stroke_color = data_bar_stroke_color[0]\n data_bar_stroke_width = data_bar_stroke_width[0]\n data_bar_fill_color = data_bar_fill_color[0]\n\n rect_x = y0_width\n rect_width = y_width - y0_width\n\n elif y_vals[0] == 0:\n\n data_bar_stroke_color = \"#808080\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#808080\"\n\n rect_x = y0_width - 2.5\n rect_width = 5\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n bar_tags = f'{g_guide_tags}'\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate single horizontal data lines\n #\n\n # TODO: Make this a line with a single point\n if plot_type == \"line\" and single_horizontal_plot:\n\n # This type of display assumes there is only a single `y` value and there\n # are possibly several such horizontal bars across different rows that\n # need to be on a common scale\n\n data_point_radius_i = data_point_radius[0]\n data_point_stroke_color_i = data_point_stroke_color[0]\n data_point_stroke_width_i = data_point_stroke_width[0]\n data_point_fill_color_i = data_point_fill_color[0]\n\n bar_thickness = data_point_radius[0] * 4\n\n if all(val == 0 for val in all_single_y_vals):\n\n # Handle case where all values across rows are `0`\n\n y_proportion = 0.5\n y_proportion_zero = 0.5\n\n else:\n\n # Scale to proportional values\n y_proportions_list = _normalize_to_dict(val=y_vals, all_vals=all_single_y_vals, zero=0)\n\n y_proportion = y_proportions_list[\"val\"][0]\n y_proportion_zero = y_proportions_list[\"zero\"][0]\n\n y0_width = y_proportion_zero * data_x_width\n y_width = y_proportion * data_x_width\n\n if y_vals[0] < 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x1_val\n\n elif y_vals[0] > 0:\n\n x1_val = y0_width\n x2_val = y_width\n\n circle_x_val = x2_val\n\n elif y_vals[0] == 0:\n\n x1_val = y_width\n x2_val = y0_width\n\n circle_x_val = x2_val\n\n # Format number compactly\n y_value = _format_number_compactly(\n val=y_vals[0], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n rect_strings = f''\n\n if y_vals[0] > 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] < 0:\n\n text_strings = f'{y_value}'\n\n elif y_vals[0] == 0:\n\n if all(val == 0 for val in all_single_y_vals):\n\n text_anchor = \"start\"\n x_position_text = y0_width + 10\n\n elif all(val < 0 for val in all_single_y_vals):\n\n text_anchor = \"end\"\n x_position_text = y0_width - 10\n\n else:\n text_anchor = \"start\"\n x_position_text = y0_width + 15\n\n text_strings = f'{y_value}'\n\n g_guide_tags = f'{rect_strings}{text_strings}'\n\n data_path_tags = f'{g_guide_tags}'\n\n circle_tags = f''\n\n zero_line_tags = f''\n\n # Redefine the `viewbox` in terms of the `data_x_width` value; this ensures\n # that the horizontal bars are centered about their extreme values\n viewbox = f\"{left_x} {top_y} {data_x_width} {bottom_y}\"\n\n #\n # Generate box plots\n #\n\n if plot_type == \"boxplot\":\n pass\n\n #\n # Generate zero line for vertical bar plots\n #\n\n if plot_type == \"bar\" and single_horizontal_plot is False:\n\n zero_line_tags = f''\n\n #\n # Generate reference line\n #\n\n if show_reference_line:\n\n stroke = reference_line_color\n stroke_width = 1\n stroke_dasharray = \"4 3\"\n transform = \"\"\n stroke_linecap = \"round\"\n vector_effect = \"non-scaling-stroke\"\n\n # Format value in a compact manner\n y_ref_line = _format_number_compactly(\n val=y_ref_line, currency=currency, as_integer=y_vals_integerlike, fn=y_ref_line_fmt_fn\n )\n\n ref_line_tags = f'{y_ref_line}'\n\n #\n # Generate reference area\n #\n\n if show_reference_area:\n\n fill = reference_area_fill_color\n\n p_ul = f\"{data_x_points[0]},{data_y_ref_area_u}\"\n p_ur = f\"{data_x_points[-1]},{data_y_ref_area_u}\"\n p_lr = f\"{data_x_points[-1]},{data_y_ref_area_l}\"\n p_ll = f\"{data_x_points[0]},{data_y_ref_area_l}\"\n\n ref_area_path = f\"M{p_ul},{p_ur},{p_lr},{p_ll}Z\"\n\n ref_area_tags = f''\n\n #\n # Generate y-axis guide\n #\n\n if show_y_axis_guide:\n\n rect_tag = f''\n\n if _is_integerlike(val_list=[y_scale_max]) and _is_integerlike(val_list=[y_scale_min]):\n y_axis_guide_vals_integerlike = True\n else:\n y_axis_guide_vals_integerlike = False\n\n # Format values in a compact manner\n y_value_max_label = _format_number_compactly(\n val=y_scale_max,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n y_value_min_label = _format_number_compactly(\n val=y_scale_min,\n currency=currency,\n as_integer=y_axis_guide_vals_integerlike,\n fn=y_axis_fmt_fn,\n )\n\n text_strings_min = f'{y_value_min_label}'\n\n text_strings_max = f'{y_value_max_label}'\n\n g_y_axis_tags = f'{rect_tag}{text_strings_max}{text_strings_min}'\n\n #\n # Generate vertical data point guidelines\n #\n\n if show_vertical_guides:\n\n g_guide_strings = []\n\n for i, _ in enumerate(data_x_points):\n\n rect_strings_i = f''\n\n # Format value in a compact manner\n y_value_i = _format_number_compactly(\n val=y_vals[i], currency=currency, as_integer=y_vals_integerlike, fn=y_val_fmt_fn\n )\n\n x_text = data_x_points[i] + 10\n\n if y_value_i == \"NA\":\n x_text = x_text + 2\n\n text_strings_i = f'{y_value_i}'\n\n g_guide_strings_i = f'{rect_strings_i}{text_strings_i}'\n\n g_guide_strings.append(g_guide_strings_i)\n\n g_guide_tags = \"\".join(g_guide_strings)\n\n #\n # Generate background with repeating line pattern\n #\n\n svg_defs = (\n f\"\"\n f''\n f''\n f\"\"\n f\"\"\n f\"\"\n )\n\n if plot_type == \"line\" and show_data_area:\n\n area_path_tags = []\n\n for i in range(n_segments):\n\n area_x = data_x_points[start_data_y_points[i] : end_data_y_points[i]]\n area_y = data_y_points[start_data_y_points[i] : end_data_y_points[i]]\n\n area_path_string = []\n\n for j in range(0, len(area_x)):\n\n area_path_j = f\"{area_x[j]},{area_y[j]}\"\n area_path_string.append(area_path_j)\n\n area_path_i = f\"M {' '.join(area_path_string)} {area_x[-1]},{bottom_y - safe_y_d + data_point_radius[0]} {area_x[0]},{bottom_y - safe_y_d + data_point_radius[0]} Z\"\n\n area_path_tag_i = f''\n\n area_path_tags.append(area_path_tag_i)\n\n area_path_tags = \" \".join(area_path_tags)\n\n #\n # Generate style tag for vertical guidelines and y-axis\n #\n\n hover_param = \":hover\" if interactive_data_values else \"\"\n\n svg_style = (\n f\"\"\n )\n\n nanoplot_svg = _construct_nanoplot_svg(\n viewbox=viewbox,\n svg_height=svg_height,\n svg_defs=svg_defs,\n svg_style=svg_style,\n show_data_points=show_data_points,\n show_data_line=show_data_line,\n show_data_area=show_data_area,\n show_reference_line=show_reference_line,\n show_reference_area=show_reference_area,\n show_vertical_guides=show_vertical_guides,\n show_y_axis_guide=show_y_axis_guide,\n ref_area_tags=ref_area_tags,\n area_path_tags=area_path_tags,\n data_path_tags=data_path_tags,\n zero_line_tags=zero_line_tags,\n bar_tags=bar_tags,\n ref_line_tags=ref_line_tags,\n circle_tags=circle_tags,\n g_y_axis_tags=g_y_axis_tags,\n g_guide_tags=g_guide_tags,\n )\n\n return nanoplot_svg\n```\n###test function signature:\n```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nanoplot_out_bars_with_mixed_ref_area_2", + "reference": "\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n\n out_bars_with_mixed_ref_area_2 = _generate_nanoplot(**CASES[16])\n\n assert _is_nanoplot_output(out_bars_with_mixed_ref_area_2)\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"30.0\"),\n (\"width\", \"40\"),\n (\"height\", \"27.179487179487168\"),\n (\"stroke\", \"#CC3243\"),\n (\"stroke-width\", \"4\"),\n (\"fill\", \"#D75A68\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"rect\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"0\"),\n (\"width\", \"65\"),\n (\"height\", \"130\"),\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"0\"),\n (\"fill\", \"transparent\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"text\",\n attrs=[\n (\"x\", \"0\"),\n (\"y\", \"19.0\"),\n (\"fill\", \"transparent\"),\n (\"stroke\", \"transparent\"),\n (\"font-size\", \"25\"),\n ],\n )\n\n assert _nanoplot_has_tag_attrs(\n out_bars_with_mixed_ref_area_2,\n tag=\"path\",\n attrs=[\n (\"stroke\", \"transparent\"),\n (\"stroke-width\", \"2\"),\n (\"fill\", \"#A6E6F2\"),\n (\"fill-opacity\", \"0.8\"),\n ],\n )", + "completions": [ + "```python\n\ndef test_nanoplot_out_bars_with_mixed_ref_area_2():\n y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n y_ref_area = [1, 10]\n y_ref_line = 5\n x_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n expand_x = \"2020-01-01 00:00:00 UTC\"\n expand_y = \"2020-01-01 00:00:00 UTC\"\n missing_vals = \"marker\"\n all_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n all_single_y_vals = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n plot_type = \"bar\"\n data_line_type = \"curved\"\n currency = None\n y_val_fmt_fn = None\n y_axis_fmt_fn = None\n y_ref_line_fmt_fn = None\n data_point_radius = 10\n data_point_stroke_color = \"#FFFFFF\"\n data_point_stroke_width = 4\n data_point_fill_color = \"#FF0000\"\n data_line_stroke_color = \"#4682B4\"\n data_line_stroke_width = 8\n data_area_fill_color = \"#FF0000\"\n data_bar_stroke_color = \"#3290CC\"\n data_bar_stroke_width = 4\n data_bar_fill_color = \"#3FB5FF\"\n data_bar_negative_stroke_color = \"#CC3243\"\n data_bar_negative_stroke_width = 4\n data_bar_negative_fill_color = \"#D75A68\"\n reference_line_color = \"#75A8B4\"\n reference_area_fill_color = \"#A6E6F2\"\n vertical_guide_stroke_color = \"#911EB4\"\n vertical_guide_stroke_width = 12\n show_data_points = True\n show_data_line = True\n show_data_area = True\n show_reference_line = True\n show_reference_area = True\n show_vertical_guides = True\n show_y_axis_guide = True\n interactive_data_values = True\n svg_height = \"2em\"\n\n nanoplot_svg = _generate_nanoplot(\n y_vals=y_vals,\n y_ref_line=y_ref_line,\n y_ref_area=y_ref_area,\n x_vals=x_vals,\n expand_x=expand_x,\n expand_y=expand_y,\n missing_vals=missing_vals,\n all_y_vals=all_y_vals,\n all_single_y_vals=all_single_y_vals,\n plot_type=plot_type,\n data_line_type=data_line_type,\n currency=currency,\n y_val_fmt_fn=y_val_fmt_fn,\n y_axis_fmt_fn=y_axis_fmt_fn,\n y_ref_line_fmt_fn=y_ref_line_fmt_fn,\n data_point_radius=data_point_radius,\n data_point_stroke_color=data_point_stroke_color,\n data_point_stroke_width=data_point_stroke_width,\n data_point_fill_color=data_point_fill_color,\n data_line_stroke_color=data_line_stroke_color,\n data_line_" + ], + "line": 749, + "token": 7038, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 77, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def mean(self, lower=0.0, upper=np.inf):\n loc = []\n for idx in self.index:\n qpd = self.qpd.loc[idx, :].values[0]\n l, _ = quad(exp_func, args=(qpd), a=lower, b=upper)\n loc.append(l)\n loc_arr = np.array(loc)\n return pd.DataFrame(loc_arr, index=self.index, columns=self.columns)\n```\n###test function signature:\n```python\ndef test_qpd_b_simple_use():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_qpd_b_simple_use", + "reference": "def test_qpd_b_simple_use():\n from skpro.distributions.qpd import QPD_B\n\n qpd = QPD_B(\n alpha=0.2,\n qv_low=[1, 2],\n qv_median=[3, 4],\n qv_high=[5, 6],\n lower=0,\n upper=10,\n )\n\n qpd.mean()", + "completions": [ + "```python\ndef test_qpd_b_simple_use():\n import numpy as np\n import pandas as pd\n from scipy.integrate import quad\n from scipy.special import expit\n\n def exp_func(x, qpd):\n return expit(qpd * x)\n\n qpd = pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z'])\n qpd_b = qpd.mean(lower=0.0, upper=np.inf)\n assert qpd_b.equals(pd.DataFrame([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]],\n index=['a', 'b', 'c'], columns=['x', 'y', 'z']))\n```\n" + ], + "line": 18, + "token": 146, + "line_diff": 0, + "token_diff": 0 + }, + { + "id": 78, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef create(\n supply_settings: Callable[[], config.ResilienceSettings] = config.ResilienceSettings, include_post: bool = False\n) -> adapters.HTTPAdapter:\n settings = supply_settings()\n allowed_methods = set(util.Retry.DEFAULT_ALLOWED_METHODS)\n if include_post:\n allowed_methods.add(\"POST\")\n\n retries = util.Retry(\n total=settings.total,\n backoff_factor=settings.backoff_factor,\n status_forcelist=settings.statuses,\n allowed_methods=allowed_methods,\n )\n return adapters.HTTPAdapter(max_retries=retries)\n```\n###test function signature:\n```python\n\ndef test_should_create_adapter(include_post):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_should_create_adapter", + "reference": "\ndef test_should_create_adapter(include_post):\n total_retries = 5\n backoff_factor = 0.4\n statuses = (random.randint(400, 600),)\n settings = config.ResilienceSettings(total=total_retries, backoff_factor=backoff_factor, statuses=statuses)\n adapter = http_adapter.create(supply_settings=lambda: settings, include_post=include_post)\n assert adapter.max_retries.total == total_retries\n assert adapter.max_retries.backoff_factor == backoff_factor\n assert adapter.max_retries.status_forcelist == statuses\n allowed_methods = adapter.max_retries.allowed_methods or []\n if include_post:\n assert \"POST\" in allowed_methods\n else:\n assert \"POST\" not in allowed_methods", + "completions": [ + "```python\n\ndef test_should_create_adapter(include_post):\n assert create(include_post=include_post) == adapters.HTTPAdapter(max_retries=util.Retry(\n total=config.ResilienceSettings().total,\n backoff_factor=config.ResilienceSettings().backoff_factor,\n status_forcelist=config.ResilienceSettings().statuses,\n allowed_methods=set(util.Retry.DEFAULT_ALLOWED_METHODS) if not include_post else set(util.Retry.DEFAULT_ALLOWED_METHODS).union({\"POST\"})\n ))\n```\n" + ], + "line": 24, + "token": 143, + "line_diff": 1, + "token_diff": 0 + }, + { + "id": 79, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef extract_best_config(trial_path: str, output_path: Optional[str] = None) -> Dict:\n summary_path = os.path.join(trial_path, 'summary.csv')\n if not os.path.exists(summary_path):\n raise ValueError(f\"summary.csv does not exist in {trial_path}.\")\n trial_summary_df = load_summary_file(summary_path, dict_columns=['best_module_params'])\n config_yaml_path = os.path.join(trial_path, 'config.yaml')\n with open(config_yaml_path, 'r') as f:\n config_dict = yaml.safe_load(f)\n yaml_dict = summary_df_to_yaml(trial_summary_df, config_dict)\n if output_path is not None:\n with open(output_path, 'w') as f:\n yaml.dump(yaml_dict, f)\n return yaml_dict\n```\n###test function signature:\n```python\n\ndef test_runner(evaluator):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_runner", + "reference": "\ndef test_runner(evaluator):\n os.environ['BM25'] = 'bm25'\n evaluator.start_trial(os.path.join(resource_dir, 'simple.yaml'))\n project_dir = evaluator.project_dir\n\n def runner_test(runner: Runner):\n answer = runner.run('What is the best movie in Korea? Have Korea movie ever won Oscar?',\n 'retrieved_contents')\n assert len(answer) == 1\n assert isinstance(answer, list)\n assert isinstance(answer[0], str)\n\n runner = Runner.from_trial_folder(os.path.join(project_dir, '0'))\n runner_test(runner)\n runner_test(runner)\n\n with tempfile.NamedTemporaryFile(suffix='yaml', mode='w+t') as yaml_path:\n extract_best_config(os.path.join(project_dir, '0'), yaml_path.name)\n runner = Runner.from_yaml(yaml_path.name, project_dir=project_dir)\n runner_test(runner)", + "completions": [ + "```python\ndef test_runner(evaluator):\n def test_extract_best_config():\n trial_path = 'path/to/trial'\n output_path = 'path/to/output'\n yaml_dict = extract_best_config(trial_path, output_path)\n assert yaml_dict is not None\n assert isinstance(yaml_dict, dict)\n assert 'best_module_params' in yaml_dict\n assert isinstance(yaml_dict['best_module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params'], dict)\n assert 'module_name' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']\n assert isinstance(yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['module_name'], str)\n assert 'module_params' in yaml_dict['best_module_params']['module_params']['module_params']['module_params']['module_params']['module_params']['" + ], + "line": 23, + "token": 205, + "line_diff": 0, + "token_diff": 1 + }, + { + "id": 80, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\n def init_weights(self):\n logger = MMLogger.get_current_instance()\n if self.init_cfg is None:\n logger.warn(f'No pre-trained weights for '\n f'{self.__class__.__name__}, '\n f'training start from scratch')\n if self.use_abs_pos_embed:\n trunc_normal_(self.absolute_pos_embed, std=0.02)\n for m in self.modules():\n if isinstance(m, nn.Linear):\n trunc_normal_init(m, std=.02, bias=0.)\n elif isinstance(m, nn.LayerNorm):\n constant_init(m, 1.0)\n else:\n assert 'checkpoint' in self.init_cfg, f'Only support ' \\\n f'specify `Pretrained` in ' \\\n f'`init_cfg` in ' \\\n f'{self.__class__.__name__} '\n ckpt = CheckpointLoader.load_checkpoint(\n self.init_cfg.checkpoint, logger=logger, map_location='cpu')\n if 'state_dict' in ckpt:\n _state_dict = ckpt['state_dict']\n elif 'model' in ckpt:\n _state_dict = ckpt['model']\n else:\n _state_dict = ckpt\n if self.convert_weights:\n # supported loading weight from original repo,\n _state_dict = swin_converter(_state_dict)\n\n state_dict = OrderedDict()\n for k, v in _state_dict.items():\n if k.startswith('backbone.'):\n state_dict[k[9:]] = v\n\n # strip prefix of state_dict\n if list(state_dict.keys())[0].startswith('module.'):\n state_dict = {k[7:]: v for k, v in state_dict.items()}\n\n # reshape absolute position embedding\n if state_dict.get('absolute_pos_embed') is not None:\n absolute_pos_embed = state_dict['absolute_pos_embed']\n N1, L, C1 = absolute_pos_embed.size()\n N2, C2, H, W = self.absolute_pos_embed.size()\n if N1 != N2 or C1 != C2 or L != H * W:\n logger.warning('Error in loading absolute_pos_embed, pass')\n else:\n state_dict['absolute_pos_embed'] = absolute_pos_embed.view(\n N2, H, W, C2).permute(0, 3, 1, 2).contiguous()\n\n # interpolate position bias table if needed\n relative_position_bias_table_keys = [\n k for k in state_dict.keys()\n if 'relative_position_bias_table' in k\n ]\n for table_key in relative_position_bias_table_keys:\n table_pretrained = state_dict[table_key]\n table_current = self.state_dict()[table_key]\n L1, nH1 = table_pretrained.size()\n L2, nH2 = table_current.size()\n if nH1 != nH2:\n logger.warning(f'Error in loading {table_key}, pass')\n elif L1 != L2:\n S1 = int(L1**0.5)\n S2 = int(L2**0.5)\n table_pretrained_resized = F.interpolate(\n table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1),\n size=(S2, S2),\n mode='bicubic')\n state_dict[table_key] = table_pretrained_resized.view(\n nH2, L2).permute(1, 0).contiguous()\n\n # load state_dict\n self.load_state_dict(state_dict, False)\n```\n###test function signature:\n```python\ndef test_swin_transformer():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_swin_transformer", + "reference": "def test_swin_transformer():\n\n with pytest.raises(TypeError):\n # Pretrained arg must be str or None.\n SwinTransformer(pretrained=123)\n\n with pytest.raises(AssertionError):\n # Because swin uses non-overlapping patch embed, so the stride of patch\n # embed must be equal to patch size.\n SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)\n\n # test pretrained image size\n with pytest.raises(AssertionError):\n SwinTransformer(pretrain_img_size=(224, 224, 224))\n\n # Test absolute position embedding\n temp = torch.randn((1, 3, 224, 224))\n model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)\n model.init_weights()\n model(temp)\n\n # Test patch norm\n model = SwinTransformer(patch_norm=False)\n model(temp)\n\n # Test normal inference\n temp = torch.randn((1, 3, 32, 32))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 31, 31))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 8, 8)\n assert outs[1].shape == (1, 192, 4, 4)\n assert outs[2].shape == (1, 384, 2, 2)\n assert outs[3].shape == (1, 768, 1, 1)\n\n # Test abnormal inference size\n temp = torch.randn((1, 3, 112, 137))\n model = SwinTransformer()\n outs = model(temp)\n assert outs[0].shape == (1, 96, 28, 35)\n assert outs[1].shape == (1, 192, 14, 18)\n assert outs[2].shape == (1, 384, 7, 9)\n assert outs[3].shape == (1, 768, 4, 5)\n\n model = SwinTransformer(frozen_stages=4)\n model.train()\n for p in model.parameters():\n assert not p.requires_grad", + "completions": [ + "```python\ndef test_swin_transformer():\n from mmcv.cnn import (build_pos_embed, build_pos_embed_sine,\n constant_init, trunc_normal_init)\n from mmcv.runner import CheckpointLoader\n from mmdet.models.builder import build_backbone\n from mmdet.models.utils import swin_converter\n from mmdet.utils import get_root_logger\n from mmdet.utils.logger import MMLogger\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import PatchEmbed\n from mmdet.models.backbones.swin import PatchMerging\n from mmdet.models.backbones.swin import BasicLayer\n from mmdet.models.backbones.swin import BasicBlock\n from mmdet.models.backbones.swin import DepthwiseSeparableConv\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransformer\n from mmdet.models.backbones.swin import SwinTransform" + ], + "line": 79, + "token": 586, + "line_diff": 6, + "token_diff": 5 + }, + { + "id": 81, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef columns_equal(\n col_1: \"pl.Series\",\n col_2: \"pl.Series\",\n rel_tol: float = 0,\n abs_tol: float = 0,\n ignore_spaces: bool = False,\n ignore_case: bool = False,\n) -> \"pl.Series\":\n compare: pl.Series\n try:\n compare = pl.Series(\n np.isclose(col_1, col_2, rtol=rel_tol, atol=abs_tol, equal_nan=True)\n )\n except TypeError:\n try:\n if col_1.dtype in DATE_TYPE or col_2 in DATE_TYPE:\n raise TypeError(\"Found date, moving to alternative logic\")\n\n compare = pl.Series(\n np.isclose(\n col_1.cast(pl.Float64, strict=True),\n col_2.cast(pl.Float64, strict=True),\n rtol=rel_tol,\n atol=abs_tol,\n equal_nan=True,\n )\n )\n except (ValueError, TypeError, InvalidOperationError, ComputeError):\n try:\n if ignore_spaces:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.strip_chars()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.strip_chars()\n\n if ignore_case:\n if str(col_1.dtype) in STRING_TYPE:\n col_1 = col_1.str.to_uppercase()\n if str(col_2.dtype) in STRING_TYPE:\n col_2 = col_2.str.to_uppercase()\n\n if (\n str(col_1.dtype) in STRING_TYPE and str(col_2.dtype) in DATE_TYPE\n ) or (\n str(col_1.dtype) in DATE_TYPE and str(col_2.dtype) in STRING_TYPE\n ):\n compare = compare_string_and_date_columns(col_1, col_2)\n else:\n compare = pl.Series(\n (col_1.eq_missing(col_2)) | (col_1.is_null() & col_2.is_null())\n )\n except:\n # Blanket exception should just return all False\n compare = pl.Series(False * col_1.shape[0])\n return compare\n```\n###test function signature:\n```python\n\ndef test_date_columns_equal():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_date_columns_equal", + "reference": "\ndef test_date_columns_equal():\n data = \"\"\"a|b|expected\n2017-01-01|2017-01-01|True\n2017-01-02|2017-01-02|True\n2017-10-01|2017-10-10|False\n2017-01-01||False\n|2017-01-01|False\n||True\"\"\"\n df = pl.read_csv(\n io.StringIO(data),\n separator=\"|\",\n null_values=[\"NULL\"],\n missing_utf8_is_empty_string=True,\n )\n # First compare just the strings\n actual_out = columns_equal(df[\"a\"], df[\"b\"], rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)\n\n # Then compare converted to datetime objects\n col_a = df[\"a\"].str.to_datetime(strict=False)\n col_b = df[\"b\"].str.to_datetime(strict=False)\n actual_out = columns_equal(col_a, col_b, rel_tol=0.2)\n expect_out = df[\"expected\"]\n assert_series_equal(expect_out, actual_out, check_names=False)", + "completions": [ + "```python\ndef test_date_columns_equal():\n col_1 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n col_2 = pl.Series([\"2022-01-01\", \"2022-01-02\", \"2022-01-03\"])\n compare = columns_equal(col_1, col_2, ignore_spaces=True, ignore_case=True)\n assert compare.sum() == 3\n```\n" + ], + "line": 62, + "token": 360, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 82, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def upsert_child_firing(\n db: PrefectDBInterface,\n session: AsyncSession,\n firing: Firing,\n):\n automation_id = firing.trigger.automation.id\n parent_trigger_id = firing.trigger.parent.id\n child_trigger_id = firing.trigger.id\n\n upsert = (\n postgresql.insert(db.CompositeTriggerChildFiring)\n .values(\n automation_id=automation_id,\n parent_trigger_id=parent_trigger_id,\n child_trigger_id=child_trigger_id,\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n )\n .on_conflict_do_update(\n index_elements=[\n db.CompositeTriggerChildFiring.automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id,\n ],\n set_=dict(\n child_firing_id=firing.id,\n child_fired_at=firing.triggered,\n child_firing=firing.dict(),\n updated=pendulum.now(\"UTC\"),\n ),\n )\n )\n\n await session.execute(upsert)\n\n result = await session.execute(\n sa.select(db.CompositeTriggerChildFiring).filter(\n db.CompositeTriggerChildFiring.automation_id == automation_id,\n db.CompositeTriggerChildFiring.parent_trigger_id == parent_trigger_id,\n db.CompositeTriggerChildFiring.child_trigger_id == child_trigger_id,\n )\n )\n\n return result.scalars().one()\n```\n###test function signature:\n```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_insert_compound_trigger_child_firing", + "reference": "\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n firing = Firing(\n trigger=saw_a_red_dragon,\n trigger_states={TriggerState.Triggered},\n triggered=frozen_time.add(seconds=1),\n triggering_labels={},\n triggering_event=baby_red_dragon_passed_by,\n )\n res = await upsert_child_firing(\n session=automations_session,\n firing=firing,\n )\n\n assert res and res.id\n assert res.automation_id == raise_the_alarm.id\n assert res.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert res.child_trigger_id == saw_a_red_dragon.id\n assert res.child_firing_id == firing.id\n assert res.child_fired_at == firing.triggered\n\n child_firings = await get_child_firings(\n session=automations_session, trigger=saw_a_red_and_green_dragon\n )\n assert len(child_firings) == 1\n assert child_firings[0].id == res.id\n assert child_firings[0].child_firing_id == firing.id", + "completions": [ + "```python\n\nasync def test_insert_compound_trigger_child_firing(\n automations_session: AsyncSession,\n raise_the_alarm: Automation,\n saw_a_red_and_green_dragon: CompoundTrigger,\n saw_a_red_dragon: EventTrigger,\n baby_red_dragon_passed_by: ReceivedEvent,\n frozen_time: DateTime,\n):\n \"\"\"\n Test that a compound trigger child firing is inserted correctly\n \"\"\"\n # Arrange\n firing = Firing(\n id=uuid4(),\n triggered=frozen_time,\n trigger=saw_a_red_and_green_dragon,\n event=baby_red_dragon_passed_by,\n )\n\n # Act\n result = await upsert_child_firing(\n db=automations_session,\n session=automations_session,\n firing=firing,\n )\n\n # Assert\n assert result.automation_id == raise_the_alarm.id\n assert result.parent_trigger_id == saw_a_red_and_green_dragon.id\n assert result.child_trigger_id == saw_a_red_dragon.id\n assert result.child_firing_id == firing.id\n assert result.child_fired_at == firing.triggered\n assert result.child_firing == firing.dict()\n assert result.updated == frozen_time\n\n```\n" + ], + "line": 58, + "token": 225, + "line_diff": 5, + "token_diff": 1 + }, + { + "id": 83, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef write_output(output_file, output_type, results):\n\n # If json or txt, write contents directly to file.\n if output_type == \"json\" or output_type == \"txt\":\n with open(output_file, \"w\") as f:\n json.dump(results, f, indent=4)\n return\n # If csv or xlsx, convert results to pandas dataframes.\n urls_df = get_urls_df(results)\n codes_df = get_codes_df(results)\n\n # If csv, write dataframes to separate csv files for urls, codes.\n if output_type == \"csv\":\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n urls_df.to_csv(urls_output_file, index=False)\n codes_df.to_csv(codes_output_file, index=False)\n\n # If xlsx, write dataframes to separate sheets for urls, codes.\n if output_type == \"xlsx\":\n writer = pd.ExcelWriter(output_file, engine=\"xlsxwriter\")\n urls_df.to_excel(writer, sheet_name=\"URLs\", index=False)\n codes_df.to_excel(writer, sheet_name=\"Codes\", index=False)\n writer.close()\n```\n###test function signature:\n```python\n def test_write_output_csv(self, mock_urls, mock_codes):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_output_csv", + "reference": " def test_write_output_csv(self, mock_urls, mock_codes):\n\n test_file = \"./test_output/test_file.csv\"\n test_file_urls = \"./test_output/test_file_urls.csv\"\n test_file_codes = \"./test_output/test_file_codes.csv\"\n test_results = {\"test\": \"test\"}\n mock_urls.return_value = pd.DataFrame([test_results])\n mock_codes.return_value = pd.DataFrame([test_results])\n\n with open(test_file_urls, \"w\") as f:\n pass\n\n with open(test_file_codes, \"w\") as f:\n pass\n\n write_output(test_file, \"csv\", test_results)\n\n test_data_urls = pd.read_csv(test_file_urls).to_dict(orient=\"records\")[0]\n test_data_codes = pd.read_csv(test_file_codes).to_dict(orient=\"records\")[0]\n\n os.remove(test_file_urls)\n os.remove(test_file_codes)\n\n self.assertEqual(test_data_urls, test_results)\n self.assertEqual(test_data_codes, test_results)", + "completions": [ + "```python\ndef test_write_output_csv(self, mock_urls, mock_codes):\n \"\"\"\n Test write_output function with csv output type.\n \"\"\"\n output_file = \"output.csv\"\n output_type = \"csv\"\n results = {\"urls\": mock_urls, \"codes\": mock_codes}\n write_output(output_file, output_type, results)\n\n # Check that the urls and codes dataframes are written to separate csv files.\n urls_output_file = output_file.replace(\".csv\", \"_urls.csv\")\n codes_output_file = output_file.replace(\".csv\", \"_codes.csv\")\n self.assertTrue(os.path.isfile(urls_output_file))\n self.assertTrue(os.path.isfile(codes_output_file))\n\n # Check that the dataframes are correct.\n urls_df = pd.read_csv(urls_output_file)\n codes_df = pd.read_csv(codes_output_file)\n self.assertEqual(urls_df.shape, (2, 2))\n self.assertEqual(codes_df.shape, (2, 2))\n self.assertEqual(urls_df.iloc[0, 0], \"url1\")\n self.assertEqual(codes_df.iloc[0, 0], \"code1\")\n self.assertEqual(urls_df.iloc[1, 0], \"url2\")\n self.assertEqual(codes_df.iloc[1, 0], \"code2\")\n```\n" + ], + "line": 31, + "token": 279, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 84, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\nasync def terminate_idle_instance() -> None:\n async with get_session_ctx() as session:\n async with PROCESSING_POOL_LOCK:\n res = await session.execute(\n select(InstanceModel)\n .where(\n InstanceModel.termination_policy == TerminationPolicy.DESTROY_AFTER_IDLE,\n InstanceModel.deleted == False,\n InstanceModel.job == None, # noqa: E711\n InstanceModel.status == InstanceStatus.IDLE,\n )\n .options(joinedload(InstanceModel.project))\n )\n instances = res.scalars().all()\n\n for instance in instances:\n last_time = instance.created_at.replace(tzinfo=datetime.timezone.utc)\n if instance.last_job_processed_at is not None:\n last_time = instance.last_job_processed_at.replace(\n tzinfo=datetime.timezone.utc\n )\n\n idle_seconds = instance.termination_idle_time\n delta = datetime.timedelta(seconds=idle_seconds)\n\n current_time = get_current_datetime().replace(tzinfo=datetime.timezone.utc)\n\n if last_time + delta < current_time:\n jpd = JobProvisioningData.__response__.parse_raw(\n instance.job_provisioning_data\n )\n await terminate_job_provisioning_data_instance(\n project=instance.project, job_provisioning_data=jpd\n )\n instance.deleted = True\n instance.deleted_at = get_current_datetime()\n instance.finished_at = get_current_datetime()\n instance.status = InstanceStatus.TERMINATED\n instance.termination_reason = \"Idle timeout\"\n\n idle_time = current_time - last_time\n logger.info(\n \"Instance %s terminated by termination policy: idle time %ss\",\n instance.name,\n str(idle_time.seconds),\n extra={\n \"instance_name\": instance.name,\n \"instance_status\": InstanceStatus.TERMINATED.value,\n },\n )\n\n await session.commit()\n```\n###test function signature:\n```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_terminate_by_idle_timeout", + "reference": "\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n project = await create_project(session=session)\n pool = await create_pool(session, project)\n\n instance = await create_instance(session, project, pool, status=InstanceStatus.IDLE)\n instance.termination_idle_time = 300\n instance.termination_policy = TerminationPolicy.DESTROY_AFTER_IDLE\n instance.last_job_processed_at = get_current_datetime() + dt.timedelta(minutes=-19)\n await session.commit()\n\n with patch(\n \"dstack._internal.server.background.tasks.process_instances.terminate_job_provisioning_data_instance\"\n ):\n await terminate_idle_instance()\n\n await session.refresh(instance)\n\n assert instance is not None\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"", + "completions": [ + "```python\n\n async def test_terminate_by_idle_timeout(self, test_db, session: AsyncSession):\n # given\n project = ProjectModel(name=\"test_project\")\n await session.add(project)\n await session.commit()\n\n instance = InstanceModel(\n name=\"test_instance\",\n project=project,\n termination_policy=TerminationPolicy.DESTROY_AFTER_IDLE,\n termination_idle_time=10,\n status=InstanceStatus.IDLE,\n created_at=datetime.datetime.now(datetime.timezone.utc),\n )\n await session.add(instance)\n await session.commit()\n\n # when\n await terminate_idle_instance()\n\n # then\n instance = await session.get(InstanceModel, instance.id)\n assert instance.status == InstanceStatus.TERMINATED\n assert instance.termination_reason == \"Idle timeout\"\n```\n" + ], + "line": 56, + "token": 289, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 85, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef nargsort(\n items,\n kind: str = \"quicksort\",\n ascending: bool = True,\n na_position: str = \"last\",\n key: Callable | None = None,\n mask: npt.NDArray[np.bool_] | None = None,\n) -> npt.NDArray[np.intp]:\n\n if key is not None:\n items = ensure_key_mapped(items, key)\n return nargsort(\n items,\n kind=kind,\n ascending=ascending,\n na_position=na_position,\n key=None,\n mask=mask,\n )\n\n if isinstance(items, ABCRangeIndex):\n return items.argsort(ascending=ascending) # TODO: test coverage with key?\n elif not isinstance(items, ABCMultiIndex):\n items = extract_array(items)\n if mask is None:\n mask = np.asarray(isna(items)) # TODO: does this exclude MultiIndex too?\n\n if is_extension_array_dtype(items):\n return items.argsort(ascending=ascending, kind=kind, na_position=na_position)\n else:\n items = np.asanyarray(items)\n\n idx = np.arange(len(items))\n non_nans = items[~mask]\n non_nan_idx = idx[~mask]\n\n nan_idx = np.nonzero(mask)[0]\n if not ascending:\n non_nans = non_nans[::-1]\n non_nan_idx = non_nan_idx[::-1]\n indexer = non_nan_idx[non_nans.argsort(kind=kind)]\n if not ascending:\n indexer = indexer[::-1]\n # Finally, place the NaNs at the end or the beginning according to\n # na_position\n if na_position == \"last\":\n indexer = np.concatenate([indexer, nan_idx])\n elif na_position == \"first\":\n indexer = np.concatenate([nan_idx, indexer])\n else:\n raise ValueError(f\"invalid na_position: {na_position}\")\n return ensure_platform_int(indexer)\n```\n###test function signature:\n```python\n\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_nargsort", + "reference": "\n def test_nargsort(self, ascending, na_position, exp, box):\n # list places NaNs last, np.array(..., dtype=\"O\") may not place NaNs first\n items = box([np.nan] * 5 + list(range(100)) + [np.nan] * 5)\n\n # mergesort is the most difficult to get right because we want it to be\n # stable.\n\n # According to numpy/core/tests/test_multiarray, \"\"\"The number of\n # sorted items must be greater than ~50 to check the actual algorithm\n # because quick and merge sort fall over to insertion sort for small\n # arrays.\"\"\"\n\n result = nargsort(\n items, kind=\"mergesort\", ascending=ascending, na_position=na_position\n )\n tm.assert_numpy_array_equal(result, np.array(exp), check_dtype=False)", + "completions": [ + "```python\n def test_nargsort(self, ascending, na_position, exp, box):\n items = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216," + ], + "line": 58, + "token": 412, + "line_diff": 5, + "token_diff": 4 + }, + { + "id": 86, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef generate_rotated_rb_circuits(\n n_qubits: int,\n num_cliffords: int,\n theta: Optional[float] = None,\n trials: int = 1,\n return_type: Optional[str] = None,\n seed: Optional[int] = None,\n) -> List[QPROGRAM]:\n\n circuits = cast(\n List[cirq.Circuit],\n generate_rb_circuits(n_qubits, num_cliffords, trials, seed=seed),\n )\n\n if theta is None:\n theta = 2 * np.pi * np.random.rand()\n\n for circ in circuits:\n qubits = list(circ.all_qubits())\n circ.insert(len(circ) // 2, cirq.Rz(rads=theta).on(qubits[0]))\n\n return_type = \"cirq\" if not return_type else return_type\n return [convert_from_mitiq(circuit, return_type) for circuit in circuits]\n```\n###test function signature:\n```python\n\ndef test_rotated_rb_circuits(n_qubits, theta):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_rotated_rb_circuits", + "reference": "\ndef test_rotated_rb_circuits(n_qubits, theta):\n depth = 10\n for trials in [2, 3]:\n circuits = generate_rotated_rb_circuits(\n n_qubits=n_qubits, num_cliffords=depth, theta=theta, trials=trials\n )\n for circ in circuits:\n zero_prob = (\n cirq.DensityMatrixSimulator()\n .simulate(circ)\n .final_density_matrix[0, 0]\n .real\n )\n assert -1.0001 <= zero_prob <= 1.0001", + "completions": [ + "```python\ndef test_rotated_rb_circuits(n_qubits, theta):\n from cirq import Circuit, LineQubit\n from cirq.testing import assert_has_diagram\n from mitiq.generate import generate_rotated_rb_circuits\n\n circuits = generate_rotated_rb_circuits(n_qubits, 10, theta)\n for circuit in circuits:\n assert_has_diagram(\n circuit,\n f\"\"\"\n0: ───Rz({theta})───\n1: ───Rz({theta})───\n2: ───Rz({theta})───\n3: ───Rz({theta})───\n4: ───Rz({theta})───\n5: ───Rz({theta})───\n6: ───Rz({theta})───\n7: ───Rz({theta})───\n8: ───Rz({theta})───\n9: ───Rz({theta})───\n10: ───Rz({theta})───\n11: ───Rz({theta})───\n12: ───Rz({theta})───\n13: ───Rz({theta})───\n14: ───Rz({theta})───\n15: ───Rz({theta})───\n16: ───Rz({theta})───\n17: ───Rz({theta})───\n18: ───Rz({theta})───\n19: ───Rz({theta})───\n20: ───Rz({theta})───\n21: ───Rz({theta})───\n22: ───Rz({theta})───\n23: ───Rz({theta})───\n24: ───Rz({theta})───\n25: ───Rz({theta})───\n26: ───Rz({theta})───\n27: ───Rz({theta})───\n28: ───Rz({theta})───\n29: ───Rz({theta})───\n30: ───Rz({theta})───\n31: ───Rz({theta})───\n32: ───Rz({theta})───\n33: ───Rz({theta})───\n34: ───Rz({theta})───\n35: ───Rz({theta})───\n36: ───Rz({theta})───\n37: ───Rz({theta})───\n38: ───Rz({theta})───\n39: ───Rz({theta})───\n40: ───Rz({theta})───\n41: ───Rz({theta})───\n42: ───Rz({theta})───\n43: ───Rz({theta})───\n44: ───Rz({theta})───\n45: ───Rz({theta})───\n46: ───Rz({theta})───\n47: ───Rz({theta})───\n48: ───Rz({theta})───\n49: ───Rz({theta})───\n50: ───Rz({theta})───\n51: ───Rz({theta})───\n52: ───Rz({theta})───\n53: ───Rz({theta})───\n54: ───Rz({theta})───\n55: ───Rz({theta})───\n56: ───Rz({theta})───\n57: ───Rz({theta})───\n58: ───Rz({theta})───\n59: ───Rz({theta})───\n60: ───Rz({theta})───\n61: ───Rz({theta})───\n62: ───Rz({theta})───\n63: ───" + ], + "line": 29, + "token": 226, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 87, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef eye(N, M=None, k=None, dtype=None):\n dtype = to_torch_dtype(dtype or config.floatx())\n M = N if M is None else M\n k = 0 if k is None else k\n if k == 0:\n # TODO: torch.eye doesn't support bfloat16 with cpu\n if get_device() == \"cpu\" and dtype == torch.bfloat16:\n return cast(\n torch.eye(\n N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n ),\n dtype,\n )\n return torch.eye(N, M, dtype=dtype, device=get_device())\n diag_length = builtins.max(N, M)\n diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n return torch.diag(diag, diagonal=k)[:N, :M]\n```\n###test function signature:\n```python\n\n def test_eye(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_eye", + "reference": "\n def test_eye(self):\n self.assertAllClose(knp.eye(3), np.eye(3))\n self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n self.assertAllClose(knp.Eye()(3), np.eye(3))\n self.assertAllClose(knp.Eye()(3, 4), np.eye(3, 4))\n self.assertAllClose(knp.Eye(k=1)(3, 4), np.eye(3, 4, k=1))\n\n # Test k >= N\n self.assertAllClose(knp.Eye(k=3)(3), np.eye(3, k=3))\n\n # Test k > 0 and N >= M\n self.assertAllClose(knp.Eye(k=1)(3), np.eye(3, k=1))\n\n # Test k > 0 and N < M and N + k > M\n self.assertAllClose(knp.Eye(k=2)(3, 4), np.eye(3, 4, k=2))\n\n # Test k < 0 and M >= N\n self.assertAllClose(knp.Eye(k=-1)(3), np.eye(3, k=-1))\n\n # Test k < 0 and M < N and M - k > N\n self.assertAllClose(knp.Eye(k=-2)(4, 3), np.eye(4, 3, k=-2))", + "completions": [ + "```python\n\n def test_eye(self):\n self.assertEqual(eye(3), torch.eye(3))\n self.assertEqual(eye(3, 4), torch.eye(3, 4))\n self.assertEqual(eye(3, 4, 1), torch.eye(3, 4, 1))\n self.assertEqual(eye(3, 4, -1), torch.eye(3, 4, -1))\n self.assertEqual(eye(3, 4, 2), torch.eye(3, 4, 2))\n self.assertEqual(eye(3, 4, -2), torch.eye(3, 4, -2))\n self.assertEqual(eye(3, 4, 3), torch.eye(3, 4, 3))\n self.assertEqual(eye(3, 4, -3), torch.eye(3, 4, -3))\n self.assertEqual(eye(3, 4, 4), torch.eye(3, 4, 4))\n self.assertEqual(eye(3, 4, -4), torch.eye(3, 4, -4))\n self.assertEqual(eye(3, 4, 5), torch.eye(3, 4, 5))\n self.assertEqual(eye(3, 4, -5), torch.eye(3, 4, -5))\n self.assertEqual(eye(3, 4, 6), torch.eye(3, 4, 6))\n self.assertEqual(eye(3, 4, -6), torch.eye(3, 4, -6))\n self.assertEqual(eye(3, 4, 7), torch.eye(3, 4, 7))\n self.assertEqual(eye(3, 4, -7), torch.eye(3, 4, -7))\n self.assertEqual(eye(3, 4, 8), torch.eye(3, 4, 8))\n self.assertEqual(eye(3, 4, -8), torch.eye(3, 4, -8))\n self.assertEqual(eye(3, 4, 9), torch.eye(3, 4, 9))\n self.assertEqual(eye(3, 4, -9), torch.eye(3, 4, -9))\n self.assertEqual(eye(3, 4, 10), torch.eye(3, 4, 10))\n self.assertEqual(eye(3, 4, -10), torch.eye(3, 4, -10))\n self.assertEqual(eye(3, 4, 11), torch.eye(3, 4, 11))\n self.assertEqual(eye(3, 4, -11), torch.eye(3, 4, -11))\n self.assertEqual(eye(3, 4, 12), torch.eye(3, 4, 12))\n self.assertEqual(eye(3, 4, -12), torch.eye(3, 4, -12))\n self.assertEqual(eye(3, 4, 13), torch.eye(3, 4, 13))\n self.assertEqual(eye(3, 4, -13), torch.eye(3, 4, -13))\n self.assertEqual(eye(3, 4, 14), torch.eye(3, 4, 14))\n self.assertEqual(eye(3, 4, -14), torch.eye(3, 4, -14))\n self.assertEqual(eye(3, 4, 15), torch.eye(" + ], + "line": 27, + "token": 219, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 88, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef ragged_qk_dotprod(\n query: RaggedActivations, key: RaggedActivations, lut: RaggedQkPidLookupTable\n) -> torch.Tensor:\n device = query.device\n\n assert query.raw_tensor.is_contiguous()\n assert key.raw_tensor.is_contiguous()\n\n # check constraints\n total_ctx_q_across_all_seqs, d_head = query.raw_tensor.shape\n total_ctx_k_across_all_seqs, d_head_k = key.raw_tensor.shape\n assert d_head == d_head_k, f\"{query.raw_tensor.shape=} {key.raw_tensor.shape=}\"\n\n # allocates output\n # max_n_ctx_q_across_seqs = query.max_n_ctx_per_seq\n\n assert query.n_seqs == key.n_seqs\n # TODO: flag use zeros for garbage\n scores_out = torch.ones(\n (query.n_seqs, query.max_n_ctx_per_seq, key.max_n_ctx_per_seq),\n device=device,\n dtype=query.dtype,\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_out_seq_idx,\n # Integers\n max_n_ctx_q_across_seqs=query.max_n_ctx_per_seq,\n max_n_ctx_k_across_seqs=key.max_n_ctx_per_seq,\n d_head=d_head,\n stride_ctx_q=query.raw_tensor.stride(0),\n stride_ctx_k=key.raw_tensor.stride(0),\n stride_out_seq=scores_out.stride(0),\n stride_out_q=scores_out.stride(1),\n stride_out_k=scores_out.stride(2),\n total_ctx_q_across_all_seqs=total_ctx_q_across_all_seqs,\n total_ctx_k_across_all_seqs=total_ctx_k_across_all_seqs,\n )\n return scores_out\n```\n###test function signature:\n```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_ragged_qk_dotprod_single_seq", + "reference": "\ndef test_ragged_qk_dotprod_single_seq(dtype):\n d_head = 2\n\n key = RaggedActivations.from_list(\n [\n make_seq(n_ctx=3, value=42, d_model=d_head, dtype=dtype),\n ]\n )\n query = RaggedActivations.from_list(\n [\n make_seq(n_ctx=4, value=55, d_model=d_head, dtype=dtype),\n ]\n )\n torch_scores = scores_via_qk_dotprod(query, key)\n print(f\"{torch_scores=}\")\n\n lut = RaggedQkPidLookupTable.from_query_and_key_tokens_per_seq(\n n_ctx_q_per_seq=query.n_ctx_per_seq, n_ctx_k_per_seq=key.n_ctx_per_seq\n )\n\n scores = ragged_qk_dotprod(query, key, lut)\n assert_eq(torch_scores, scores)", + "completions": [ + "```python\n\ndef test_ragged_qk_dotprod_single_seq(dtype):\n # query\n query_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n query = RaggedActivations(\n raw_tensor=query_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # key\n key_raw_tensor = torch.randn(\n 1, 1, 1, dtype=dtype, device=torch.device(\"cuda\")\n )\n key = RaggedActivations(\n raw_tensor=key_raw_tensor,\n n_seqs=1,\n max_n_ctx_per_seq=1,\n dtype=dtype,\n device=torch.device(\"cuda\"),\n )\n\n # lut\n pid_to_in_q_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_in_k_token_offset = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_q_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_k_block = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n pid_to_out_seq_idx = torch.tensor(\n [0], dtype=torch.int32, device=torch.device(\"cuda\")\n )\n lut = RaggedQkPidLookupTable(\n pid_to_in_q_token_offset=pid_to_in_q_token_offset,\n pid_to_in_k_token_offset=pid_to_in_k_token_offset,\n pid_to_out_q_block=pid_to_out_q_block,\n pid_to_out_k_block=pid_to_out_k_block,\n pid_to_out_seq_idx=pid_to_out_seq_idx,\n n_pids_total=1,\n )\n\n # scores_out\n scores_out = torch.ones(\n (1, 1, 1), dtype=dtype, device=torch.device(\"cuda\")\n )\n\n # Stride along the d_head dimension must be 1\n assert query.raw_tensor.stride(1) == 1, f\"{query.raw_tensor.stride(1)}\"\n assert key.raw_tensor.stride(1) == 1, f\"{key.raw_tensor.stride(1)}\"\n\n # pid_to_seq_idx = [0, 0, 1, 2, 2]\n grid = (lut.n_pids_total,)\n _qk_dotprod_kernel[grid](\n q_ptr=query.raw_tensor,\n k_ptr=key.raw_tensor,\n scores_ptr=scores_out,\n # Lookup tables (sometimes referred to as a \"lut\")\n pid_to_in_q_token_offset_ptr=lut.pid_to_in_q_token_offset,\n pid_to_in_k_token_offset_ptr=lut.pid_to_in_k_token_offset,\n pid_to_out_q_block_ptr=lut.pid_to_out_q_block,\n pid_to_out_k_block_ptr=lut.pid_to_out_k_block,\n pid_to_out_seq_idx_ptr=lut.pid_to_" + ], + "line": 57, + "token": 310, + "line_diff": 4, + "token_diff": 3 + }, + { + "id": 89, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):\n # Ravel both arrays, behavior for the first array could be different\n ar1 = np.asarray(ar1).ravel()\n ar2 = np.asarray(ar2).ravel()\n\n # Ensure that iteration through object arrays yields size-1 arrays\n if ar2.dtype == object:\n ar2 = ar2.reshape(-1, 1)\n\n if kind not in {None, 'sort', 'table'}:\n raise ValueError(\n f\"Invalid kind: '{kind}'. Please use None, 'sort' or 'table'.\")\n\n # Can use the table method if all arrays are integers or boolean:\n is_int_arrays = all(ar.dtype.kind in (\"u\", \"i\", \"b\") for ar in (ar1, ar2))\n use_table_method = is_int_arrays and kind in {None, 'table'}\n\n if use_table_method:\n if ar2.size == 0:\n if invert:\n return np.ones_like(ar1, dtype=bool)\n else:\n return np.zeros_like(ar1, dtype=bool)\n\n # Convert booleans to uint8 so we can use the fast integer algorithm\n if ar1.dtype == bool:\n ar1 = ar1.astype(np.uint8)\n if ar2.dtype == bool:\n ar2 = ar2.astype(np.uint8)\n\n ar2_min = np.min(ar2)\n ar2_max = np.max(ar2)\n\n ar2_range = int(ar2_max) - int(ar2_min)\n\n # Constraints on whether we can actually use the table method:\n # 1. Assert memory usage is not too large\n below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)\n # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype\n range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max\n # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype\n if ar1.size > 0:\n ar1_min = np.min(ar1)\n ar1_max = np.max(ar1)\n\n # After masking, the range of ar1 is guaranteed to be\n # within the range of ar2:\n ar1_upper = min(int(ar1_max), int(ar2_max))\n ar1_lower = max(int(ar1_min), int(ar2_min))\n\n range_safe_from_overflow &= all((\n ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,\n ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min\n ))\n\n # Optimal performance is for approximately\n # log10(size) > (log10(range) - 2.27) / 0.927.\n # However, here we set the requirement that by default\n # the intermediate array can only be 6x\n # the combined memory allocation of the original\n # arrays. See discussion on \n # https://github.com/numpy/numpy/pull/12065.\n\n if (\n range_safe_from_overflow and \n (below_memory_constraint or kind == 'table')\n ):\n\n if invert:\n outgoing_array = np.ones_like(ar1, dtype=bool)\n else:\n outgoing_array = np.zeros_like(ar1, dtype=bool)\n\n # Make elements 1 where the integer exists in ar2\n if invert:\n isin_helper_ar = np.ones(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 0\n else:\n isin_helper_ar = np.zeros(ar2_range + 1, dtype=bool)\n isin_helper_ar[ar2 - ar2_min] = 1\n\n # Mask out elements we know won't work\n basic_mask = (ar1 <= ar2_max) & (ar1 >= ar2_min)\n outgoing_array[basic_mask] = isin_helper_ar[ar1[basic_mask] -\n ar2_min]\n\n return outgoing_array\n elif kind == 'table': # not range_safe_from_overflow\n raise RuntimeError(\n \"You have specified kind='table', \"\n \"but the range of values in `ar2` or `ar1` exceed the \"\n \"maximum integer of the datatype. \"\n \"Please set `kind` to None or 'sort'.\"\n )\n elif kind == 'table':\n raise ValueError(\n \"The 'table' method is only \"\n \"supported for boolean or integer arrays. \"\n \"Please select 'sort' or None for kind.\"\n )\n\n\n # Check if one of the arrays may contain arbitrary objects\n contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject\n\n # This code is run when\n # a) the first condition is true, making the code significantly faster\n # b) the second condition is true (i.e. `ar1` or `ar2` may contain\n # arbitrary objects), since then sorting is not guaranteed to work\n if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:\n if invert:\n mask = np.ones(len(ar1), dtype=bool)\n for a in ar2:\n mask &= (ar1 != a)\n else:\n mask = np.zeros(len(ar1), dtype=bool)\n for a in ar2:\n mask |= (ar1 == a)\n return mask\n\n # Otherwise use sorting\n if not assume_unique:\n ar1, rev_idx = np.unique(ar1, return_inverse=True)\n ar2 = np.unique(ar2)\n\n ar = np.concatenate((ar1, ar2))\n # We need this to be a stable sort, so always use 'mergesort'\n # here. The values from the first array should always come before\n # the values from the second array.\n order = ar.argsort(kind='mergesort')\n sar = ar[order]\n if invert:\n bool_ar = (sar[1:] != sar[:-1])\n else:\n bool_ar = (sar[1:] == sar[:-1])\n flag = np.concatenate((bool_ar, [invert]))\n ret = np.empty(ar.shape, dtype=bool)\n ret[order] = flag\n\n if assume_unique:\n return ret[:len(ar1)]\n else:\n return ret[rev_idx]\n```\n###test function signature:\n```python\n\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_in1d_ravel", + "reference": "\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n a = np.arange(6).reshape(2, 3)\n b = np.arange(3, 9).reshape(3, 2)\n long_b = np.arange(3, 63).reshape(30, 2)\n ec = np.array([False, False, False, True, True, True])\n\n assert_array_equal(in1d(a, b, assume_unique=True, kind=kind),\n ec)\n assert_array_equal(in1d(a, b, assume_unique=False,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=True,\n kind=kind),\n ec)\n assert_array_equal(in1d(a, long_b, assume_unique=False,\n kind=kind),\n ec)", + "completions": [ + "```python\n def test_in1d_ravel(self, kind):\n # Test that in1d ravels its input arrays. This is not documented\n # behavior however. The test is to ensure consistentency.\n ar1 = np.array([1, 2, 3, 4, 5])\n ar2 = np.array([1, 2, 3, 4, 5])\n ar1_ravel = ar1.ravel()\n ar2_ravel = ar2.ravel()\n self.assertTrue(np.array_equal(ar1_ravel, ar1))\n self.assertTrue(np.array_equal(ar2_ravel, ar2))\n self.assertTrue(np.array_equal(in1d(ar1, ar2, kind=kind),\n in1d(ar1_ravel, ar2_ravel, kind=kind)))\n```\n" + ], + "line": 133, + "token": 1112, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 90, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef header_to_metadata_and_cell(\n lines, header_prefix, header_suffix, ext=None, root_level_metadata_as_raw_cell=True\n):\n\n header = []\n jupyter = []\n in_jupyter = False\n in_html_div = False\n\n start = 0\n started = False\n ended = False\n metadata = {}\n i = -1\n\n comment = \"#\" if header_prefix == \"#'\" else header_prefix\n\n encoding_re = re.compile(\n rf\"^[ \\t\\f]*{re.escape(comment)}.*?coding[:=][ \\t]*([-_.a-zA-Z0-9]+)\"\n )\n\n for i, line in enumerate(lines):\n if i == 0 and line.startswith(\"#!\"):\n metadata.setdefault(\"jupytext\", {})[\"executable\"] = line[2:]\n start = i + 1\n continue\n if i == 0 or (i == 1 and not encoding_re.match(lines[0])):\n encoding = encoding_re.match(line)\n if encoding:\n if encoding.group(1) != \"utf-8\":\n raise ValueError(\"Encodings other than utf-8 are not supported\")\n metadata.setdefault(\"jupytext\", {})[\"encoding\"] = line\n start = i + 1\n continue\n if not line.startswith(header_prefix):\n break\n if not comment:\n if line.strip().startswith(\"\" in line:\n break\n if not started and not line.strip():\n continue\n\n line = uncomment_line(line, header_prefix, header_suffix)\n if _HEADER_RE.match(line):\n if not started:\n started = True\n continue\n ended = True\n if in_html_div:\n continue\n break\n\n # Stop if there is something else than a YAML header\n if not started and line.strip():\n break\n\n if _JUPYTER_RE.match(line):\n in_jupyter = True\n elif line and not _LEFTSPACE_RE.match(line):\n in_jupyter = False\n\n if in_jupyter:\n jupyter.append(line)\n else:\n header.append(line)\n\n if ended:\n if jupyter:\n extra_metadata = metadata\n metadata = yaml.safe_load(\"\\n\".join(jupyter))[\"jupyter\"]\n recursive_update(metadata, extra_metadata)\n\n lines_to_next_cell = 1\n if len(lines) > i + 1:\n line = uncomment_line(lines[i + 1], header_prefix)\n if not _BLANK_RE.match(line):\n lines_to_next_cell = 0\n else:\n i = i + 1\n else:\n lines_to_next_cell = 0\n\n if header:\n if root_level_metadata_as_raw_cell:\n cell = new_raw_cell(\n source=\"\\n\".join([\"---\"] + header + [\"---\"]),\n metadata={}\n if lines_to_next_cell\n == pep8_lines_between_cells([\"---\"], lines[i + 1 :], ext)\n else {\"lines_to_next_cell\": lines_to_next_cell},\n )\n else:\n cell = None\n root_level_metadata = yaml.safe_load(\"\\n\".join(header))\n metadata.setdefault(\"jupytext\", {})[\n \"root_level_metadata\"\n ] = root_level_metadata\n else:\n cell = None\n\n return metadata, jupyter, cell, i + 1\n\n return metadata, False, None, start\n```\n###test function signature:\n```python\n\ndef test_header_to_metadata_and_cell_metadata():\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_header_to_metadata_and_cell_metadata", + "reference": "\ndef test_header_to_metadata_and_cell_metadata():\n text = \"\"\"---\ntitle: Sample header\njupyter:\n mainlanguage: python\n---\n\"\"\"\n lines = text.splitlines()\n metadata, _, cell, pos = header_to_metadata_and_cell(lines, \"\", \"\")\n\n assert metadata == {\"mainlanguage\": \"python\"}\n assert cell.cell_type == \"raw\"\n assert (\n cell.source\n == \"\"\"---\ntitle: Sample header\n---\"\"\"\n )\n assert cell.metadata == {\"lines_to_next_cell\": 0}\n assert pos == len(lines)", + "completions": [ + "```python\ndef test_header_to_metadata_and_cell_metadata():\n lines = [\n \"#' ---\",\n \"#' jupyter:\",\n \"#' jupytext:\",\n \"#' text_representation:\",\n \"#' extension: .py\",\n \"#' format_name: light\",\n \"#' format_version: '1.5'\",\n \"#' jupytext_version: 1.11.4\",\n \"#' kernelspec:\",\n \"#' display_name: Python 3\",\n \"#' language: python\",\n \"#' name: python3\",\n \"#' ---\",\n \"#'\",\n \"#' This is a comment\",\n \"#'\",\n \"#' And another one\",\n \"#'\",\n \"#' ```python\",\n \"#' import numpy as np\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' import matplotlib.pyplot as plt\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %matplotlib inline\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2, 3])\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.show()\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' %config InlineBackend.figure_format = 'retina'\",\n \"#' ```\",\n \"#'\",\n \"#' ```python\",\n \"#' plt.plot([1, 2" + ], + "line": 104, + "token": 642, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 91, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef import_optional_dependency(\n name: str,\n extra: str = \"\",\n errors: str = \"raise\",\n min_version: str | None = None,\n):\n\n assert errors in {\"warn\", \"raise\", \"ignore\"}\n\n package_name = INSTALL_MAPPING.get(name)\n install_name = package_name if package_name is not None else name\n\n msg = (\n f\"Missing optional dependency '{install_name}'. {extra} \"\n f\"Use pip or conda to install {install_name}.\"\n )\n try:\n module = importlib.import_module(name)\n except ImportError:\n if errors == \"raise\":\n raise ImportError(msg)\n return None\n\n # Handle submodules: if we have submodule, grab parent module from sys.modules\n parent = name.split(\".\")[0]\n if parent != name:\n install_name = parent\n module_to_get = sys.modules[install_name]\n else:\n module_to_get = module\n minimum_version = min_version if min_version is not None else VERSIONS.get(parent)\n if minimum_version:\n version = get_version(module_to_get)\n if version and Version(version) < Version(minimum_version):\n msg = (\n f\"Pandas requires version '{minimum_version}' or newer of '{parent}' \"\n f\"(version '{version}' currently installed).\"\n )\n if errors == \"warn\":\n warnings.warn(\n msg,\n UserWarning,\n stacklevel=find_stack_level(),\n )\n return None\n elif errors == \"raise\":\n raise ImportError(msg)\n\n return module\n```\n###test function signature:\n```python\n\ndef test_bad_version(monkeypatch):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_bad_version", + "reference": "\ndef test_bad_version(monkeypatch):\n name = \"fakemodule\"\n module = types.ModuleType(name)\n module.__version__ = \"0.9.0\"\n sys.modules[name] = module\n monkeypatch.setitem(VERSIONS, name, \"1.0.0\")\n\n match = \"Pandas requires .*1.0.0.* of .fakemodule.*'0.9.0'\"\n with pytest.raises(ImportError, match=match):\n import_optional_dependency(\"fakemodule\")\n\n # Test min_version parameter\n result = import_optional_dependency(\"fakemodule\", min_version=\"0.8\")\n assert result is module\n\n with tm.assert_produces_warning(UserWarning):\n result = import_optional_dependency(\"fakemodule\", errors=\"warn\")\n assert result is None\n\n module.__version__ = \"1.0.0\" # exact match is OK\n result = import_optional_dependency(\"fakemodule\")\n assert result is module", + "completions": [ + "```python\ndef test_bad_version(monkeypatch):\n monkeypatch.setattr(import_optional_dependency, \"get_version\", lambda x: \"0.1.0\")\n monkeypatch.setattr(import_optional_dependency, \"VERSIONS\", {\"pandas\": \"0.2.0\"})\n with pytest.raises(ImportError):\n import_optional_dependency(\"pandas\")\n```\n" + ], + "line": 54, + "token": 352, + "line_diff": 4, + "token_diff": 4 + }, + { + "id": 92, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef spectrogram(\n waveform: np.ndarray,\n window: np.ndarray,\n frame_length: int,\n hop_length: int,\n fft_length: Optional[int] = None,\n power: Optional[float] = 1.0,\n center: bool = True,\n pad_mode: str = \"reflect\",\n onesided: bool = True,\n preemphasis: Optional[float] = None,\n mel_filters: Optional[np.ndarray] = None,\n mel_floor: float = 1e-10,\n log_mel: Optional[str] = None,\n reference: float = 1.0,\n min_value: float = 1e-10,\n db_range: Optional[float] = None,\n remove_dc_offset: Optional[bool] = None,\n dtype: np.dtype = np.float32,\n) -> np.ndarray:\n window_length = len(window)\n\n if fft_length is None:\n fft_length = frame_length\n\n if frame_length > fft_length:\n raise ValueError(f\"frame_length ({frame_length}) may not be larger than fft_length ({fft_length})\")\n\n if window_length != frame_length:\n raise ValueError(f\"Length of the window ({window_length}) must equal frame_length ({frame_length})\")\n\n if hop_length <= 0:\n raise ValueError(\"hop_length must be greater than zero\")\n\n if waveform.ndim != 1:\n raise ValueError(f\"Input waveform must have only one dimension, shape is {waveform.shape}\")\n\n if np.iscomplexobj(waveform):\n raise ValueError(\"Complex-valued input waveforms are not currently supported\")\n\n if power is None and mel_filters is not None:\n raise ValueError(\n \"You have provided `mel_filters` but `power` is `None`. Mel spectrogram computation is not yet supported for complex-valued spectrogram.\"\n \"Specify `power` to fix this issue.\"\n )\n\n # center pad the waveform\n if center:\n padding = [(int(frame_length // 2), int(frame_length // 2))]\n waveform = np.pad(waveform, padding, mode=pad_mode)\n\n # promote to float64, since np.fft uses float64 internally\n waveform = waveform.astype(np.float64)\n window = window.astype(np.float64)\n\n # split waveform into frames of frame_length size\n num_frames = int(1 + np.floor((waveform.size - frame_length) / hop_length))\n\n num_frequency_bins = (fft_length // 2) + 1 if onesided else fft_length\n spectrogram = np.empty((num_frames, num_frequency_bins), dtype=np.complex64)\n\n # rfft is faster than fft\n fft_func = np.fft.rfft if onesided else np.fft.fft\n buffer = np.zeros(fft_length)\n\n timestep = 0\n for frame_idx in range(num_frames):\n buffer[:frame_length] = waveform[timestep : timestep + frame_length]\n\n if remove_dc_offset:\n buffer[:frame_length] = buffer[:frame_length] - buffer[:frame_length].mean()\n\n if preemphasis is not None:\n buffer[1:frame_length] -= preemphasis * buffer[: frame_length - 1]\n buffer[0] *= 1 - preemphasis\n\n buffer[:frame_length] *= window\n\n spectrogram[frame_idx] = fft_func(buffer)\n timestep += hop_length\n\n # note: ** is much faster than np.power\n if power is not None:\n spectrogram = np.abs(spectrogram, dtype=np.float64) ** power\n\n spectrogram = spectrogram.T\n\n if mel_filters is not None:\n spectrogram = np.maximum(mel_floor, np.dot(mel_filters.T, spectrogram))\n\n if power is not None and log_mel is not None:\n if log_mel == \"log\":\n spectrogram = np.log(spectrogram)\n elif log_mel == \"log10\":\n spectrogram = np.log10(spectrogram)\n elif log_mel == \"dB\":\n if power == 1.0:\n spectrogram = amplitude_to_db(spectrogram, reference, min_value, db_range)\n elif power == 2.0:\n spectrogram = power_to_db(spectrogram, reference, min_value, db_range)\n else:\n raise ValueError(f\"Cannot use log_mel option '{log_mel}' with power {power}\")\n else:\n raise ValueError(f\"Unknown log_mel option: {log_mel}\")\n\n spectrogram = np.asarray(spectrogram, dtype)\n\n return spectrogram\n```\n###test function signature:\n```python\n\n def test_mel_spectrogram(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_mel_spectrogram", + "reference": "\n def test_mel_spectrogram(self):\n waveform = self._load_datasamples(1)[0]\n\n mel_filters = mel_filter_bank(\n num_frequency_bins=513,\n num_mel_filters=13,\n min_frequency=100,\n max_frequency=4000,\n sampling_rate=16000,\n norm=None,\n mel_scale=\"htk\",\n )\n self.assertEqual(mel_filters.shape, (513, 13))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n )\n self.assertEqual(spec.shape, (513, 732))\n\n spec = spectrogram(\n waveform,\n window_function(800, \"hann\", frame_length=1024),\n frame_length=1024,\n hop_length=128,\n power=2.0,\n mel_filters=mel_filters,\n )\n self.assertEqual(spec.shape, (13, 732))\n\n # fmt: off\n expected = np.array([\n 1.08027889e+02, 1.48080673e+01, 7.70758213e+00, 9.57676639e-01,\n 8.81639061e-02, 5.26073833e-02, 1.52736155e-02, 9.95350117e-03,\n 7.95364356e-03, 1.01148004e-02, 4.29241020e-03, 9.90708797e-03,\n 9.44153646e-04\n ])\n # fmt: on\n self.assertTrue(np.allclose(spec[:, 300], expected))", + "completions": [ + "```python\n def test_mel_spectrogram(self):\n waveform = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])\n window = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])\n frame_length = 10\n hop_length = 1\n fft_length = 10\n power = 1.0\n center = True\n pad_mode = \"reflect\"\n onesided = True\n preemphasis = None\n mel_filters = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n mel_floor = 1e-10\n log_mel = \"log\"\n reference = 1.0\n min_value = 1e-10\n db_range = None\n remove_dc_offset = None\n dtype = np.float32\n\n spectrogram = spectrogram(\n waveform,\n window,\n frame_length,\n hop_length,\n fft_length,\n power,\n center,\n pad_mode,\n onesided,\n preemphasis,\n mel_filters,\n mel_floor,\n log_mel,\n reference,\n min_value,\n db_range,\n remove_dc_offset,\n dtype,\n )\n\n self.assertEqual(spectrogram.shape, (1, 10))\n self.assertEqual(spectrogram[0][0], 1)\n self.assertEqual(spectrogram[0][1], 2)\n self.assertEqual(spectrogram[0][2], 3)\n self.assertEqual(spectrogram[0][3], 4)\n self.assertEqual(spectrogram[0][4], 5)\n self.assertEqual(spectrogram[0][5], 6)\n self.assertEqual(spectrogram[0][6], 7)\n self.assertEqual(spectrogram[0][7], 8)\n self.assertEqual(spectrogram[0][8], 9)\n self.assertEqual(spectrogram[0][9], 10)\n```\n" + ], + "line": 95, + "token": 804, + "line_diff": 6, + "token_diff": 6 + }, + { + "id": 93, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef create(\n meta_schema,\n validators=(),\n version=None,\n type_checker=_types.draft202012_type_checker,\n format_checker=_format.draft202012_format_checker,\n id_of=_id_of,\n applicable_validators=methodcaller(\"items\"),\n):\n # preemptively don't shadow the `Validator.format_checker` local\n format_checker_arg = format_checker\n\n @attr.s\n class Validator:\n\n VALIDATORS = dict(validators)\n META_SCHEMA = dict(meta_schema)\n TYPE_CHECKER = type_checker\n FORMAT_CHECKER = format_checker_arg\n ID_OF = staticmethod(id_of)\n\n schema = attr.ib(repr=reprlib.repr)\n resolver = attr.ib(default=None, repr=False)\n format_checker = attr.ib(default=None)\n\n def __init_subclass__(cls):\n warnings.warn(\n (\n \"Subclassing validator classes is not intended to \"\n \"be part of their public API. A future version \"\n \"will make doing so an error, as the behavior of \"\n \"subclasses isn't guaranteed to stay the same \"\n \"between releases of jsonschema. Instead, prefer \"\n \"composition of validators, wrapping them in an object \"\n \"owned entirely by the downstream library.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n\n def __attrs_post_init__(self):\n if self.resolver is None:\n self.resolver = RefResolver.from_schema(\n self.schema,\n id_of=id_of,\n )\n\n @classmethod\n def check_schema(cls, schema, format_checker=_UNSET):\n Validator = validator_for(cls.META_SCHEMA, default=cls)\n if format_checker is _UNSET:\n format_checker = Validator.FORMAT_CHECKER\n validator = Validator(\n schema=cls.META_SCHEMA,\n format_checker=format_checker,\n )\n for error in validator.iter_errors(schema):\n raise exceptions.SchemaError.create_from(error)\n\n def evolve(self, **changes):\n # Essentially reproduces attr.evolve, but may involve instantiating\n # a different class than this one.\n cls = self.__class__\n\n schema = changes.setdefault(\"schema\", self.schema)\n NewValidator = validator_for(schema, default=cls)\n\n for field in attr.fields(cls):\n if not field.init:\n continue\n attr_name = field.name # To deal with private attributes.\n init_name = attr_name if attr_name[0] != \"_\" else attr_name[1:]\n if init_name not in changes:\n changes[init_name] = getattr(self, attr_name)\n\n return NewValidator(**changes)\n\n def iter_errors(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.iter_errors \"\n \"is deprecated and will be removed in a future \"\n \"release. Call validator.evolve(schema=new_schema).\"\n \"iter_errors(...) instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n else:\n _schema = self.schema\n\n if _schema is True:\n return\n elif _schema is False:\n yield exceptions.ValidationError(\n f\"False schema does not allow {instance!r}\",\n validator=None,\n validator_value=None,\n instance=instance,\n schema=_schema,\n )\n return\n\n scope = id_of(_schema)\n if scope:\n self.resolver.push_scope(scope)\n try:\n for k, v in applicable_validators(_schema):\n validator = self.VALIDATORS.get(k)\n if validator is None:\n continue\n\n errors = validator(self, v, instance, _schema) or ()\n for error in errors:\n # set details if not already set by the called fn\n error._set(\n validator=k,\n validator_value=v,\n instance=instance,\n schema=_schema,\n type_checker=self.TYPE_CHECKER,\n )\n if k not in {\"if\", \"$ref\"}:\n error.schema_path.appendleft(k)\n yield error\n finally:\n if scope:\n self.resolver.pop_scope()\n\n def descend(self, instance, schema, path=None, schema_path=None):\n for error in self.evolve(schema=schema).iter_errors(instance):\n if path is not None:\n error.path.appendleft(path)\n if schema_path is not None:\n error.schema_path.appendleft(schema_path)\n yield error\n\n def validate(self, *args, **kwargs):\n for error in self.iter_errors(*args, **kwargs):\n raise error\n\n def is_type(self, instance, type):\n try:\n return self.TYPE_CHECKER.is_type(instance, type)\n except exceptions.UndefinedTypeCheck:\n raise exceptions.UnknownType(type, instance, self.schema)\n\n def is_valid(self, instance, _schema=None):\n if _schema is not None:\n warnings.warn(\n (\n \"Passing a schema to Validator.is_valid is deprecated \"\n \"and will be removed in a future release. Call \"\n \"validator.evolve(schema=new_schema).is_valid(...) \"\n \"instead.\"\n ),\n DeprecationWarning,\n stacklevel=2,\n )\n self = self.evolve(schema=_schema)\n\n error = next(self.iter_errors(instance), None)\n return error is None\n\n if version is not None:\n safe = version.title().replace(\" \", \"\").replace(\"-\", \"\")\n Validator.__name__ = Validator.__qualname__ = f\"{safe}Validator\"\n Validator = validates(version)(Validator)\n\n return Validator\n```\n###test function signature:\n```python\n\n def test_validates_registers_meta_schema_id(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_validates_registers_meta_schema_id", + "reference": "\n def test_validates_registers_meta_schema_id(self):\n meta_schema_key = \"meta schema id\"\n my_meta_schema = {\"id\": meta_schema_key}\n\n validators.create(\n meta_schema=my_meta_schema,\n version=\"my version\",\n id_of=lambda s: s.get(\"id\", \"\"),\n )\n self.addCleanup(validators._META_SCHEMAS.pop, meta_schema_key)\n\n self.assertIn(meta_schema_key, validators._META_SCHEMAS)", + "completions": [ + "```python\n def test_validates_registers_meta_schema_id(self):\n validator = create(\n meta_schema={\n \"$id\": \"https://example.com/schema.json\",\n \"type\": \"object\",\n \"properties\": {\"foo\": {\"type\": \"string\"}},\n }\n )\n self.assertEqual(\n validator.resolver.store[\"https://example.com/schema.json\"],\n validator.schema,\n )\n```\n" + ], + "line": 160, + "token": 896, + "line_diff": 7, + "token_diff": 7 + }, + { + "id": 94, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef read_metrics(directory: str) -> pd.Series:\n\n dfs = []\n # Loop through all csv files in the directory\n csv_files = [\n filename for filename in os.listdir(directory) if filename.endswith(\".csv\")\n ]\n\n if len(csv_files) == 0:\n return pd.Series()\n\n for filename in csv_files:\n filepath = os.path.join(directory, filename)\n # Read the csv file into a pandas DataFrame\n df = pd.read_csv(\n filepath,\n header=None,\n names=[\n \"initial_time_iso\",\n \"lead_time_hours\",\n \"channel\",\n \"metric\",\n \"value\",\n ],\n )\n out = pd.DataFrame()\n out[\"initial_time\"] = pd.to_datetime(df[\"initial_time_iso\"])\n out[\"lead_time\"] = pd.to_timedelta(df[\"lead_time_hours\"], unit=\"h\")\n out[\"channel\"] = df[\"channel\"]\n out[\"metric\"] = df[\"metric\"]\n out[\"value\"] = df[\"value\"]\n dfs.append(out)\n\n df = pd.concat(dfs, axis=0)\n df.set_index([\"initial_time\", \"lead_time\", \"channel\", \"metric\"], inplace=True)\n return df[\"value\"]\n```\n###test function signature:\n```python\n\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_write_and_read_metric", + "reference": "\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n initial_time = datetime.datetime(2022, 1, 1, 0, 0, 0)\n lead_time = datetime.timedelta(hours=24)\n channel = \"t2m\"\n metric = \"rmse\"\n value = 25.6\n\n # Read the metric from the directory\n directory = tmpdir.mkdir(\"metrics\")\n filename = \"test_metric.csv\"\n filepath = os.path.join(directory, filename)\n with open(filepath, \"w\") as f:\n forecast_metrics_io.write_metric(\n f, initial_time, lead_time, channel, metric, value\n )\n metrics = forecast_metrics_io.read_metrics(directory)\n\n # Check that the metric was written correctly\n expected_index = pd.MultiIndex.from_tuples(\n [(initial_time, lead_time, channel, metric)],\n names=[\"initial_time\", \"lead_time\", \"channel\", \"metric\"],\n )\n expected_value = pd.Series([value], index=expected_index, name=\"value\")\n pd.testing.assert_series_equal(metrics, expected_value)", + "completions": [ + "```python\ndef test_write_and_read_metric(tmpdir):\n # Write a metric to the file object\n metric = pd.Series(\n data=[1, 2, 3],\n index=[\n pd.Timestamp(\"2020-01-01 00:00:00\"),\n pd.Timestamp(\"2020-01-01 01:00:00\"),\n pd.Timestamp(\"2020-01-01 02:00:00\"),\n ],\n name=\"metric\",\n )\n metric.to_csv(tmpdir.join(\"metric.csv\"), header=False)\n\n # Read the metric from the file object\n read_metric = read_metrics(tmpdir.strpath)\n\n # Check that the metric is the same\n assert_series_equal(metric, read_metric)\n```\n" + ], + "line": 43, + "token": 320, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 95, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n def evaluate(self, points):\n points = np.atleast_2d(points)\n\n dim, num_m = np.array(points).shape\n if dim != self.dim:\n raise ValueError(\"points have dimension {}, dataset has dimension \"\n \"{}\".format(dim, self.dim))\n\n result = np.zeros(num_m)\n\n if num_m >= self.num_dp:\n # there are more points than data, so loop over data\n for i in range(self.num_dp):\n diff = self.dataset[:, i, np.newaxis] - points\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result = result + np.exp(-energy)\n else:\n # loop over points\n for i in range(num_m):\n diff = self.dataset - points[:, i, np.newaxis]\n tdiff = np.dot(self.inv_cov, diff)\n energy = np.sum(diff * tdiff, axis=0) / 2.0\n result[i] = np.sum(np.exp(-energy), axis=0)\n\n result = result / self.norm_factor\n\n return result\n```\n###test function signature:\n```python\n\n def test_kde_bandwidth_method(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_kde_bandwidth_method", + "reference": "\n def test_kde_bandwidth_method(self):\n\n np.random.seed(8765678)\n n_basesample = 50\n xn = np.random.randn(n_basesample)\n\n # Default\n gkde = mlab.GaussianKDE(xn)\n # Supply a callable\n gkde2 = mlab.GaussianKDE(xn, 'scott')\n # Supply a scalar\n gkde3 = mlab.GaussianKDE(xn, bw_method=gkde.factor)\n\n xs = np.linspace(-7, 7, 51)\n kdepdf = gkde.evaluate(xs)\n kdepdf2 = gkde2.evaluate(xs)\n assert kdepdf.all() == kdepdf2.all()\n kdepdf3 = gkde3.evaluate(xs)\n assert kdepdf.all() == kdepdf3.all()", + "completions": [ + "```python\n def test_kde_bandwidth_method(self):\n from sklearn.neighbors import KernelDensity\n from sklearn.metrics import euclidean_distances\n from sklearn.datasets import make_blobs\n from sklearn.preprocessing import StandardScaler\n from sklearn.model_selection import train_test_split\n from sklearn.metrics import accuracy_score\n from sklearn.metrics import f1_score\n from sklearn.metrics import precision_score\n from sklearn.metrics import recall_score\n from sklearn.metrics import roc_auc_score\n from sklearn.metrics import confusion_matrix\n from sklearn.metrics import classification_report\n from sklearn.metrics import cohen_kappa_score\n from sklearn.metrics import mean_squared_error\n from sklearn.metrics import mean_absolute_error\n from sklearn.metrics import mean_squared_log_error\n from sklearn.metrics import median_absolute_error\n from sklearn.metrics import r2_score\n from sklearn.metrics import explained_variance_score\n from sklearn.metrics import silhouette_score\n from sklearn.metrics import calinski_harabasz_score\n from sklearn.metrics import davies_bouldin_score\n from sklearn.metrics import mutual_info_score\n from sklearn.metrics import normalized_mutual_info_score\n from sklearn.metrics import adjusted_mutual_info_score\n from sklearn.metrics import adjusted_rand_score\n from sklearn.metrics import completeness_score\n from sklearn.metrics import contingency_matrix\n from sklearn.metrics import homogeneity_completeness_v_measure\n from sklearn.metrics import homogeneity_score\n from sklearn.metrics import v_measure_score\n from sklearn.metrics import normalized_v_measure_score\n from sklearn.metrics import poisson\n from sklearn.metrics import kld\n from sklearn.metrics import kld_normalized\n from sklearn.metrics import kld_unnormalized\n from sklearn.metrics import kld_unnormalized_normalized\n from sklearn.metrics import kld_unnormalized_normalized_sym\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from sklearn.metrics import kld_unnormalized_normalized_sym_\n from" + ], + "line": 33, + "token": 273, + "line_diff": 2, + "token_diff": 2 + }, + { + "id": 96, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef get_top_k_embeddings(\n query_embedding: List[float],\n embeddings: List[List[float]],\n similarity_fn: Optional[Callable[..., float]] = None,\n similarity_top_k: Optional[int] = None,\n embedding_ids: Optional[List] = None,\n similarity_cutoff: Optional[float] = None,\n) -> Tuple[List[float], List]:\n if embedding_ids is None:\n embedding_ids = list(range(len(embeddings)))\n\n similarity_fn = similarity_fn or default_similarity_fn\n\n embeddings_np = np.array(embeddings)\n query_embedding_np = np.array(query_embedding)\n\n similarity_heap: List[Tuple[float, Any]] = []\n for i, emb in enumerate(embeddings_np):\n similarity = similarity_fn(query_embedding_np, emb)\n if similarity_cutoff is None or similarity > similarity_cutoff:\n heapq.heappush(similarity_heap, (similarity, embedding_ids[i]))\n if similarity_top_k and len(similarity_heap) > similarity_top_k:\n heapq.heappop(similarity_heap)\n result_tups = sorted(similarity_heap, key=lambda x: x[0], reverse=True)\n\n result_similarities = [s for s, _ in result_tups]\n result_ids = [n for _, n in result_tups]\n\n return result_similarities, result_ids\n```\n###test function signature:\n```python\ndef test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_get_top_k_mmr_embeddings", + "reference": "def test_get_top_k_mmr_embeddings() -> None:\n # Results score should follow from the mmr algorithm\n query_embedding = [5.0, 0.0, 0.0]\n embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]]\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.8\n )\n\n assert np.isclose(0.8 * 4 / 5, result_similarities[0], atol=0.00001)\n assert np.isclose(\n 0.8 * 3 / 5 - (1 - 0.8) * (3 * 4 / 25 + 3 * 4 / 25),\n result_similarities[1],\n atol=0.00001,\n )\n assert np.isclose(\n 0.8 * -4 / 5 - (1 - 0.8) * (3 * -4 / 25 + 4 * 3 / 25),\n result_similarities[2],\n atol=0.00001,\n )\n assert result_ids == [0, 1, 2]\n\n # Tests that if the first embedding vector is close to the second,\n # it will return the third\n query_embedding = [1.0, 0.0, 1.0]\n embeddings = [[1.0, 0.0, 0.9], [1.0, 0.0, 0.8], [0.7, 0.0, 1.0]]\n\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=0.5\n )\n assert result_ids == [0, 2, 1]\n\n # Tests that embedding ids map properly to results\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, embedding_ids=[\"A\", \"B\", \"C\"], mmr_threshold=0.5\n )\n assert result_ids == [\"A\", \"C\", \"B\"]\n # Test that it will go back to the original order under a high threshold\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n assert result_ids == [0, 1, 2]\n\n # Test similarity_top_k works\n _, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1, similarity_top_k=2\n )\n assert result_ids == [0, 1]\n\n # Test the results for get_top_k_embeddings and get_top_k_mmr_embeddings are the\n # same for threshold = 1\n query_embedding = [10, 23, 90, 78]\n embeddings = [[1, 23, 89, 68], [1, 74, 144, 23], [0.23, 0.0, 1.0, 9]]\n result_similarities_no_mmr, result_ids_no_mmr = get_top_k_embeddings(\n query_embedding, embeddings\n )\n result_similarities, result_ids = get_top_k_mmr_embeddings(\n query_embedding, embeddings, mmr_threshold=1\n )\n\n for result_no_mmr, result_with_mmr in zip(\n result_similarities_no_mmr, result_similarities\n ):\n assert np.isclose(result_no_mmr, result_with_mmr, atol=0.00001)", + "completions": [ + "```python\ndef test_get_top_k_mmr_embeddings() -> None:\n query_embedding = [0.1, 0.2, 0.3, 0.4, 0.5]\n embeddings = [\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n [0.1, 0.2, 0.3, 0.4, 0.5],\n ]\n result_similarities, result_ids = get_top_k_embeddings(\n query_embedding, embeddings, similarity_top_k=3\n )\n assert result_similarities == [0.5, 0.4, 0.3]\n assert result_ids == [0, 1, 2]\n```\n" + ], + "line": 35, + "token": 309, + "line_diff": 2, + "token_diff": 3 + }, + { + "id": 97, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef frame_accuracy(\n logits: np.ndarray,\n label: np.ndarray,\n background_logit_threshold: float = 0.0,\n) -> float:\n top1_idx = np.argmax(logits, axis=-1)\n background_label = np.sum(label, axis=-1) == 0\n pred_background = (\n np.max(logits, axis=-1) <= background_logit_threshold\n ).astype(np.int32)\n\n # Extracts the label at the highest logit index for each input.\n top1_correct = np.take_along_axis(label, top1_idx[..., None], axis=-1)\n top1_correct = np.squeeze(top1_correct)\n foreground_correct = ~pred_background.astype(bool) * top1_correct\n\n # Count correctly classified background samples.\n background_correct = pred_background * background_label\n correct = foreground_correct + background_correct\n return np.sum(correct) / len(logits)\n```\n###test function signature:\n```python\n\n def test_frame_accuracy_all_background(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_frame_accuracy_all_background", + "reference": "\n def test_frame_accuracy_all_background(self):\n logits = np.array(\n [\n [1.5, 0.0, -1.0],\n [0.1, 0.7, 0.2],\n [0.1, 0.7, 0.2],\n [0.1, -0.7, 0.2],\n [-0.1, -0.7, -0.2],\n ],\n np.float32,\n )\n label = np.array(\n [\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n [0, 0, 0],\n ],\n np.int32,\n )\n actual = metrics.frame_accuracy(logits, label)\n self.assertAlmostEqual(actual, 0.2)", + "completions": [ + "```python\ndef test_frame_accuracy_all_background(self):\n logits = np.array([[0.1, 0.2, 0.3, 0.4]])\n label = np.array([[0, 0, 0, 0]])\n self.assertAlmostEqual(frame_accuracy(logits, label), 1.0)\n```\n" + ], + "line": 28, + "token": 206, + "line_diff": 1, + "token_diff": 1 + }, + { + "id": 98, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\n\ndef load_image_from_string(\n reference: str,\n max_height: Optional[int] = None,\n max_width: Optional[int] = None,\n) -> Tuple[str, Optional[float]]:\n if uri_is_http_link(uri=reference):\n return load_image_from_url(\n url=reference, max_height=max_height, max_width=max_width\n )\n if os.path.exists(reference):\n if max_height is None or max_width is None:\n with open(reference, \"rb\") as f:\n img_bytes = f.read()\n img_base64_str = encode_base_64(payload=img_bytes)\n return img_base64_str, None\n local_image = cv2.imread(reference)\n if local_image is None:\n raise EncodingError(f\"Could not load image from {reference}\")\n local_image, scaling_factor = resize_opencv_image(\n image=local_image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=local_image), scaling_factor\n if max_height is not None and max_width is not None:\n image_bytes = base64.b64decode(reference)\n image = bytes_to_opencv_image(payload=image_bytes)\n image, scaling_factor = resize_opencv_image(\n image=image,\n max_height=max_height,\n max_width=max_width,\n )\n return numpy_array_to_base64_jpeg(image=image), scaling_factor\n return reference, None\n```\n###test function signature:\n```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_load_image_from_string_when_local_image_to_be_loaded", + "reference": "\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n file_path, _ = example_local_image\n\n # when\n serialised_image, scaling_factor = load_image_from_string(\n reference=file_path,\n max_width=64,\n max_height=128,\n )\n recovered_image = base64.b64decode(serialised_image)\n bytes_array = np.frombuffer(recovered_image, dtype=np.uint8)\n decoding_result = cv2.imdecode(bytes_array, cv2.IMREAD_UNCHANGED)\n\n # then\n assert abs(scaling_factor - 0.5) < 1e-5\n assert decoding_result.shape == (64, 64, 3)\n assert (decoding_result == 0).all()", + "completions": [ + "```python\n\ndef test_load_image_from_string_when_local_image_to_be_loaded(\n example_local_image: Tuple[str, np.ndarray]\n) -> None:\n # given\n reference, image = example_local_image\n # when\n result = load_image_from_string(\n reference=reference,\n max_height=image.shape[0],\n max_width=image.shape[1],\n )\n # then\n assert result[0] == reference\n assert result[1] is None\n```\n" + ], + "line": 47, + "token": 280, + "line_diff": 3, + "token_diff": 3 + }, + { + "id": 99, + "prompt": "\nYou are an expert Python programmer. You will be given some information and will generate a correct test case for the given function.\nYour response will only be python test function and without anything else.\n\n### focal function:\n```python\ndef hertz_to_mel(freq: Union[float, np.ndarray], mel_scale: str = \"htk\") -> Union[float, np.ndarray]:\n\n if mel_scale not in [\"slaney\", \"htk\", \"kaldi\"]:\n raise ValueError('mel_scale should be one of \"htk\", \"slaney\" or \"kaldi\".')\n\n if mel_scale == \"htk\":\n return 2595.0 * np.log10(1.0 + (freq / 700.0))\n elif mel_scale == \"kaldi\":\n return 1127.0 * np.log(1.0 + (freq / 700.0))\n\n min_log_hertz = 1000.0\n min_log_mel = 15.0\n logstep = 27.0 / np.log(6.4)\n mels = 3.0 * freq / 200.0\n\n if isinstance(freq, np.ndarray):\n log_region = freq >= min_log_hertz\n mels[log_region] = min_log_mel + np.log(freq[log_region] / min_log_hertz) * logstep\n elif freq >= min_log_hertz:\n mels = min_log_mel + np.log(freq / min_log_hertz) * logstep\n\n return mels\n```\n###test function signature:\n```python\n\n def test_hertz_to_mel(self):\n```\n### Answer: (use the provided format with backticks)\n", + "name": "test_hertz_to_mel", + "reference": "\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(0.0), 0.0)\n self.assertAlmostEqual(hertz_to_mel(100), 150.48910241)\n\n inputs = np.array([100, 200])\n expected = np.array([150.48910241, 283.22989816])\n self.assertTrue(np.allclose(hertz_to_mel(inputs), expected))\n\n self.assertEqual(hertz_to_mel(0.0, \"slaney\"), 0.0)\n self.assertEqual(hertz_to_mel(100, \"slaney\"), 1.5)\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([0.9, 1.5, 3.0, 15.0, 15.01453781, 25.08188016])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"slaney\"), expected))\n\n inputs = np.array([60, 100, 200, 1000, 1001, 2000])\n expected = np.array([92.6824, 150.4899, 283.2313, 999.9907, 1000.6534, 1521.3674])\n self.assertTrue(np.allclose(hertz_to_mel(inputs, \"kaldi\"), expected))\n\n with pytest.raises(ValueError):\n hertz_to_mel(100, mel_scale=None)", + "completions": [ + "```python\n def test_hertz_to_mel(self):\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"slaney\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"htk\"), 0)\n self.assertEqual(hertz_to_mel(1000, \"kaldi\"), 0)\n self.assertEqual(hertz_to_mel(" + ], + "line": 27, + "token": 262, + "line_diff": 1, + "token_diff": 2 + } +] \ No newline at end of file diff --git a/dataset/Test Generation/ComplexCodeEval-Python/8/QS/line_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Python/8/QS/line_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..baf4a2212aba4d50435cfaff8a769bb9ebaa7e09 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/8/QS/line_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(15~23),subset_1(24~29),subset_2(31~40),subset_3(40~49),subset_4(50~58),subset_5(58~76),subset_6(79~133),subset_7(133~749) +StarCoder2-15b,28.20,28.73,22.45,22.71,24.44,18.34,25.32,24.47 +CodeLlama-7b,21.38,25.55,22.38,31.90,29.45,24.23,30.11,27.22 +CodeLlama-13b,30.03,28.32,28.37,24.80,27.65,25.20,22.43,26.98 +CodeLlama-34b,30.29,30.68,26.60,21.67,25.88,25.78,27.94,27.66 +DeepSeek-Coder-1.3b,22.21,25.49,24.23,28.70,27.72,22.97,26.75,27.07 +DeepSeek-Coder-6.7b,25.00,24.71,23.57,27.31,24.29,24.95,22.78,23.69 +DeepSeek-Coder-33b,30.28,29.29,29.88,25.43,32.16,26.34,27.36,31.42 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/8/QS/token_counts_QS.csv b/dataset/Test Generation/ComplexCodeEval-Python/8/QS/token_counts_QS.csv new file mode 100644 index 0000000000000000000000000000000000000000..1d8d96f6e499f090cc64f4a9ded4f19b0ca04d84 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/8/QS/token_counts_QS.csv @@ -0,0 +1,8 @@ +Models, subset_0(122~162),subset_1(162~226),subset_2(237~279),subset_3(280~347),subset_4(351~425),subset_5(430~586),subset_6(610~886),subset_7(896~7038) +StarCoder2-15b,27.70,24.46,27.81,21.71,24.47,21.86,22.09,24.47 +CodeLlama-7b,22.14,26.22,26.55,29.38,28.63,25.69,26.15,27.22 +CodeLlama-13b,31.23,30.24,21.43,25.45,27.93,27.21,22.43,26.98 +CodeLlama-34b,27.71,28.97,31.50,23.61,23.51,25.10,28.47,27.66 +DeepSeek-Coder-1.3b,21.46,27.43,24.93,28.16,25.51,25.82,24.65,27.07 +DeepSeek-Coder-6.7b,26.18,23.26,25.55,25.35,26.38,23.48,22.43,23.69 +DeepSeek-Coder-33b,29.25,28.76,31.47,26.92,27.73,29.55,27.07,31.42 diff --git a/dataset/Test Generation/ComplexCodeEval-Python/8/QS/tongji.py b/dataset/Test Generation/ComplexCodeEval-Python/8/QS/tongji.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5f6a831a70100f6970eb8fc556b1717f72faa8 --- /dev/null +++ b/dataset/Test Generation/ComplexCodeEval-Python/8/QS/tongji.py @@ -0,0 +1,50 @@ +import json +from collections import defaultdict +import os + +def analyze_json_file(file_path): + # 读取JSON文件 + with open(file_path, 'r',encoding="utf-8") as f: + data = json.load(f) + + # 初始化统计字典 + line_diff_stats = defaultdict(list) + token_diff_stats = defaultdict(list) + + # 收集数据 + for entry in data: + line_diff = entry['line_diff'] + token_diff = entry['token_diff'] + line = entry['line'] + token = entry['token'] + + line_diff_stats[line_diff].append(line) + token_diff_stats[token_diff].append(token) + + # 处理line_diff统计结果 + print("Models: ", end="") + line_diff_keys = sorted(line_diff_stats.keys()) + line_subsets = [] + for diff in line_diff_keys: + lines = line_diff_stats[diff] + min_line = min(lines) + max_line = max(lines) + line_subsets.append(f"subset_{diff}({min_line}~{max_line})") + print(",".join(line_subsets)) + + # 处理token_diff统计结果 + print("Models: ", end="") + token_diff_keys = sorted(token_diff_stats.keys()) + token_subsets = [] + for diff in token_diff_keys: + tokens = token_diff_stats[diff] + min_token = min(tokens) + max_token = max(tokens) + token_subsets.append(f"subset_{diff}({min_token}~{max_token})") + print(",".join(token_subsets)) + +# 使用示例 +if __name__ == "__main__": + # 假设JSON文件名为data.json + file_path = "EI.json" if os.path.exists("EI.json") else "QS.json" + analyze_json_file(file_path) \ No newline at end of file